Skip to content

Commit b9f190d

Browse files
authored
JavaRDD extension functions + iterators (#174)
1 parent e0288b7 commit b9f190d

File tree

18 files changed

+2127
-94
lines changed

18 files changed

+2127
-94
lines changed

.github/workflows/build.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ jobs:
5959
-Pspark=${{ matrix.spark }}
6060
-Pscala=${{ matrix.scala }}
6161
clean
62-
build
62+
test
6363
--scan
6464
6565
# qodana:

core/src/main/scala/org/jetbrains/kotlinx/spark/extensions/KSparkExtensions.scala

+14-1
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ package org.jetbrains.kotlinx.spark.extensions
2121

2222
import org.apache.spark.SparkContext
2323
import org.apache.spark.sql._
24-
2524
import java.util
25+
import scala.reflect.ClassTag
2626

2727
object KSparkExtensions {
2828

@@ -58,4 +58,17 @@ object KSparkExtensions {
5858
}
5959

6060
def sparkContext(s: SparkSession): SparkContext = s.sparkContext
61+
62+
/**
63+
* Produces a ClassTag[T], which is actually just a casted ClassTag[AnyRef].
64+
*
65+
* This method is used to keep ClassTags out of the external Java API, as the Java compiler
66+
* cannot produce them automatically. While this ClassTag-faking does please the compiler,
67+
* it can cause problems at runtime if the Scala API relies on ClassTags for correctness.
68+
*
69+
* Often, though, a ClassTag[AnyRef] will not lead to incorrect behavior, just worse performance
70+
* or security issues. For instance, an Array[AnyRef] can hold any type T, but may lose primitive
71+
* specialization.
72+
*/
73+
def fakeClassTag[T]: ClassTag[T] = ClassTag.AnyRef.asInstanceOf[ClassTag[T]]
6174
}

examples/src/main/kotlin/org/jetbrains/kotlinx/spark/examples/JupyterExample.ipynb

+19-8
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
{
44
"cell_type": "markdown",
55
"source": [
6-
"By default the latest version of the API and the latest supported Spark version is chosen.\n",
7-
"To specify your own: `%use spark(spark=3.2, v=1.1.0)`"
6+
"By default, the latest version of the API and the latest supported Spark version is chosen.\n",
7+
"To specify your own: `%use spark(spark=3.3.0, scala=2.13, v=1.2.0)`"
88
],
99
"metadata": {
1010
"collapsed": false,
@@ -35,6 +35,18 @@
3535
}
3636
}
3737
},
38+
{
39+
"cell_type": "code",
40+
"execution_count": null,
41+
"outputs": [],
42+
"source": [],
43+
"metadata": {
44+
"collapsed": false,
45+
"pycharm": {
46+
"name": "#%%\n"
47+
}
48+
}
49+
},
3850
{
3951
"cell_type": "markdown",
4052
"source": [
@@ -312,14 +324,13 @@
312324
}
313325
],
314326
"source": [
315-
"val rdd: JavaRDD<Tuple2<Int, String>> = sc.parallelize(\n",
316-
" listOf(\n",
317-
" 1 X \"aaa\",\n",
318-
" t(2, \"bbb\"),\n",
319-
" tupleOf(3, \"ccc\"),\n",
320-
" )\n",
327+
"val rdd: JavaRDD<Tuple2<Int, String>> = rddOf(\n",
328+
" 1 X \"aaa\",\n",
329+
" t(2, \"bbb\"),\n",
330+
" tupleOf(3, \"ccc\"),\n",
321331
")\n",
322332
"\n",
333+
"\n",
323334
"rdd"
324335
],
325336
"metadata": {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
package org.jetbrains.kotlinx.spark.examples
2+
3+
import org.apache.spark.sql.Dataset
4+
import org.jetbrains.kotlinx.spark.api.*
5+
import org.jetbrains.kotlinx.spark.api.tuples.X
6+
import org.jetbrains.kotlinx.spark.examples.GroupCalculation.getAllPossibleGroups
7+
import scala.Tuple2
8+
import kotlin.math.pow
9+
10+
/**
11+
* Gets all the possible, unique, non repeating groups of indices for a list.
12+
*
13+
* Example by Jolanrensen.
14+
*/
15+
16+
fun main() = withSpark {
17+
val groupIndices = getAllPossibleGroups(listSize = 10, groupSize = 4)
18+
.sort("value")
19+
20+
groupIndices.showDS(numRows = groupIndices.count().toInt())
21+
}
22+
23+
object GroupCalculation {
24+
25+
/**
26+
* Get all the possible, unique, non repeating groups (of size [groupSize]) of indices for a list of
27+
* size [listSize].
28+
*
29+
*
30+
* The workload is evenly distributed by [listSize] and [groupSize]
31+
*
32+
* @param listSize the size of the list for which to calculate the indices
33+
* @param groupSize the size of a group of indices
34+
* @return all the possible, unique non repeating groups of indices
35+
*/
36+
fun KSparkSession.getAllPossibleGroups(
37+
listSize: Int,
38+
groupSize: Int,
39+
): Dataset<IntArray> {
40+
val indices = (0 until listSize).toList().toRDD() // Easy RDD creation!
41+
42+
// for a groupSize of 1, no pairing up is needed, so just return the indices converted to IntArrays
43+
if (groupSize == 1) {
44+
return indices
45+
.mapPartitions {
46+
it.map { intArrayOf(it) }
47+
}
48+
.toDS()
49+
}
50+
51+
// this converts all indices to (number in table, index)
52+
val keys = indices.mapPartitions {
53+
54+
// _1 is key (item in table), _2 is index in list
55+
it.transformAsSequence {
56+
flatMap { listIndex ->
57+
58+
// for each dimension loop over the other dimensions using addTuples
59+
(0 until groupSize).asSequence().flatMap { dimension ->
60+
addTuples(
61+
groupSize = groupSize,
62+
value = listIndex,
63+
listSize = listSize,
64+
skipDimension = dimension,
65+
)
66+
}
67+
}
68+
}
69+
}
70+
71+
// Since we have a JavaRDD<Tuple2> we can aggregateByKey!
72+
// Each number in table occurs for each dimension as key.
73+
// The values of those two will be a tuple of (key, indices as list)
74+
val allPossibleGroups = keys.aggregateByKey(
75+
zeroValue = IntArray(groupSize) { -1 },
76+
seqFunc = { base: IntArray, listIndex: Int ->
77+
// put listIndex in the first empty spot in base
78+
base[base.indexOfFirst { it < 0 }] = listIndex
79+
80+
base
81+
},
82+
83+
// how to merge partially filled up int arrays
84+
combFunc = { a: IntArray, b: IntArray ->
85+
// merge a and b
86+
var j = 0
87+
for (i in a.indices) {
88+
if (a[i] < 0) {
89+
while (b[j] < 0) {
90+
j++
91+
if (j == b.size) return@aggregateByKey a
92+
}
93+
a[i] = b[j]
94+
j++
95+
}
96+
}
97+
a
98+
},
99+
)
100+
.values() // finally just take the values
101+
102+
return allPossibleGroups.toDS()
103+
}
104+
105+
/**
106+
* Simple method to give each index of x dimensions a unique number.
107+
*
108+
* @param indexTuple IntArray (can be seen as Tuple) of size x with all values < listSize. The index for which to return the number
109+
* @param listSize The size of the list, aka the max width, height etc. of the table
110+
* @return the unique number for this [indexTuple]
111+
*/
112+
private fun getTupleValue(indexTuple: List<Int>, listSize: Int): Int =
113+
indexTuple.indices.sumOf {
114+
indexTuple[it] * listSize.toDouble().pow(it).toInt()
115+
}
116+
117+
118+
/**
119+
* To make sure that every tuple is only picked once, this method returns true only if the indices are in the right
120+
* corner of the matrix. This works for any number of dimensions > 1. Here is an example for 2-D:
121+
*
122+
*
123+
* - 0 1 2 3 4 5 6 7 8 9
124+
* --------------------------------
125+
* 0| x ✓ ✓ ✓ ✓ ✓ ✓ ✓ ✓ ✓
126+
* 1| x x ✓ ✓ ✓ ✓ ✓ ✓ ✓ ✓
127+
* 2| x x x ✓ ✓ ✓ ✓ ✓ ✓ ✓
128+
* 3| x x x x ✓ ✓ ✓ ✓ ✓ ✓
129+
* 4| x x x x x ✓ ✓ ✓ ✓ ✓
130+
* 5| x x x x x x ✓ ✓ ✓ ✓
131+
* 6| x x x x x x x ✓ ✓ ✓
132+
* 7| x x x x x x x x ✓ ✓
133+
* 8| x x x x x x x x x ✓
134+
* 9| x x x x x x x x x x
135+
*
136+
* @param indexTuple a tuple of indices in the form of an IntArray
137+
* @return true if this tuple is in the right corner and should be included
138+
*/
139+
private fun isValidIndexTuple(indexTuple: List<Int>): Boolean {
140+
// x - y > 0; 2d
141+
// (x - y) > 0 && (x - z) > 0 && (y - z) > 0; 3d
142+
// (x - y) > 0 && (x - z) > 0 && (x - a) > 0 && (y - z) > 0 && (y - a) > 0 && (z - a) > 0; 4d
143+
require(indexTuple.size >= 2) { "not a tuple" }
144+
for (i in 0 until indexTuple.size - 1) {
145+
for (j in i + 1 until indexTuple.size) {
146+
if (indexTuple[i] - indexTuple[j] <= 0) return false
147+
}
148+
}
149+
return true
150+
}
151+
152+
/**
153+
* Recursive method that for [skipDimension] loops over all the other dimensions and returns all results from
154+
* [getTupleValue] as key and [value] as value.
155+
* In the end, the return value will have, for each key in the table below, a value for the key's column, row etc.
156+
*
157+
*
158+
* This is an example for 2D. The letters will be int indices as well (a = 0, b = 1, ..., [listSize]), but help for clarification.
159+
* The numbers we don't want are filtered out using [isValidIndexTuple].
160+
* The actual value of the number in the table comes from [getTupleValue].
161+
*
162+
*
163+
*
164+
*
165+
* - a b c d e f g h i j
166+
* --------------------------------
167+
* a| - 1 2 3 4 5 6 7 8 9
168+
* b| - - 12 13 14 15 16 17 18 19
169+
* c| - - - 23 24 25 26 27 28 29
170+
* d| - - - - 34 35 36 37 38 39
171+
* e| - - - - - 45 46 47 48 49
172+
* f| - - - - - - 56 57 58 59
173+
* g| - - - - - - - 67 68 69
174+
* h| - - - - - - - - 78 79
175+
* i| - - - - - - - - - 89
176+
* j| - - - - - - - - - -
177+
*
178+
*
179+
* @param groupSize the size of index tuples to form
180+
* @param value the current index to work from (can be seen as a letter in the table above)
181+
* @param listSize the size of the list to make
182+
* @param skipDimension the current dimension that will have a set value [value] while looping over the other dimensions
183+
*/
184+
private fun addTuples(
185+
groupSize: Int,
186+
value: Int,
187+
listSize: Int,
188+
skipDimension: Int,
189+
): List<Tuple2<Int, Int>> {
190+
191+
/**
192+
* @param currentDimension the indicator for which dimension we're currently calculating for (and how deep in the recursion we are)
193+
* @param indexTuple the list (or tuple) in which to store the current indices
194+
*/
195+
fun recursiveCall(
196+
currentDimension: Int = 0,
197+
indexTuple: List<Int> = emptyList(),
198+
): List<Tuple2<Int, Int>> = when {
199+
// base case
200+
currentDimension >= groupSize ->
201+
if (isValidIndexTuple(indexTuple))
202+
listOf(getTupleValue(indexTuple, listSize) X value)
203+
else
204+
emptyList()
205+
206+
currentDimension == skipDimension ->
207+
recursiveCall(
208+
currentDimension = currentDimension + 1,
209+
indexTuple = indexTuple + value,
210+
)
211+
212+
else ->
213+
(0 until listSize).flatMap { i ->
214+
recursiveCall(
215+
currentDimension = currentDimension + 1,
216+
indexTuple = indexTuple + i,
217+
)
218+
}
219+
}
220+
221+
return recursiveCall()
222+
}
223+
}

examples/src/main/kotlin/org/jetbrains/kotlinx/spark/examples/streaming/KotlinStatefulNetworkCount.kt

+4-7
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,8 @@ import org.apache.spark.api.java.StorageLevels
2424
import org.apache.spark.streaming.Durations
2525
import org.apache.spark.streaming.State
2626
import org.apache.spark.streaming.StateSpec
27-
import org.jetbrains.kotlinx.spark.api.getOrElse
28-
import org.jetbrains.kotlinx.spark.api.mapWithState
29-
import org.jetbrains.kotlinx.spark.api.toPairRDD
27+
import org.jetbrains.kotlinx.spark.api.*
3028
import org.jetbrains.kotlinx.spark.api.tuples.X
31-
import org.jetbrains.kotlinx.spark.api.withSparkStreaming
3229
import java.util.regex.Pattern
3330
import kotlin.system.exitProcess
3431

@@ -71,8 +68,8 @@ object KotlinStatefulNetworkCount {
7168
) {
7269

7370
// Initial state RDD input to mapWithState
74-
val tuples = listOf("hello" X 1, "world" X 1)
75-
val initialRDD = ssc.sparkContext().parallelize(tuples)
71+
val tuples = arrayOf("hello" X 1, "world" X 1)
72+
val initialRDD = ssc.sparkContext().rddOf(*tuples)
7673

7774
val lines = ssc.socketTextStream(
7875
args.getOrElse(0) { DEFAULT_HOSTNAME },
@@ -95,7 +92,7 @@ object KotlinStatefulNetworkCount {
9592
val stateDstream = wordsDstream.mapWithState(
9693
StateSpec
9794
.function(mappingFunc)
98-
.initialState(initialRDD.toPairRDD())
95+
.initialState(initialRDD.toJavaPairRDD())
9996
)
10097

10198
stateDstream.print()

jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkIntegration.kt

+4
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ internal class SparkIntegration : Integration() {
7777
inline fun <reified T> RDD<T>.toDF(vararg colNames: String): Dataset<Row> = toDF(spark, *colNames)""".trimIndent(),
7878
"""
7979
inline fun <reified T> JavaRDDLike<T, *>.toDF(vararg colNames: String): Dataset<Row> = toDF(spark, *colNames)""".trimIndent(),
80+
"""
81+
fun <T> List<T>.toRDD(numSlices: Int = sc.defaultParallelism()): JavaRDD<T> = sc.toRDD(this, numSlices)""".trimIndent(),
82+
"""
83+
fun <T> rddOf(vararg elements: T, numSlices: Int = sc.defaultParallelism()): JavaRDD<T> = sc.toRDD(elements.toList(), numSlices)""".trimIndent(),
8084
"""
8185
val udf: UDFRegistration get() = spark.udf()""".trimIndent(),
8286
).map(::execute)

0 commit comments

Comments
 (0)