Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

migrate to spark3 #7

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>com.cisco.cognitive</groupId>
<artifactId>oraf</artifactId>
<version>1.3.0-SNAPSHOT</version>
<version>2.0.0</version>
<packaging>jar</packaging>

<name>${project.groupId}:${project.artifactId}</name>
Expand Down Expand Up @@ -40,13 +40,13 @@
</scm>

<properties>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
<encoding>UTF-8</encoding>
<scala.compat.version>2.12</scala.compat.version>
<scala.binary.version>2.12</scala.binary.version>
<scala.version>2.12.10</scala.version>
<spark.version>2.4.4</spark.version>
<spark.version>3.0.3</spark.version>
</properties>

<dependencies>
Expand Down Expand Up @@ -175,25 +175,25 @@
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-gpg-plugin</artifactId>
<version>1.5</version>
<executions>
<execution>
<id>sign-artifacts</id>
<phase>verify</phase>
<goals>
<goal>sign</goal>
</goals>
</execution>
</executions>
</plugin>
<!-- <plugin>-->
<!-- <groupId>org.apache.maven.plugins</groupId>-->
<!-- <artifactId>maven-gpg-plugin</artifactId>-->
<!-- <version>1.5</version>-->
<!-- <executions>-->
<!-- <execution>-->
<!-- <id>sign-artifacts</id>-->
<!-- <phase>verify</phase>-->
<!-- <goals>-->
<!-- <goal>sign</goal>-->
<!-- </goals>-->
<!-- </execution>-->
<!-- </executions>-->
<!-- </plugin>-->
<plugin>
<!-- see http://davidb.github.com/scala-maven-plugin -->
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<version>4.1.0</version>
<version>4.8.0</version>
<executions>
<execution>
<goals>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,27 +59,27 @@ class OptimizedDecisionTreeClassifier @Since("1.4.0") (

/** @group setParam */
@Since("1.4.0")
override def setMaxDepth(value: Int): this.type = set(maxDepth, value)
def setMaxDepth(value: Int): this.type = set(maxDepth, value)

/** @group setParam */
@Since("1.4.0")
override def setMaxBins(value: Int): this.type = set(maxBins, value)
def setMaxBins(value: Int): this.type = set(maxBins, value)

/** @group setParam */
@Since("1.4.0")
override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)

/** @group setParam */
@Since("1.4.0")
override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)

/** @group expertSetParam */
@Since("1.4.0")
override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)

/** @group expertSetParam */
@Since("1.4.0")
override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)

/**
* Specifies how often to checkpoint the cached node IDs.
Expand All @@ -91,15 +91,15 @@ class OptimizedDecisionTreeClassifier @Since("1.4.0") (
* @group setParam
*/
@Since("1.4.0")
override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)

/** @group setParam */
@Since("1.4.0")
override def setImpurity(value: String): this.type = set(impurity, value)
def setImpurity(value: String): this.type = set(impurity, value)

/** @group setParam */
@Since("1.6.0")
override def setSeed(value: Long): this.type = set(seed, value)
def setSeed(value: Long): this.type = set(seed, value)

/** @group setParam */
@Since("2.0.0")
Expand Down Expand Up @@ -261,7 +261,7 @@ class OptimizedDecisionTreeClassificationModel(
}

// TODO: Make sure this is correct
override protected def predictRaw(features: Vector): Vector = {
override def predictRaw(features: Vector): Vector = {
val predictions = Array.fill[Double](numClasses)(0.0)
predictions(rootNode.predict(features).toInt) = 1.0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,27 +61,27 @@ class OptimizedRandomForestClassifier @Since("1.4.0") (

/** @group setParam */
@Since("1.4.0")
override def setMaxDepth(value: Int): this.type = set(maxDepth, value)
def setMaxDepth(value: Int): this.type = set(maxDepth, value)

/** @group setParam */
@Since("1.4.0")
override def setMaxBins(value: Int): this.type = set(maxBins, value)
def setMaxBins(value: Int): this.type = set(maxBins, value)

/** @group setParam */
@Since("1.4.0")
override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)

/** @group setParam */
@Since("1.4.0")
override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)

/** @group expertSetParam */
@Since("1.4.0")
override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)

/** @group expertSetParam */
@Since("1.4.0")
override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)

/**
* Specifies how often to checkpoint the cached node IDs.
Expand All @@ -93,31 +93,31 @@ class OptimizedRandomForestClassifier @Since("1.4.0") (
* @group setParam
*/
@Since("1.4.0")
override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)

/** @group setParam */
@Since("1.4.0")
override def setImpurity(value: String): this.type = set(impurity, value)
def setImpurity(value: String): this.type = set(impurity, value)

// Parameters from TreeEnsembleParams:

/** @group setParam */
@Since("1.4.0")
override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value)
def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value)

/** @group setParam */
@Since("1.4.0")
override def setSeed(value: Long): this.type = set(seed, value)
def setSeed(value: Long): this.type = set(seed, value)

// Parameters from RandomForestParams:

/** @group setParam */
@Since("1.4.0")
override def setNumTrees(value: Int): this.type = set(numTrees, value)
def setNumTrees(value: Int): this.type = set(numTrees, value)

/** @group setParam */
@Since("1.4.0")
override def setFeatureSubsetStrategy(value: String): this.type =
def setFeatureSubsetStrategy(value: String): this.type =
set(featureSubsetStrategy, value)

/** @group setParam */
Expand Down Expand Up @@ -250,15 +250,7 @@ class OptimizedRandomForestClassificationModel(
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights

override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
val predictUDF = udf { features: Any =>
bcastModel.value.predict(features.asInstanceOf[Vector])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}

override protected def predictRaw(features: Vector): Vector = {
override def predictRaw(features: Vector): Vector = {
// TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
// Classifies using majority votes.
// Ignore the tree weights since all are 1.0 for now.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,27 +58,27 @@ class OptimizedDecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override v
// Override parameter setters from parent trait for Java API compatibility.
/** @group setParam */
@Since("1.4.0")
override def setMaxDepth(value: Int): this.type = set(maxDepth, value)
def setMaxDepth(value: Int): this.type = set(maxDepth, value)

/** @group setParam */
@Since("1.4.0")
override def setMaxBins(value: Int): this.type = set(maxBins, value)
def setMaxBins(value: Int): this.type = set(maxBins, value)

/** @group setParam */
@Since("1.4.0")
override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)

/** @group setParam */
@Since("1.4.0")
override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)

/** @group expertSetParam */
@Since("1.4.0")
override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)

/** @group expertSetParam */
@Since("1.4.0")
override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)

/**
* Specifies how often to checkpoint the cached node IDs.
Expand All @@ -90,15 +90,15 @@ class OptimizedDecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override v
* @group setParam
*/
@Since("1.4.0")
override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)

/** @group setParam */
@Since("1.4.0")
override def setImpurity(value: String): this.type = set(impurity, value)
def setImpurity(value: String): this.type = set(impurity, value)

/** @group setParam */
@Since("1.6.0")
override def setSeed(value: Long): this.type = set(seed, value)
def setSeed(value: Long): this.type = set(seed, value)

/** @group setParam */
@Since("2.0.0")
Expand Down Expand Up @@ -175,7 +175,7 @@ class OptimizedDecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override v
object OptimizedDecisionTreeRegressor
extends DefaultParamsReadable[OptimizedDecisionTreeRegressor] {
/** Accessor for supported impurities: variance */
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
final val supportedImpurities: Array[String] = Array("variance")

@Since("2.0.0")
override def load(path: String): OptimizedDecisionTreeRegressor = super.load(path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,27 +59,27 @@ class OptimizedRandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override v

/** @group setParam */
@Since("1.4.0")
override def setMaxDepth(value: Int): this.type = set(maxDepth, value)
def setMaxDepth(value: Int): this.type = set(maxDepth, value)

/** @group setParam */
@Since("1.4.0")
override def setMaxBins(value: Int): this.type = set(maxBins, value)
def setMaxBins(value: Int): this.type = set(maxBins, value)

/** @group setParam */
@Since("1.4.0")
override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)

/** @group setParam */
@Since("1.4.0")
override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)

/** @group expertSetParam */
@Since("1.4.0")
override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)

/** @group expertSetParam */
@Since("1.4.0")
override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)

/**
* Specifies how often to checkpoint the cached node IDs.
Expand All @@ -91,31 +91,31 @@ class OptimizedRandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override v
* @group setParam
*/
@Since("1.4.0")
override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)

/** @group setParam */
@Since("1.4.0")
override def setImpurity(value: String): this.type = set(impurity, value)
def setImpurity(value: String): this.type = set(impurity, value)

// Parameters from TreeEnsembleParams:

/** @group setParam */
@Since("1.4.0")
override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value)
def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value)

/** @group setParam */
@Since("1.4.0")
override def setSeed(value: Long): this.type = set(seed, value)
def setSeed(value: Long): this.type = set(seed, value)

// Parameters from RandomForestParams:

/** @group setParam */
@Since("1.4.0")
override def setNumTrees(value: Int): this.type = set(numTrees, value)
def setNumTrees(value: Int): this.type = set(numTrees, value)

/** @group setParam */
@Since("1.4.0")
override def setFeatureSubsetStrategy(value: String): this.type =
def setFeatureSubsetStrategy(value: String): this.type =
set(featureSubsetStrategy, value)

/** @group setParam */
Expand Down Expand Up @@ -181,7 +181,7 @@ class OptimizedRandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override v
object OptimizedRandomForestRegressor extends DefaultParamsReadable[OptimizedRandomForestRegressor]{
/** Accessor for supported impurity settings: variance */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
final val supportedImpurities: Array[String] = Array("variance")

/** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
@Since("1.4.0")
Expand Down
Loading