diff --git a/pom.xml b/pom.xml
index d82d5f6..d054be5 100644
--- a/pom.xml
+++ b/pom.xml
@@ -3,7 +3,7 @@
4.0.0
com.cisco.cognitive
oraf
- 1.3.0-SNAPSHOT
+ 2.0.0
jar
${project.groupId}:${project.artifactId}
@@ -40,13 +40,13 @@
- 1.8
- 1.8
+ 11
+ 11
UTF-8
2.12
2.12
2.12.10
- 2.4.4
+ 3.0.3
@@ -175,25 +175,25 @@
-
- org.apache.maven.plugins
- maven-gpg-plugin
- 1.5
-
-
- sign-artifacts
- verify
-
- sign
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
net.alchim31.maven
scala-maven-plugin
- 4.1.0
+ 4.8.0
diff --git a/src/main/scala/org/apache/spark/ml/classification/OptimizedDecisionTreeClassifier.scala b/src/main/scala/org/apache/spark/ml/classification/OptimizedDecisionTreeClassifier.scala
index 86873a0..1edd3a8 100755
--- a/src/main/scala/org/apache/spark/ml/classification/OptimizedDecisionTreeClassifier.scala
+++ b/src/main/scala/org/apache/spark/ml/classification/OptimizedDecisionTreeClassifier.scala
@@ -59,27 +59,27 @@ class OptimizedDecisionTreeClassifier @Since("1.4.0") (
/** @group setParam */
@Since("1.4.0")
- override def setMaxDepth(value: Int): this.type = set(maxDepth, value)
+ def setMaxDepth(value: Int): this.type = set(maxDepth, value)
/** @group setParam */
@Since("1.4.0")
- override def setMaxBins(value: Int): this.type = set(maxBins, value)
+ def setMaxBins(value: Int): this.type = set(maxBins, value)
/** @group setParam */
@Since("1.4.0")
- override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
+ def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
/** @group setParam */
@Since("1.4.0")
- override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
+ def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
/** @group expertSetParam */
@Since("1.4.0")
- override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
+ def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
/** @group expertSetParam */
@Since("1.4.0")
- override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
+ def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
/**
* Specifies how often to checkpoint the cached node IDs.
@@ -91,15 +91,15 @@ class OptimizedDecisionTreeClassifier @Since("1.4.0") (
* @group setParam
*/
@Since("1.4.0")
- override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
+ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
/** @group setParam */
@Since("1.4.0")
- override def setImpurity(value: String): this.type = set(impurity, value)
+ def setImpurity(value: String): this.type = set(impurity, value)
/** @group setParam */
@Since("1.6.0")
- override def setSeed(value: Long): this.type = set(seed, value)
+ def setSeed(value: Long): this.type = set(seed, value)
/** @group setParam */
@Since("2.0.0")
@@ -261,7 +261,7 @@ class OptimizedDecisionTreeClassificationModel(
}
// TODO: Make sure this is correct
- override protected def predictRaw(features: Vector): Vector = {
+ override def predictRaw(features: Vector): Vector = {
val predictions = Array.fill[Double](numClasses)(0.0)
predictions(rootNode.predict(features).toInt) = 1.0
diff --git a/src/main/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifier.scala b/src/main/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifier.scala
index 0c398d7..b3cd055 100755
--- a/src/main/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifier.scala
+++ b/src/main/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifier.scala
@@ -61,27 +61,27 @@ class OptimizedRandomForestClassifier @Since("1.4.0") (
/** @group setParam */
@Since("1.4.0")
- override def setMaxDepth(value: Int): this.type = set(maxDepth, value)
+ def setMaxDepth(value: Int): this.type = set(maxDepth, value)
/** @group setParam */
@Since("1.4.0")
- override def setMaxBins(value: Int): this.type = set(maxBins, value)
+ def setMaxBins(value: Int): this.type = set(maxBins, value)
/** @group setParam */
@Since("1.4.0")
- override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
+ def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
/** @group setParam */
@Since("1.4.0")
- override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
+ def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
/** @group expertSetParam */
@Since("1.4.0")
- override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
+ def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
/** @group expertSetParam */
@Since("1.4.0")
- override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
+ def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
/**
* Specifies how often to checkpoint the cached node IDs.
@@ -93,31 +93,31 @@ class OptimizedRandomForestClassifier @Since("1.4.0") (
* @group setParam
*/
@Since("1.4.0")
- override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
+ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
/** @group setParam */
@Since("1.4.0")
- override def setImpurity(value: String): this.type = set(impurity, value)
+ def setImpurity(value: String): this.type = set(impurity, value)
// Parameters from TreeEnsembleParams:
/** @group setParam */
@Since("1.4.0")
- override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value)
+ def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value)
/** @group setParam */
@Since("1.4.0")
- override def setSeed(value: Long): this.type = set(seed, value)
+ def setSeed(value: Long): this.type = set(seed, value)
// Parameters from RandomForestParams:
/** @group setParam */
@Since("1.4.0")
- override def setNumTrees(value: Int): this.type = set(numTrees, value)
+ def setNumTrees(value: Int): this.type = set(numTrees, value)
/** @group setParam */
@Since("1.4.0")
- override def setFeatureSubsetStrategy(value: String): this.type =
+ def setFeatureSubsetStrategy(value: String): this.type =
set(featureSubsetStrategy, value)
/** @group setParam */
@@ -250,15 +250,7 @@ class OptimizedRandomForestClassificationModel(
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
- override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
- val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
- val predictUDF = udf { features: Any =>
- bcastModel.value.predict(features.asInstanceOf[Vector])
- }
- dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
- }
-
- override protected def predictRaw(features: Vector): Vector = {
+ override def predictRaw(features: Vector): Vector = {
// TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
// Classifies using majority votes.
// Ignore the tree weights since all are 1.0 for now.
diff --git a/src/main/scala/org/apache/spark/ml/regression/OptimizedDecisionTreeRegressor.scala b/src/main/scala/org/apache/spark/ml/regression/OptimizedDecisionTreeRegressor.scala
index 0b80297..f37ecf3 100755
--- a/src/main/scala/org/apache/spark/ml/regression/OptimizedDecisionTreeRegressor.scala
+++ b/src/main/scala/org/apache/spark/ml/regression/OptimizedDecisionTreeRegressor.scala
@@ -58,27 +58,27 @@ class OptimizedDecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override v
// Override parameter setters from parent trait for Java API compatibility.
/** @group setParam */
@Since("1.4.0")
- override def setMaxDepth(value: Int): this.type = set(maxDepth, value)
+ def setMaxDepth(value: Int): this.type = set(maxDepth, value)
/** @group setParam */
@Since("1.4.0")
- override def setMaxBins(value: Int): this.type = set(maxBins, value)
+ def setMaxBins(value: Int): this.type = set(maxBins, value)
/** @group setParam */
@Since("1.4.0")
- override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
+ def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
/** @group setParam */
@Since("1.4.0")
- override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
+ def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
/** @group expertSetParam */
@Since("1.4.0")
- override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
+ def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
/** @group expertSetParam */
@Since("1.4.0")
- override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
+ def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
/**
* Specifies how often to checkpoint the cached node IDs.
@@ -90,15 +90,15 @@ class OptimizedDecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override v
* @group setParam
*/
@Since("1.4.0")
- override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
+ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
/** @group setParam */
@Since("1.4.0")
- override def setImpurity(value: String): this.type = set(impurity, value)
+ def setImpurity(value: String): this.type = set(impurity, value)
/** @group setParam */
@Since("1.6.0")
- override def setSeed(value: Long): this.type = set(seed, value)
+ def setSeed(value: Long): this.type = set(seed, value)
/** @group setParam */
@Since("2.0.0")
@@ -175,7 +175,7 @@ class OptimizedDecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override v
object OptimizedDecisionTreeRegressor
extends DefaultParamsReadable[OptimizedDecisionTreeRegressor] {
/** Accessor for supported impurities: variance */
- final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
+ final val supportedImpurities: Array[String] = Array("variance")
@Since("2.0.0")
override def load(path: String): OptimizedDecisionTreeRegressor = super.load(path)
diff --git a/src/main/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressor.scala b/src/main/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressor.scala
index df8cd38..c0e2ffd 100755
--- a/src/main/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressor.scala
+++ b/src/main/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressor.scala
@@ -59,27 +59,27 @@ class OptimizedRandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override v
/** @group setParam */
@Since("1.4.0")
- override def setMaxDepth(value: Int): this.type = set(maxDepth, value)
+ def setMaxDepth(value: Int): this.type = set(maxDepth, value)
/** @group setParam */
@Since("1.4.0")
- override def setMaxBins(value: Int): this.type = set(maxBins, value)
+ def setMaxBins(value: Int): this.type = set(maxBins, value)
/** @group setParam */
@Since("1.4.0")
- override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
+ def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
/** @group setParam */
@Since("1.4.0")
- override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
+ def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
/** @group expertSetParam */
@Since("1.4.0")
- override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
+ def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
/** @group expertSetParam */
@Since("1.4.0")
- override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
+ def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
/**
* Specifies how often to checkpoint the cached node IDs.
@@ -91,31 +91,31 @@ class OptimizedRandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override v
* @group setParam
*/
@Since("1.4.0")
- override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
+ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
/** @group setParam */
@Since("1.4.0")
- override def setImpurity(value: String): this.type = set(impurity, value)
+ def setImpurity(value: String): this.type = set(impurity, value)
// Parameters from TreeEnsembleParams:
/** @group setParam */
@Since("1.4.0")
- override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value)
+ def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value)
/** @group setParam */
@Since("1.4.0")
- override def setSeed(value: Long): this.type = set(seed, value)
+ def setSeed(value: Long): this.type = set(seed, value)
// Parameters from RandomForestParams:
/** @group setParam */
@Since("1.4.0")
- override def setNumTrees(value: Int): this.type = set(numTrees, value)
+ def setNumTrees(value: Int): this.type = set(numTrees, value)
/** @group setParam */
@Since("1.4.0")
- override def setFeatureSubsetStrategy(value: String): this.type =
+ def setFeatureSubsetStrategy(value: String): this.type =
set(featureSubsetStrategy, value)
/** @group setParam */
@@ -181,7 +181,7 @@ class OptimizedRandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override v
object OptimizedRandomForestRegressor extends DefaultParamsReadable[OptimizedRandomForestRegressor]{
/** Accessor for supported impurity settings: variance */
@Since("1.4.0")
- final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
+ final val supportedImpurities: Array[String] = Array("variance")
/** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
@Since("1.4.0")
diff --git a/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
new file mode 100644
index 0000000..19bd10a
--- /dev/null
+++ b/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import java.io.IOException
+
+import scala.collection.mutable
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.tree.{LearningNode, Split}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+
+
+/**
+ * This is used by the node id cache to find the child id that a data point would belong to.
+ * @param split Split information.
+ * @param nodeIndex The current node index of a data point that this will update.
+ */
+private[tree] case class NodeIndexUpdater(split: Split, nodeIndex: Int) {
+
+ /**
+ * Determine a child node index based on the feature value and the split.
+ * @param binnedFeature Binned feature value.
+ * @param splits Split information to convert the bin indices to approximate feature values.
+ * @return Child node index to update to.
+ */
+ def updateNodeIndex(binnedFeature: Int, splits: Array[Split]): Int = {
+ if (split.shouldGoLeft(binnedFeature, splits)) {
+ LearningNode.leftChildIndex(nodeIndex)
+ } else {
+ LearningNode.rightChildIndex(nodeIndex)
+ }
+ }
+}
+
+/**
+ * Each TreePoint belongs to a particular node per tree.
+ * Each row in the nodeIdsForInstances RDD is an array over trees of the node index
+ * in each tree. Initially, values should all be 1 for root node.
+ * The nodeIdsForInstances RDD needs to be updated at each iteration.
+ * @param nodeIdsForInstances The initial values in the cache
+ * (should be an Array of all 1's (meaning the root nodes)).
+ * @param checkpointInterval The checkpointing interval
+ * (how often should the cache be checkpointed.).
+ */
+private[spark] class NodeIdCache(
+ var nodeIdsForInstances: RDD[Array[Int]],
+ val checkpointInterval: Int) extends Logging {
+
+ // Keep a reference to a previous node Ids for instances.
+ // Because we will keep on re-persisting updated node Ids,
+ // we want to unpersist the previous RDD.
+ private var prevNodeIdsForInstances: RDD[Array[Int]] = null
+
+ // To keep track of the past checkpointed RDDs.
+ private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
+ private var rddUpdateCount = 0
+
+ // Indicates whether we can checkpoint
+ private val canCheckpoint = nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty
+
+ // Hadoop Configuration for deleting checkpoints as needed
+ private val hadoopConf = nodeIdsForInstances.sparkContext.hadoopConfiguration
+
+ /**
+ * Update the node index values in the cache.
+ * This updates the RDD and its lineage.
+ * TODO: Passing bin information to executors seems unnecessary and costly.
+ * @param data The RDD of training rows.
+ * @param nodeIdUpdaters A map of node index updaters.
+ * The key is the indices of nodes that we want to update.
+ * @param splits Split information needed to find child node indices.
+ */
+ def updateNodeIndices(
+ data: RDD[BaggedPoint[TreePoint]],
+ nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]],
+ splits: Array[Array[Split]]): Unit = {
+ if (prevNodeIdsForInstances != null) {
+ // Unpersist the previous one if one exists.
+ prevNodeIdsForInstances.unpersist(false)
+ }
+
+ prevNodeIdsForInstances = nodeIdsForInstances
+ nodeIdsForInstances = data.zip(nodeIdsForInstances).map { case (point, ids) =>
+ var treeId = 0
+ while (treeId < nodeIdUpdaters.length) {
+ val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(ids(treeId), null)
+ if (nodeIdUpdater != null) {
+ val featureIndex = nodeIdUpdater.split.featureIndex
+ val newNodeIndex = nodeIdUpdater.updateNodeIndex(
+ binnedFeature = point.datum.binnedFeatures(featureIndex),
+ splits = splits(featureIndex))
+ ids(treeId) = newNodeIndex
+ }
+ treeId += 1
+ }
+ ids
+ }
+
+ // Keep on persisting new ones.
+ nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK)
+ rddUpdateCount += 1
+
+ // Handle checkpointing if the directory is not None.
+ if (canCheckpoint && checkpointInterval != -1 && (rddUpdateCount % checkpointInterval) == 0) {
+ // Let's see if we can delete previous checkpoints.
+ var canDelete = true
+ while (checkpointQueue.size > 1 && canDelete) {
+ // We can delete the oldest checkpoint iff
+ // the next checkpoint actually exists in the file system.
+ if (checkpointQueue(1).getCheckpointFile.isDefined) {
+ val old = checkpointQueue.dequeue()
+ // Since the old checkpoint is not deleted by Spark, we'll manually delete it here.
+ try {
+ val path = new Path(old.getCheckpointFile.get)
+ val fs = path.getFileSystem(hadoopConf)
+ fs.delete(path, true)
+ } catch {
+ case e: IOException =>
+ logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" +
+ s" file: ${old.getCheckpointFile.get}")
+ }
+ } else {
+ canDelete = false
+ }
+ }
+
+ nodeIdsForInstances.checkpoint()
+ checkpointQueue.enqueue(nodeIdsForInstances)
+ }
+ }
+
+ /**
+ * Call this after training is finished to delete any remaining checkpoints.
+ */
+ def deleteAllCheckpoints(): Unit = {
+ while (checkpointQueue.nonEmpty) {
+ val old = checkpointQueue.dequeue()
+ if (old.getCheckpointFile.isDefined) {
+ try {
+ val path = new Path(old.getCheckpointFile.get)
+ val fs = path.getFileSystem(hadoopConf)
+ fs.delete(path, true)
+ } catch {
+ case e: IOException =>
+ logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" +
+ s" file: ${old.getCheckpointFile.get}")
+ }
+ }
+ }
+ if (nodeIdsForInstances != null) {
+ // Unpersist current one if one exists.
+ nodeIdsForInstances.unpersist(false)
+ }
+ if (prevNodeIdsForInstances != null) {
+ // Unpersist the previous one if one exists.
+ prevNodeIdsForInstances.unpersist(false)
+ }
+ }
+}
+
+private[spark] object NodeIdCache {
+ /**
+ * Initialize the node Id cache with initial node Id values.
+ * @param data The RDD of training rows.
+ * @param numTrees The number of trees that we want to create cache for.
+ * @param checkpointInterval The checkpointing interval
+ * (how often should the cache be checkpointed.).
+ * @param initVal The initial values in the cache.
+ * @return A node Id cache containing an RDD of initial root node Indices.
+ */
+ def init(
+ data: RDD[BaggedPoint[TreePoint]],
+ numTrees: Int,
+ checkpointInterval: Int,
+ initVal: Int = 1): NodeIdCache = {
+ new NodeIdCache(
+ data.map(_ => Array.fill[Int](numTrees)(initVal)),
+ checkpointInterval)
+ }
+}
diff --git a/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedDTStatsAggregator.scala b/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedDTStatsAggregator.scala
index da1e890..369374e 100644
--- a/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedDTStatsAggregator.scala
+++ b/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedDTStatsAggregator.scala
@@ -108,14 +108,14 @@ class OptimizedDTStatsAggregator(
*/
def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = {
val i = featureOffsets(featureIndex) + binIndex * statsSize
- impurityAggregator.update(allStats, i, label, instanceWeight)
+ impurityAggregator.update(allStats, i, label, 1, instanceWeight)
}
/**
* Update the parent node stats using the given label.
*/
def updateParent(label: Double, instanceWeight: Double): Unit = {
- impurityAggregator.update(parentStats, 0, label, instanceWeight)
+ impurityAggregator.update(parentStats, 0, label, 1, instanceWeight)
}
/**
@@ -130,8 +130,7 @@ class OptimizedDTStatsAggregator(
binIndex: Int,
label: Double,
instanceWeight: Double): Unit = {
- impurityAggregator.update(allStats, featureOffset + binIndex * statsSize,
- label, instanceWeight)
+ impurityAggregator.update(allStats, featureOffset + binIndex * statsSize, label, 1, instanceWeight)
}
/**
diff --git a/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedDecisionTreeMetadata.scala b/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedDecisionTreeMetadata.scala
index 0860ecc..e2fcf77 100644
--- a/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedDecisionTreeMetadata.scala
+++ b/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedDecisionTreeMetadata.scala
@@ -44,6 +44,7 @@ import scala.util.Try
class OptimizedDecisionTreeMetadata(
numFeatures: Int,
numExamples: Long,
+ weightedNumExamples: Double,
numClasses: Int,
maxBins: Int,
featureArity: Map[Int, Int],
@@ -53,10 +54,27 @@ class OptimizedDecisionTreeMetadata(
quantileStrategy: QuantileStrategy,
maxDepth: Int,
minInstancesPerNode: Int,
+ minWeightFractionPerNode: Double,
minInfoGain: Double,
numTrees: Int,
- numFeaturesPerNode: Int) extends DecisionTreeMetadata(numFeatures,
- numExamples, numClasses, maxBins, featureArity, unorderedFeatures, numBins, impurity, quantileStrategy, maxDepth, minInstancesPerNode, minInfoGain, numTrees, numFeaturesPerNode) with Serializable {
+ numFeaturesPerNode: Int) extends DecisionTreeMetadata(
+ numFeatures,
+ numExamples,
+ weightedNumExamples,
+ numClasses,
+ maxBins,
+ featureArity,
+ unorderedFeatures,
+ numBins,
+ impurity,
+ quantileStrategy,
+ maxDepth,
+ minInstancesPerNode,
+ minWeightFractionPerNode,
+ minInfoGain,
+ numTrees,
+ numFeaturesPerNode
+) with Serializable {
}
object OptimizedDecisionTreeMetadata extends Logging {
@@ -78,7 +96,11 @@ object OptimizedDecisionTreeMetadata extends Logging {
}
require(numFeatures > 0, s"DecisionTree requires number of features > 0, " +
s"but was given an empty features vector")
- val numExamples = input.count()
+ val (numExamples, weightSum) = input.aggregate((0L, 0.0))(
+ seqOp = (cw, instance) => (cw._1 + 1L, cw._2 + instance.weight),
+ combOp = (cw1, cw2) => (cw1._1 + cw2._1, cw1._2 + cw2._2)
+ )
+
val numClasses = strategy.algo match {
case Classification => strategy.numClasses
case Regression => 0
@@ -169,10 +191,11 @@ object OptimizedDecisionTreeMetadata extends Logging {
}
}
- new OptimizedDecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
- strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
+ new OptimizedDecisionTreeMetadata(numFeatures, numExamples, weightSum, numClasses,
+ numBins.max, strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
- strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode)
+ strategy.minInstancesPerNode, strategy.minWeightFractionPerNode, strategy.minInfoGain,
+ numTrees, numFeaturesPerNode)
}
/**
diff --git a/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForest.scala b/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForest.scala
index 02aaffe..61d113f 100755
--- a/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForest.scala
+++ b/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForest.scala
@@ -19,15 +19,12 @@
package org.apache.spark.ml.tree.impl
-import java.io.IOException
-
import org.apache.spark.Partitioner
import org.apache.spark.internal.Logging
import org.apache.spark.ml.classification.OptimizedDecisionTreeClassificationModel
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.regression.OptimizedDecisionTreeRegressionModel
import org.apache.spark.ml.tree._
-import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.util.Instrumentation
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, OptimizedForestStrategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
@@ -36,6 +33,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
+import java.io.IOException
import scala.collection.{SeqView, mutable}
import scala.util.{Random, Try}
@@ -174,7 +172,7 @@ private[spark] object OptimizedRandomForest extends Logging {
val withReplacement = numTrees > 1
val baggedInput = BaggedPoint
- .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, seed)
+ .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, (treePoint: TreePoint) => treePoint.weight, seed = seed)
.persist(StorageLevel.MEMORY_AND_DISK)
val distributedMaxDepth = Math.min(strategy.maxDepth, 30)
@@ -340,11 +338,11 @@ private[spark] object OptimizedRandomForest extends Logging {
pointsWithNodeIds.flatMap {
case (baggedPoint, nodeIdsForTree) =>
nodeSetsBc.value.keys
- .filter(treeId => baggedPoint.subsampleWeights(treeId) > 0)
+ .filter(treeId => baggedPoint.subsampleCounts(treeId) > 0)
.map(treeId => (treeId, nodeIdsForTree(treeId)))
.filter { case (treeId, nodeId) => nodeSetsBc.value(treeId).contains(nodeId) }
.map { case (treeId, nodeId) =>
- ((treeId, nodeId), (baggedPoint.datum, baggedPoint.subsampleWeights(treeId) * baggedPoint.datum.sampleWeight))
+ ((treeId, nodeId), (baggedPoint.datum, baggedPoint.subsampleCounts(treeId) * baggedPoint.datum.sampleWeight))
}
}
}
@@ -676,7 +674,7 @@ private[spark] object OptimizedRandomForest extends Logging {
if (nodeInfo != null) {
val aggNodeIndex = nodeInfo.nodeIndexInGroup
val featuresForNode = nodeInfo.featureSubset
- val instanceWeight = baggedPoint.subsampleWeights(treeIndex) * baggedPoint.datum.sampleWeight
+ val instanceWeight = baggedPoint.subsampleCounts(treeIndex) * baggedPoint.datum.sampleWeight
if (metadata.unorderedFeatures.isEmpty) {
orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
} else {
@@ -879,11 +877,11 @@ private[spark] object OptimizedRandomForest extends Logging {
// enqueue left child and right child if they are not leaves
if (!leftChildIsLeaf) {
- addTrainingTask(node.leftChild.get, treeIndex, stats.leftImpurityCalculator.count,
+ addTrainingTask(node.leftChild.get, treeIndex, stats.leftImpurityCalculator.count.toLong,
nodeLevel, stats.leftImpurity)
}
if (!rightChildIsLeaf) {
- addTrainingTask(node.rightChild.get, treeIndex, stats.rightImpurityCalculator.count,
+ addTrainingTask(node.rightChild.get, treeIndex, stats.rightImpurityCalculator.count.toLong,
nodeLevel, stats.rightImpurity)
}
diff --git a/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedTreePoint.scala b/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedTreePoint.scala
index 6b89734..245bc16 100644
--- a/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedTreePoint.scala
+++ b/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedTreePoint.scala
@@ -40,7 +40,7 @@ import org.apache.spark.rdd.RDD
* Same length as LabeledPoint.features, but values are bin indices.
*/
class OptimizedTreePoint(label: Double, binnedFeatures: Array[Int], val sampleWeight: Double)
- extends TreePoint(label, binnedFeatures) with Serializable {
+ extends TreePoint(label, binnedFeatures, sampleWeight) with Serializable {
}
object OptimizedTreePoint {
diff --git a/src/test/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifierSuite.scala b/src/test/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifierSuite.scala
index 5e242ce..7dd52eb 100755
--- a/src/test/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifierSuite.scala
+++ b/src/test/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifierSuite.scala
@@ -22,12 +22,8 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, Vectors}
-import org.apache.spark.ml.tree.impl.{OptimizedRandomForestSuite, OptimizedTreeTests, TreeTests}
+import org.apache.spark.ml.tree.impl.{OptimizedRandomForestSuite, OptimizedTreeTests}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
-import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
-import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
-import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
-import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
@@ -54,7 +50,7 @@ class OptimizedRandomForestClassifierSuite extends MLTest with DefaultReadWriteT
// Tests calling train()
/////////////////////////////////////////////////////////////////////////////
- def binaryClassificationTestWithContinuousFeatures(rf: RandomForestClassifier, orf: OptimizedRandomForestClassifier) {
+ def binaryClassificationTestWithContinuousFeatures(rf: RandomForestClassifier, orf: OptimizedRandomForestClassifier): Unit = {
val categoricalFeatures = Map.empty[Int, Int]
val numClasses = 2
val newRF = rf
@@ -72,6 +68,7 @@ class OptimizedRandomForestClassifierSuite extends MLTest with DefaultReadWriteT
compareAPIs(orderedInstances50_1000, newRF, optimizedRF, categoricalFeatures, numClasses)
}
+ // Fixed
test("Binary classification with continuous features:" +
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
val rf = new RandomForestClassifier()
@@ -79,6 +76,7 @@ class OptimizedRandomForestClassifierSuite extends MLTest with DefaultReadWriteT
binaryClassificationTestWithContinuousFeatures(rf, orf)
}
+ // Fixed
test("Binary classification with continuous features and node Id cache:" +
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
val rf = new RandomForestClassifier()
diff --git a/src/test/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressorSuite.scala b/src/test/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressorSuite.scala
index e796d0e..a9469e0 100755
--- a/src/test/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressorSuite.scala
+++ b/src/test/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressorSuite.scala
@@ -23,9 +23,6 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.tree.impl.{OptimizedRandomForestSuite, OptimizedTreeTests}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
-import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
-import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
-import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
@@ -49,7 +46,7 @@ class OptimizedRandomForestRegressorSuite extends MLTest with DefaultReadWriteTe
// Tests calling train()
/////////////////////////////////////////////////////////////////////////////
- def regressionTestWithContinuousFeatures(rf: RandomForestRegressor, orf: OptimizedRandomForestRegressor) {
+ def regressionTestWithContinuousFeatures(rf: RandomForestRegressor, orf: OptimizedRandomForestRegressor): Unit = {
val categoricalFeaturesInfo = Map.empty[Int, Int]
val newRF = rf
.setImpurity("variance")
@@ -68,6 +65,7 @@ class OptimizedRandomForestRegressorSuite extends MLTest with DefaultReadWriteTe
compareAPIs(orderedInstances50_1000, newRF, optimizedRF, categoricalFeaturesInfo)
}
+ // Fixed
test("Regression with continuous features:" +
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
val rf = new RandomForestRegressor()
diff --git a/src/test/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeRegressor.scala b/src/test/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeRegressor.scala
index 8c0968d..8586e8d 100755
--- a/src/test/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeRegressor.scala
+++ b/src/test/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeRegressor.scala
@@ -42,22 +42,21 @@ private[impl] final class LocalDecisionTreeRegressor(override val uid: String)
def this() = this(Identifiable.randomUID("local_dtr"))
// Override parameter setters from parent trait for Java API compatibility.
- override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+ def setMaxDepth(value: Int): this.type = set(maxDepth, value)
- override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+ def setMaxBins(value: Int): this.type = set(maxBins, value)
- override def setMinInstancesPerNode(value: Int): this.type =
- super.setMinInstancesPerNode(value)
+ def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
- override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+ def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
- override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+ def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
- override def setImpurity(value: String): this.type = super.setImpurity(value)
+ def setImpurity(value: String): this.type = set(impurity, value)
- override def setSeed(value: Long): this.type = super.setSeed(value)
+ def setSeed(value: Long): this.type = set(seed, value)
- override def copy(extra: ParamMap): LocalDecisionTreeRegressor = defaultCopy(extra)
+ def copy(extra: ParamMap): LocalDecisionTreeRegressor = defaultCopy(extra)
override protected def train(dataset: Dataset[_]): OptimizedDecisionTreeRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
diff --git a/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForestSuite.scala b/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForestSuite.scala
index a8defb3..a0a36a8 100755
--- a/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForestSuite.scala
+++ b/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForestSuite.scala
@@ -94,10 +94,10 @@ class OptimizedRandomForestSuite extends SparkFunSuite with MLlibTestSparkContex
test("find splits for a continuous feature") {
// find splits for normal case
{
- val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 200000, 0, 0,
+ val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 200000, 200000, 0, 0,
Map(), Set(),
Array(6), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0.0, 0, 0
)
val featureSamples = Array.fill(10000)(math.random).filter(_ != 0.0)
val splits = OptimizedRandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
@@ -110,10 +110,10 @@ class OptimizedRandomForestSuite extends SparkFunSuite with MLlibTestSparkContex
// SPARK-16957: Use midpoints for split values.
{
- val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 8, 0, 0,
+ val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 8, 8, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0.0, 0, 0
)
// TODO: Why doesn't this work after filtering 0.0?
@@ -137,10 +137,10 @@ class OptimizedRandomForestSuite extends SparkFunSuite with MLlibTestSparkContex
// find splits should not return identical splits
// when there are not enough split candidates, reduce the number of splits in metadata
{
- val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 12, 0, 0,
+ val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 12, 12, 0, 0,
Map(), Set(),
Array(5), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0.0, 0, 0
)
val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(_.toDouble)
val splits = OptimizedRandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
@@ -152,10 +152,10 @@ class OptimizedRandomForestSuite extends SparkFunSuite with MLlibTestSparkContex
// find splits when most samples close to the minimum
{
- val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 18, 0, 0,
+ val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 18, 18, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0.0, 0, 0
)
val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5)
.map(_.toDouble)
@@ -166,10 +166,10 @@ class OptimizedRandomForestSuite extends SparkFunSuite with MLlibTestSparkContex
// find splits when most samples close to the maximum
{
- val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 17, 0, 0,
+ val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 17, 17, 0, 0,
Map(), Set(),
Array(2), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0.0, 0, 0
)
val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)
.map(_.toDouble).filter(_ != 0.0)
@@ -180,10 +180,10 @@ class OptimizedRandomForestSuite extends SparkFunSuite with MLlibTestSparkContex
// find splits for constant feature
{
- val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 3, 0, 0,
+ val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 3, 3, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0.0, 0, 0
)
val featureSamples = Array(0, 0, 0).map(_.toDouble).filter(_ != 0.0)
val featureSamplesEmpty = Array.empty[Double]
diff --git a/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedTreeTests.scala b/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedTreeTests.scala
index f407819..eefd4dc 100755
--- a/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedTreeTests.scala
+++ b/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedTreeTests.scala
@@ -200,10 +200,10 @@ private[ml] object OptimizedTreeTests extends SparkFunSuite {
}
new OptimizedDecisionTreeMetadata(numFeatures = numFeatures, numExamples = numExamples,
- numClasses = numClasses, maxBins = maxBins, minInfoGain = 0.0, featureArity = featureArity,
- unorderedFeatures = unordered, numBins = numBins, impurity = impurity,
+ weightedNumExamples = numExamples, numClasses = numClasses, maxBins = maxBins, minInfoGain = 0.0,
+ featureArity = featureArity, unorderedFeatures = unordered, numBins = numBins, impurity = impurity,
quantileStrategy = null, maxDepth = 5, minInstancesPerNode = 1, numTrees = 1,
- numFeaturesPerNode = 2)
+ numFeaturesPerNode = 2, minWeightFractionPerNode = 0.0)
}
/**
@@ -240,15 +240,13 @@ private[ml] object OptimizedTreeTests extends SparkFunSuite {
* make mistakes such as creating loops of Nodes.
*/
private def checkEqual(a: Node, b: OptimizedNode): Unit = {
- assert(a.prediction === b.prediction)
- assert(a.impurity === b.impurity)
-// assert(a.impurityStats.stats === b.impurityStats.stats)
(a, b) match {
- case (aye: InternalNode, bee: OptimizedInternalNode) =>
- assert(aye.split === bee.split)
- checkEqual(aye.leftChild, bee.leftChild)
- checkEqual(aye.rightChild, bee.rightChild)
- case (aye: LeafNode, bee: OptimizedLeafNode) => // do nothing
+ case (aa: InternalNode, bb: OptimizedInternalNode) =>
+ assert(aa.split === bb.split)
+ checkEqual(aa.leftChild, bb.leftChild)
+ checkEqual(aa.rightChild, bb.rightChild)
+ case (aa: LeafNode, bb: OptimizedLeafNode) => // do nothing
+ assert(aa.prediction === bb.prediction)
case _ =>
println(a.getClass.getCanonicalName, b.getClass.getCanonicalName)
throw new AssertionError("Found mismatched nodes")
@@ -264,11 +262,11 @@ private[ml] object OptimizedTreeTests extends SparkFunSuite {
assert(a.impurity === b.impurity)
// assert(a.impurityStats.stats === b.impurityStats.stats)
(a, b) match {
- case (aye: OptimizedInternalNode, bee: OptimizedInternalNode) =>
- assert(aye.split === bee.split)
- checkEqual(aye.leftChild, bee.leftChild)
- checkEqual(aye.rightChild, bee.rightChild)
- case (aye: OptimizedLeafNode, bee: OptimizedLeafNode) => // do nothing
+ case (aa: OptimizedInternalNode, bb: OptimizedInternalNode) =>
+ assert(aa.split === bb.split)
+ checkEqual(aa.leftChild, bb.leftChild)
+ checkEqual(aa.rightChild, bb.rightChild)
+ case (_: OptimizedLeafNode, _: OptimizedLeafNode) => // do nothing
case _ =>
println(a.getClass.getCanonicalName, b.getClass.getCanonicalName)
throw new AssertionError("Found mismatched nodes")