diff --git a/pom.xml b/pom.xml index d82d5f6..d054be5 100644 --- a/pom.xml +++ b/pom.xml @@ -3,7 +3,7 @@ 4.0.0 com.cisco.cognitive oraf - 1.3.0-SNAPSHOT + 2.0.0 jar ${project.groupId}:${project.artifactId} @@ -40,13 +40,13 @@ - 1.8 - 1.8 + 11 + 11 UTF-8 2.12 2.12 2.12.10 - 2.4.4 + 3.0.3 @@ -175,25 +175,25 @@ - - org.apache.maven.plugins - maven-gpg-plugin - 1.5 - - - sign-artifacts - verify - - sign - - - - + + + + + + + + + + + + + + net.alchim31.maven scala-maven-plugin - 4.1.0 + 4.8.0 diff --git a/src/main/scala/org/apache/spark/ml/classification/OptimizedDecisionTreeClassifier.scala b/src/main/scala/org/apache/spark/ml/classification/OptimizedDecisionTreeClassifier.scala index 86873a0..1edd3a8 100755 --- a/src/main/scala/org/apache/spark/ml/classification/OptimizedDecisionTreeClassifier.scala +++ b/src/main/scala/org/apache/spark/ml/classification/OptimizedDecisionTreeClassifier.scala @@ -59,27 +59,27 @@ class OptimizedDecisionTreeClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -91,15 +91,15 @@ class OptimizedDecisionTreeClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = set(impurity, value) + def setImpurity(value: String): this.type = set(impurity, value) /** @group setParam */ @Since("1.6.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) /** @group setParam */ @Since("2.0.0") @@ -261,7 +261,7 @@ class OptimizedDecisionTreeClassificationModel( } // TODO: Make sure this is correct - override protected def predictRaw(features: Vector): Vector = { + override def predictRaw(features: Vector): Vector = { val predictions = Array.fill[Double](numClasses)(0.0) predictions(rootNode.predict(features).toInt) = 1.0 diff --git a/src/main/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifier.scala b/src/main/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifier.scala index 0c398d7..b3cd055 100755 --- a/src/main/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifier.scala +++ b/src/main/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifier.scala @@ -61,27 +61,27 @@ class OptimizedRandomForestClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -93,31 +93,31 @@ class OptimizedRandomForestClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = set(impurity, value) + def setImpurity(value: String): this.type = set(impurity, value) // Parameters from TreeEnsembleParams: /** @group setParam */ @Since("1.4.0") - override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) // Parameters from RandomForestParams: /** @group setParam */ @Since("1.4.0") - override def setNumTrees(value: Int): this.type = set(numTrees, value) + def setNumTrees(value: Int): this.type = set(numTrees, value) /** @group setParam */ @Since("1.4.0") - override def setFeatureSubsetStrategy(value: String): this.type = + def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) /** @group setParam */ @@ -250,15 +250,7 @@ class OptimizedRandomForestClassificationModel( @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights - override protected def transformImpl(dataset: Dataset[_]): DataFrame = { - val bcastModel = dataset.sparkSession.sparkContext.broadcast(this) - val predictUDF = udf { features: Any => - bcastModel.value.predict(features.asInstanceOf[Vector]) - } - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) - } - - override protected def predictRaw(features: Vector): Vector = { + override def predictRaw(features: Vector): Vector = { // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 // Classifies using majority votes. // Ignore the tree weights since all are 1.0 for now. diff --git a/src/main/scala/org/apache/spark/ml/regression/OptimizedDecisionTreeRegressor.scala b/src/main/scala/org/apache/spark/ml/regression/OptimizedDecisionTreeRegressor.scala index 0b80297..f37ecf3 100755 --- a/src/main/scala/org/apache/spark/ml/regression/OptimizedDecisionTreeRegressor.scala +++ b/src/main/scala/org/apache/spark/ml/regression/OptimizedDecisionTreeRegressor.scala @@ -58,27 +58,27 @@ class OptimizedDecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override v // Override parameter setters from parent trait for Java API compatibility. /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -90,15 +90,15 @@ class OptimizedDecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override v * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = set(impurity, value) + def setImpurity(value: String): this.type = set(impurity, value) /** @group setParam */ @Since("1.6.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) /** @group setParam */ @Since("2.0.0") @@ -175,7 +175,7 @@ class OptimizedDecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override v object OptimizedDecisionTreeRegressor extends DefaultParamsReadable[OptimizedDecisionTreeRegressor] { /** Accessor for supported impurities: variance */ - final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities + final val supportedImpurities: Array[String] = Array("variance") @Since("2.0.0") override def load(path: String): OptimizedDecisionTreeRegressor = super.load(path) diff --git a/src/main/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressor.scala b/src/main/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressor.scala index df8cd38..c0e2ffd 100755 --- a/src/main/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressor.scala +++ b/src/main/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressor.scala @@ -59,27 +59,27 @@ class OptimizedRandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override v /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -91,31 +91,31 @@ class OptimizedRandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override v * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = set(impurity, value) + def setImpurity(value: String): this.type = set(impurity, value) // Parameters from TreeEnsembleParams: /** @group setParam */ @Since("1.4.0") - override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) // Parameters from RandomForestParams: /** @group setParam */ @Since("1.4.0") - override def setNumTrees(value: Int): this.type = set(numTrees, value) + def setNumTrees(value: Int): this.type = set(numTrees, value) /** @group setParam */ @Since("1.4.0") - override def setFeatureSubsetStrategy(value: String): this.type = + def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) /** @group setParam */ @@ -181,7 +181,7 @@ class OptimizedRandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override v object OptimizedRandomForestRegressor extends DefaultParamsReadable[OptimizedRandomForestRegressor]{ /** Accessor for supported impurity settings: variance */ @Since("1.4.0") - final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities + final val supportedImpurities: Array[String] = Array("variance") /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ @Since("1.4.0") diff --git a/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala new file mode 100644 index 0000000..19bd10a --- /dev/null +++ b/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import java.io.IOException + +import scala.collection.mutable + +import org.apache.hadoop.fs.Path + +import org.apache.spark.internal.Logging +import org.apache.spark.ml.tree.{LearningNode, Split} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + + +/** + * This is used by the node id cache to find the child id that a data point would belong to. + * @param split Split information. + * @param nodeIndex The current node index of a data point that this will update. + */ +private[tree] case class NodeIndexUpdater(split: Split, nodeIndex: Int) { + + /** + * Determine a child node index based on the feature value and the split. + * @param binnedFeature Binned feature value. + * @param splits Split information to convert the bin indices to approximate feature values. + * @return Child node index to update to. + */ + def updateNodeIndex(binnedFeature: Int, splits: Array[Split]): Int = { + if (split.shouldGoLeft(binnedFeature, splits)) { + LearningNode.leftChildIndex(nodeIndex) + } else { + LearningNode.rightChildIndex(nodeIndex) + } + } +} + +/** + * Each TreePoint belongs to a particular node per tree. + * Each row in the nodeIdsForInstances RDD is an array over trees of the node index + * in each tree. Initially, values should all be 1 for root node. + * The nodeIdsForInstances RDD needs to be updated at each iteration. + * @param nodeIdsForInstances The initial values in the cache + * (should be an Array of all 1's (meaning the root nodes)). + * @param checkpointInterval The checkpointing interval + * (how often should the cache be checkpointed.). + */ +private[spark] class NodeIdCache( + var nodeIdsForInstances: RDD[Array[Int]], + val checkpointInterval: Int) extends Logging { + + // Keep a reference to a previous node Ids for instances. + // Because we will keep on re-persisting updated node Ids, + // we want to unpersist the previous RDD. + private var prevNodeIdsForInstances: RDD[Array[Int]] = null + + // To keep track of the past checkpointed RDDs. + private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]() + private var rddUpdateCount = 0 + + // Indicates whether we can checkpoint + private val canCheckpoint = nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty + + // Hadoop Configuration for deleting checkpoints as needed + private val hadoopConf = nodeIdsForInstances.sparkContext.hadoopConfiguration + + /** + * Update the node index values in the cache. + * This updates the RDD and its lineage. + * TODO: Passing bin information to executors seems unnecessary and costly. + * @param data The RDD of training rows. + * @param nodeIdUpdaters A map of node index updaters. + * The key is the indices of nodes that we want to update. + * @param splits Split information needed to find child node indices. + */ + def updateNodeIndices( + data: RDD[BaggedPoint[TreePoint]], + nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]], + splits: Array[Array[Split]]): Unit = { + if (prevNodeIdsForInstances != null) { + // Unpersist the previous one if one exists. + prevNodeIdsForInstances.unpersist(false) + } + + prevNodeIdsForInstances = nodeIdsForInstances + nodeIdsForInstances = data.zip(nodeIdsForInstances).map { case (point, ids) => + var treeId = 0 + while (treeId < nodeIdUpdaters.length) { + val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(ids(treeId), null) + if (nodeIdUpdater != null) { + val featureIndex = nodeIdUpdater.split.featureIndex + val newNodeIndex = nodeIdUpdater.updateNodeIndex( + binnedFeature = point.datum.binnedFeatures(featureIndex), + splits = splits(featureIndex)) + ids(treeId) = newNodeIndex + } + treeId += 1 + } + ids + } + + // Keep on persisting new ones. + nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK) + rddUpdateCount += 1 + + // Handle checkpointing if the directory is not None. + if (canCheckpoint && checkpointInterval != -1 && (rddUpdateCount % checkpointInterval) == 0) { + // Let's see if we can delete previous checkpoints. + var canDelete = true + while (checkpointQueue.size > 1 && canDelete) { + // We can delete the oldest checkpoint iff + // the next checkpoint actually exists in the file system. + if (checkpointQueue(1).getCheckpointFile.isDefined) { + val old = checkpointQueue.dequeue() + // Since the old checkpoint is not deleted by Spark, we'll manually delete it here. + try { + val path = new Path(old.getCheckpointFile.get) + val fs = path.getFileSystem(hadoopConf) + fs.delete(path, true) + } catch { + case e: IOException => + logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" + + s" file: ${old.getCheckpointFile.get}") + } + } else { + canDelete = false + } + } + + nodeIdsForInstances.checkpoint() + checkpointQueue.enqueue(nodeIdsForInstances) + } + } + + /** + * Call this after training is finished to delete any remaining checkpoints. + */ + def deleteAllCheckpoints(): Unit = { + while (checkpointQueue.nonEmpty) { + val old = checkpointQueue.dequeue() + if (old.getCheckpointFile.isDefined) { + try { + val path = new Path(old.getCheckpointFile.get) + val fs = path.getFileSystem(hadoopConf) + fs.delete(path, true) + } catch { + case e: IOException => + logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" + + s" file: ${old.getCheckpointFile.get}") + } + } + } + if (nodeIdsForInstances != null) { + // Unpersist current one if one exists. + nodeIdsForInstances.unpersist(false) + } + if (prevNodeIdsForInstances != null) { + // Unpersist the previous one if one exists. + prevNodeIdsForInstances.unpersist(false) + } + } +} + +private[spark] object NodeIdCache { + /** + * Initialize the node Id cache with initial node Id values. + * @param data The RDD of training rows. + * @param numTrees The number of trees that we want to create cache for. + * @param checkpointInterval The checkpointing interval + * (how often should the cache be checkpointed.). + * @param initVal The initial values in the cache. + * @return A node Id cache containing an RDD of initial root node Indices. + */ + def init( + data: RDD[BaggedPoint[TreePoint]], + numTrees: Int, + checkpointInterval: Int, + initVal: Int = 1): NodeIdCache = { + new NodeIdCache( + data.map(_ => Array.fill[Int](numTrees)(initVal)), + checkpointInterval) + } +} diff --git a/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedDTStatsAggregator.scala b/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedDTStatsAggregator.scala index da1e890..369374e 100644 --- a/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedDTStatsAggregator.scala +++ b/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedDTStatsAggregator.scala @@ -108,14 +108,14 @@ class OptimizedDTStatsAggregator( */ def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = { val i = featureOffsets(featureIndex) + binIndex * statsSize - impurityAggregator.update(allStats, i, label, instanceWeight) + impurityAggregator.update(allStats, i, label, 1, instanceWeight) } /** * Update the parent node stats using the given label. */ def updateParent(label: Double, instanceWeight: Double): Unit = { - impurityAggregator.update(parentStats, 0, label, instanceWeight) + impurityAggregator.update(parentStats, 0, label, 1, instanceWeight) } /** @@ -130,8 +130,7 @@ class OptimizedDTStatsAggregator( binIndex: Int, label: Double, instanceWeight: Double): Unit = { - impurityAggregator.update(allStats, featureOffset + binIndex * statsSize, - label, instanceWeight) + impurityAggregator.update(allStats, featureOffset + binIndex * statsSize, label, 1, instanceWeight) } /** diff --git a/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedDecisionTreeMetadata.scala b/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedDecisionTreeMetadata.scala index 0860ecc..e2fcf77 100644 --- a/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedDecisionTreeMetadata.scala +++ b/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedDecisionTreeMetadata.scala @@ -44,6 +44,7 @@ import scala.util.Try class OptimizedDecisionTreeMetadata( numFeatures: Int, numExamples: Long, + weightedNumExamples: Double, numClasses: Int, maxBins: Int, featureArity: Map[Int, Int], @@ -53,10 +54,27 @@ class OptimizedDecisionTreeMetadata( quantileStrategy: QuantileStrategy, maxDepth: Int, minInstancesPerNode: Int, + minWeightFractionPerNode: Double, minInfoGain: Double, numTrees: Int, - numFeaturesPerNode: Int) extends DecisionTreeMetadata(numFeatures, - numExamples, numClasses, maxBins, featureArity, unorderedFeatures, numBins, impurity, quantileStrategy, maxDepth, minInstancesPerNode, minInfoGain, numTrees, numFeaturesPerNode) with Serializable { + numFeaturesPerNode: Int) extends DecisionTreeMetadata( + numFeatures, + numExamples, + weightedNumExamples, + numClasses, + maxBins, + featureArity, + unorderedFeatures, + numBins, + impurity, + quantileStrategy, + maxDepth, + minInstancesPerNode, + minWeightFractionPerNode, + minInfoGain, + numTrees, + numFeaturesPerNode +) with Serializable { } object OptimizedDecisionTreeMetadata extends Logging { @@ -78,7 +96,11 @@ object OptimizedDecisionTreeMetadata extends Logging { } require(numFeatures > 0, s"DecisionTree requires number of features > 0, " + s"but was given an empty features vector") - val numExamples = input.count() + val (numExamples, weightSum) = input.aggregate((0L, 0.0))( + seqOp = (cw, instance) => (cw._1 + 1L, cw._2 + instance.weight), + combOp = (cw1, cw2) => (cw1._1 + cw2._1, cw1._2 + cw2._2) + ) + val numClasses = strategy.algo match { case Classification => strategy.numClasses case Regression => 0 @@ -169,10 +191,11 @@ object OptimizedDecisionTreeMetadata extends Logging { } } - new OptimizedDecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, - strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, + new OptimizedDecisionTreeMetadata(numFeatures, numExamples, weightSum, numClasses, + numBins.max, strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth, - strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode) + strategy.minInstancesPerNode, strategy.minWeightFractionPerNode, strategy.minInfoGain, + numTrees, numFeaturesPerNode) } /** diff --git a/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForest.scala b/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForest.scala index 02aaffe..61d113f 100755 --- a/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForest.scala +++ b/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForest.scala @@ -19,15 +19,12 @@ package org.apache.spark.ml.tree.impl -import java.io.IOException - import org.apache.spark.Partitioner import org.apache.spark.internal.Logging import org.apache.spark.ml.classification.OptimizedDecisionTreeClassificationModel -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.regression.OptimizedDecisionTreeRegressionModel import org.apache.spark.ml.tree._ -import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.util.Instrumentation import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, OptimizedForestStrategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator @@ -36,6 +33,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} +import java.io.IOException import scala.collection.{SeqView, mutable} import scala.util.{Random, Try} @@ -174,7 +172,7 @@ private[spark] object OptimizedRandomForest extends Logging { val withReplacement = numTrees > 1 val baggedInput = BaggedPoint - .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, seed) + .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, (treePoint: TreePoint) => treePoint.weight, seed = seed) .persist(StorageLevel.MEMORY_AND_DISK) val distributedMaxDepth = Math.min(strategy.maxDepth, 30) @@ -340,11 +338,11 @@ private[spark] object OptimizedRandomForest extends Logging { pointsWithNodeIds.flatMap { case (baggedPoint, nodeIdsForTree) => nodeSetsBc.value.keys - .filter(treeId => baggedPoint.subsampleWeights(treeId) > 0) + .filter(treeId => baggedPoint.subsampleCounts(treeId) > 0) .map(treeId => (treeId, nodeIdsForTree(treeId))) .filter { case (treeId, nodeId) => nodeSetsBc.value(treeId).contains(nodeId) } .map { case (treeId, nodeId) => - ((treeId, nodeId), (baggedPoint.datum, baggedPoint.subsampleWeights(treeId) * baggedPoint.datum.sampleWeight)) + ((treeId, nodeId), (baggedPoint.datum, baggedPoint.subsampleCounts(treeId) * baggedPoint.datum.sampleWeight)) } } } @@ -676,7 +674,7 @@ private[spark] object OptimizedRandomForest extends Logging { if (nodeInfo != null) { val aggNodeIndex = nodeInfo.nodeIndexInGroup val featuresForNode = nodeInfo.featureSubset - val instanceWeight = baggedPoint.subsampleWeights(treeIndex) * baggedPoint.datum.sampleWeight + val instanceWeight = baggedPoint.subsampleCounts(treeIndex) * baggedPoint.datum.sampleWeight if (metadata.unorderedFeatures.isEmpty) { orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) } else { @@ -879,11 +877,11 @@ private[spark] object OptimizedRandomForest extends Logging { // enqueue left child and right child if they are not leaves if (!leftChildIsLeaf) { - addTrainingTask(node.leftChild.get, treeIndex, stats.leftImpurityCalculator.count, + addTrainingTask(node.leftChild.get, treeIndex, stats.leftImpurityCalculator.count.toLong, nodeLevel, stats.leftImpurity) } if (!rightChildIsLeaf) { - addTrainingTask(node.rightChild.get, treeIndex, stats.rightImpurityCalculator.count, + addTrainingTask(node.rightChild.get, treeIndex, stats.rightImpurityCalculator.count.toLong, nodeLevel, stats.rightImpurity) } diff --git a/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedTreePoint.scala b/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedTreePoint.scala index 6b89734..245bc16 100644 --- a/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedTreePoint.scala +++ b/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedTreePoint.scala @@ -40,7 +40,7 @@ import org.apache.spark.rdd.RDD * Same length as LabeledPoint.features, but values are bin indices. */ class OptimizedTreePoint(label: Double, binnedFeatures: Array[Int], val sampleWeight: Double) - extends TreePoint(label, binnedFeatures) with Serializable { + extends TreePoint(label, binnedFeatures, sampleWeight) with Serializable { } object OptimizedTreePoint { diff --git a/src/test/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifierSuite.scala b/src/test/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifierSuite.scala index 5e242ce..7dd52eb 100755 --- a/src/test/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifierSuite.scala +++ b/src/test/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifierSuite.scala @@ -22,12 +22,8 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.tree.impl.{OptimizedRandomForestSuite, OptimizedTreeTests, TreeTests} +import org.apache.spark.ml.tree.impl.{OptimizedRandomForestSuite, OptimizedTreeTests} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} -import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} -import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} @@ -54,7 +50,7 @@ class OptimizedRandomForestClassifierSuite extends MLTest with DefaultReadWriteT // Tests calling train() ///////////////////////////////////////////////////////////////////////////// - def binaryClassificationTestWithContinuousFeatures(rf: RandomForestClassifier, orf: OptimizedRandomForestClassifier) { + def binaryClassificationTestWithContinuousFeatures(rf: RandomForestClassifier, orf: OptimizedRandomForestClassifier): Unit = { val categoricalFeatures = Map.empty[Int, Int] val numClasses = 2 val newRF = rf @@ -72,6 +68,7 @@ class OptimizedRandomForestClassifierSuite extends MLTest with DefaultReadWriteT compareAPIs(orderedInstances50_1000, newRF, optimizedRF, categoricalFeatures, numClasses) } + // Fixed test("Binary classification with continuous features:" + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { val rf = new RandomForestClassifier() @@ -79,6 +76,7 @@ class OptimizedRandomForestClassifierSuite extends MLTest with DefaultReadWriteT binaryClassificationTestWithContinuousFeatures(rf, orf) } + // Fixed test("Binary classification with continuous features and node Id cache:" + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { val rf = new RandomForestClassifier() diff --git a/src/test/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressorSuite.scala b/src/test/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressorSuite.scala index e796d0e..a9469e0 100755 --- a/src/test/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressorSuite.scala +++ b/src/test/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressorSuite.scala @@ -23,9 +23,6 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.tree.impl.{OptimizedRandomForestSuite, OptimizedTreeTests} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} -import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} -import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame @@ -49,7 +46,7 @@ class OptimizedRandomForestRegressorSuite extends MLTest with DefaultReadWriteTe // Tests calling train() ///////////////////////////////////////////////////////////////////////////// - def regressionTestWithContinuousFeatures(rf: RandomForestRegressor, orf: OptimizedRandomForestRegressor) { + def regressionTestWithContinuousFeatures(rf: RandomForestRegressor, orf: OptimizedRandomForestRegressor): Unit = { val categoricalFeaturesInfo = Map.empty[Int, Int] val newRF = rf .setImpurity("variance") @@ -68,6 +65,7 @@ class OptimizedRandomForestRegressorSuite extends MLTest with DefaultReadWriteTe compareAPIs(orderedInstances50_1000, newRF, optimizedRF, categoricalFeaturesInfo) } + // Fixed test("Regression with continuous features:" + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { val rf = new RandomForestRegressor() diff --git a/src/test/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeRegressor.scala b/src/test/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeRegressor.scala index 8c0968d..8586e8d 100755 --- a/src/test/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeRegressor.scala +++ b/src/test/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeRegressor.scala @@ -42,22 +42,21 @@ private[impl] final class LocalDecisionTreeRegressor(override val uid: String) def this() = this(Identifiable.randomUID("local_dtr")) // Override parameter setters from parent trait for Java API compatibility. - override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) - override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + def setMaxBins(value: Int): this.type = set(maxBins, value) - override def setMinInstancesPerNode(value: Int): this.type = - super.setMinInstancesPerNode(value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) - override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) - override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) - override def setImpurity(value: String): this.type = super.setImpurity(value) + def setImpurity(value: String): this.type = set(impurity, value) - override def setSeed(value: Long): this.type = super.setSeed(value) + def setSeed(value: Long): this.type = set(seed, value) - override def copy(extra: ParamMap): LocalDecisionTreeRegressor = defaultCopy(extra) + def copy(extra: ParamMap): LocalDecisionTreeRegressor = defaultCopy(extra) override protected def train(dataset: Dataset[_]): OptimizedDecisionTreeRegressionModel = { val categoricalFeatures: Map[Int, Int] = diff --git a/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForestSuite.scala b/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForestSuite.scala index a8defb3..a0a36a8 100755 --- a/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForestSuite.scala +++ b/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForestSuite.scala @@ -94,10 +94,10 @@ class OptimizedRandomForestSuite extends SparkFunSuite with MLlibTestSparkContex test("find splits for a continuous feature") { // find splits for normal case { - val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 200000, 0, 0, + val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 200000, 200000, 0, 0, Map(), Set(), Array(6), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0.0, 0, 0 ) val featureSamples = Array.fill(10000)(math.random).filter(_ != 0.0) val splits = OptimizedRandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) @@ -110,10 +110,10 @@ class OptimizedRandomForestSuite extends SparkFunSuite with MLlibTestSparkContex // SPARK-16957: Use midpoints for split values. { - val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 8, 0, 0, + val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 8, 8, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0.0, 0, 0 ) // TODO: Why doesn't this work after filtering 0.0? @@ -137,10 +137,10 @@ class OptimizedRandomForestSuite extends SparkFunSuite with MLlibTestSparkContex // find splits should not return identical splits // when there are not enough split candidates, reduce the number of splits in metadata { - val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 12, 0, 0, + val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 12, 12, 0, 0, Map(), Set(), Array(5), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0.0, 0, 0 ) val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(_.toDouble) val splits = OptimizedRandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) @@ -152,10 +152,10 @@ class OptimizedRandomForestSuite extends SparkFunSuite with MLlibTestSparkContex // find splits when most samples close to the minimum { - val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 18, 0, 0, + val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 18, 18, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0.0, 0, 0 ) val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5) .map(_.toDouble) @@ -166,10 +166,10 @@ class OptimizedRandomForestSuite extends SparkFunSuite with MLlibTestSparkContex // find splits when most samples close to the maximum { - val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 17, 0, 0, + val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 17, 17, 0, 0, Map(), Set(), Array(2), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0.0, 0, 0 ) val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) .map(_.toDouble).filter(_ != 0.0) @@ -180,10 +180,10 @@ class OptimizedRandomForestSuite extends SparkFunSuite with MLlibTestSparkContex // find splits for constant feature { - val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 3, 0, 0, + val fakeMetadata = new OptimizedDecisionTreeMetadata(1, 3, 3, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0.0, 0, 0 ) val featureSamples = Array(0, 0, 0).map(_.toDouble).filter(_ != 0.0) val featureSamplesEmpty = Array.empty[Double] diff --git a/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedTreeTests.scala b/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedTreeTests.scala index f407819..eefd4dc 100755 --- a/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedTreeTests.scala +++ b/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedTreeTests.scala @@ -200,10 +200,10 @@ private[ml] object OptimizedTreeTests extends SparkFunSuite { } new OptimizedDecisionTreeMetadata(numFeatures = numFeatures, numExamples = numExamples, - numClasses = numClasses, maxBins = maxBins, minInfoGain = 0.0, featureArity = featureArity, - unorderedFeatures = unordered, numBins = numBins, impurity = impurity, + weightedNumExamples = numExamples, numClasses = numClasses, maxBins = maxBins, minInfoGain = 0.0, + featureArity = featureArity, unorderedFeatures = unordered, numBins = numBins, impurity = impurity, quantileStrategy = null, maxDepth = 5, minInstancesPerNode = 1, numTrees = 1, - numFeaturesPerNode = 2) + numFeaturesPerNode = 2, minWeightFractionPerNode = 0.0) } /** @@ -240,15 +240,13 @@ private[ml] object OptimizedTreeTests extends SparkFunSuite { * make mistakes such as creating loops of Nodes. */ private def checkEqual(a: Node, b: OptimizedNode): Unit = { - assert(a.prediction === b.prediction) - assert(a.impurity === b.impurity) -// assert(a.impurityStats.stats === b.impurityStats.stats) (a, b) match { - case (aye: InternalNode, bee: OptimizedInternalNode) => - assert(aye.split === bee.split) - checkEqual(aye.leftChild, bee.leftChild) - checkEqual(aye.rightChild, bee.rightChild) - case (aye: LeafNode, bee: OptimizedLeafNode) => // do nothing + case (aa: InternalNode, bb: OptimizedInternalNode) => + assert(aa.split === bb.split) + checkEqual(aa.leftChild, bb.leftChild) + checkEqual(aa.rightChild, bb.rightChild) + case (aa: LeafNode, bb: OptimizedLeafNode) => // do nothing + assert(aa.prediction === bb.prediction) case _ => println(a.getClass.getCanonicalName, b.getClass.getCanonicalName) throw new AssertionError("Found mismatched nodes") @@ -264,11 +262,11 @@ private[ml] object OptimizedTreeTests extends SparkFunSuite { assert(a.impurity === b.impurity) // assert(a.impurityStats.stats === b.impurityStats.stats) (a, b) match { - case (aye: OptimizedInternalNode, bee: OptimizedInternalNode) => - assert(aye.split === bee.split) - checkEqual(aye.leftChild, bee.leftChild) - checkEqual(aye.rightChild, bee.rightChild) - case (aye: OptimizedLeafNode, bee: OptimizedLeafNode) => // do nothing + case (aa: OptimizedInternalNode, bb: OptimizedInternalNode) => + assert(aa.split === bb.split) + checkEqual(aa.leftChild, bb.leftChild) + checkEqual(aa.rightChild, bb.rightChild) + case (_: OptimizedLeafNode, _: OptimizedLeafNode) => // do nothing case _ => println(a.getClass.getCanonicalName, b.getClass.getCanonicalName) throw new AssertionError("Found mismatched nodes")