Repository: spark
Updated Branches:
  refs/heads/master 1c90347a4 -> 2e4eae3a5


[SPARK-3366][MLLIB]Compute best splits distributively in decision tree

Currently, all best splits are computed on the driver, which makes the driver a 
bottleneck for both communication and computation. This PR fix this problem by 
computed best splits on executors.
Instead of send all aggregate stats to the driver node, we can send aggregate 
stats for a node to a particular executor, using `reduceByKey` operation, then 
we can compute best split for this node there.

Implementation details:

Each node now has a nodeStatsAggregator, which save aggregate stats for all 
features and bins.
First use mapPartition to compute node aggregate stats for all nodes in each 
partition.
Then transform node aggregate stats to (nodeIndex, nodeStatsAggregator) pairs 
and use to `reduceByKey` operation to combine nodeStatsAggregator for the same 
node.
After all stats have been combined, best splits can be computed for each node 
based on the node aggregate stats. Best split result is collected to driver to 
construct the decision tree.

CC: mengxr manishamde jkbradley, please help me review this, thanks.

Author: qiping.lqp <[email protected]>
Author: chouqin <[email protected]>

Closes #2595 from chouqin/dt-dist-agg and squashes the following commits:

db0d24a [chouqin] fix a minor bug and adjust code
a0d9de3 [chouqin] adjust code based on comments
9f201a6 [chouqin] fix bug: statsSize -> allStatsSize
a8a7ed0 [chouqin] Merge branch 'master' of https://github.com/apache/spark into 
dt-dist-agg
f13b346 [chouqin] adjust randomforest comments
c32636e [chouqin] adjust code based on comments
ac6a505 [chouqin] adjust code based on comments
7bbb787 [chouqin] add comments
bdd2a63 [qiping.lqp] fix test suite
a75df27 [qiping.lqp] fix test suite
b5b0bc2 [qiping.lqp] fix style
e76414f [qiping.lqp] fix testsuite
748bd45 [qiping.lqp] fix type-mismatch bug
24eacd8 [qiping.lqp] fix type-mismatch bug
5f63d6c [qiping.lqp] add multiclassification using One-Vs-All strategy
4f56496 [qiping.lqp] fix bug
f00fc22 [qiping.lqp] fix bug
532993a [qiping.lqp] Compute best splits distributively in decision tree


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2e4eae3a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2e4eae3a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2e4eae3a

Branch: refs/heads/master
Commit: 2e4eae3a52e3d04895b00447d1ac56ae3c1b98ae
Parents: 1c90347
Author: qiping.lqp <[email protected]>
Authored: Fri Oct 3 03:26:17 2014 -0700
Committer: Xiangrui Meng <[email protected]>
Committed: Fri Oct 3 03:26:17 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/mllib/tree/DecisionTree.scala  | 140 +++++----
 .../apache/spark/mllib/tree/RandomForest.scala  |   5 +-
 .../mllib/tree/impl/DTStatsAggregator.scala     | 292 +++++--------------
 .../mllib/tree/model/InformationGainStats.scala |  11 +
 .../spark/mllib/tree/RandomForestSuite.scala    |   1 +
 5 files changed, 182 insertions(+), 267 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2e4eae3a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index b7dc373..b311d10 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -23,7 +23,6 @@ import scala.collection.mutable
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.Logging
-import org.apache.spark.mllib.rdd.RDDFunctions._
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
 import org.apache.spark.mllib.tree.configuration.Strategy
@@ -36,6 +35,7 @@ import org.apache.spark.mllib.tree.impurity._
 import org.apache.spark.mllib.tree.model._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.util.random.XORShiftRandom
+import org.apache.spark.SparkContext._
 
 
 /**
@@ -328,9 +328,8 @@ object DecisionTree extends Serializable with Logging {
    * for each subset is updated.
    *
    * @param agg  Array storing aggregate calculation, with a set of sufficient 
statistics for
-   *             each (node, feature, bin).
+   *             each (feature, bin).
    * @param treePoint  Data point being aggregated.
-   * @param nodeIndex  Node corresponding to treePoint.  agg is indexed in [0, 
numNodes).
    * @param bins possible bins for all features, indexed (numFeatures)(numBins)
    * @param unorderedFeatures  Set of indices of unordered features.
    * @param instanceWeight  Weight (importance) of instance in dataset.
@@ -338,7 +337,6 @@ object DecisionTree extends Serializable with Logging {
   private def mixedBinSeqOp(
       agg: DTStatsAggregator,
       treePoint: TreePoint,
-      nodeIndex: Int,
       bins: Array[Array[Bin]],
       unorderedFeatures: Set[Int],
       instanceWeight: Double,
@@ -350,7 +348,6 @@ object DecisionTree extends Serializable with Logging {
       // Use all features
       agg.metadata.numFeatures
     }
-    val nodeOffset = agg.getNodeOffset(nodeIndex)
     // Iterate over features.
     var featureIndexIdx = 0
     while (featureIndexIdx < numFeaturesPerNode) {
@@ -363,16 +360,16 @@ object DecisionTree extends Serializable with Logging {
         // Unordered feature
         val featureValue = treePoint.binnedFeatures(featureIndex)
         val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
-          agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx)
+          agg.getLeftRightFeatureOffsets(featureIndexIdx)
         // Update the left or right bin for each split.
         val numSplits = agg.metadata.numSplits(featureIndex)
         var splitIndex = 0
         while (splitIndex < numSplits) {
           if 
(bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) {
-            agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, 
treePoint.label,
+            agg.featureUpdate(leftNodeFeatureOffset, splitIndex, 
treePoint.label,
               instanceWeight)
           } else {
-            agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, 
treePoint.label,
+            agg.featureUpdate(rightNodeFeatureOffset, splitIndex, 
treePoint.label,
               instanceWeight)
           }
           splitIndex += 1
@@ -380,8 +377,7 @@ object DecisionTree extends Serializable with Logging {
       } else {
         // Ordered feature
         val binIndex = treePoint.binnedFeatures(featureIndex)
-        agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, 
treePoint.label,
-          instanceWeight)
+        agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight)
       }
       featureIndexIdx += 1
     }
@@ -393,26 +389,24 @@ object DecisionTree extends Serializable with Logging {
    * For each feature, the sufficient statistics of one bin are updated.
    *
    * @param agg  Array storing aggregate calculation, with a set of sufficient 
statistics for
-   *             each (node, feature, bin).
+   *             each (feature, bin).
    * @param treePoint  Data point being aggregated.
-   * @param nodeIndex  Node corresponding to treePoint.  agg is indexed in [0, 
numNodes).
    * @param instanceWeight  Weight (importance) of instance in dataset.
    */
   private def orderedBinSeqOp(
       agg: DTStatsAggregator,
       treePoint: TreePoint,
-      nodeIndex: Int,
       instanceWeight: Double,
       featuresForNode: Option[Array[Int]]): Unit = {
     val label = treePoint.label
-    val nodeOffset = agg.getNodeOffset(nodeIndex)
+
     // Iterate over features.
     if (featuresForNode.nonEmpty) {
       // Use subsampled features
       var featureIndexIdx = 0
       while (featureIndexIdx < featuresForNode.get.size) {
         val binIndex = 
treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
-        agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, 
label, instanceWeight)
+        agg.update(featureIndexIdx, binIndex, label, instanceWeight)
         featureIndexIdx += 1
       }
     } else {
@@ -421,7 +415,7 @@ object DecisionTree extends Serializable with Logging {
       var featureIndex = 0
       while (featureIndex < numFeatures) {
         val binIndex = treePoint.binnedFeatures(featureIndex)
-        agg.nodeUpdate(nodeOffset, nodeIndex, featureIndex, binIndex, label, 
instanceWeight)
+        agg.update(featureIndex, binIndex, label, instanceWeight)
         featureIndex += 1
       }
     }
@@ -496,8 +490,8 @@ object DecisionTree extends Serializable with Logging {
      * @return  agg
      */
     def binSeqOp(
-        agg: DTStatsAggregator,
-        baggedPoint: BaggedPoint[TreePoint]): DTStatsAggregator = {
+        agg: Array[DTStatsAggregator],
+        baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
       treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
         val nodeIndex = predictNodeIndex(topNodes(treeIndex), 
baggedPoint.datum.binnedFeatures,
           bins, metadata.unorderedFeatures)
@@ -508,9 +502,9 @@ object DecisionTree extends Serializable with Logging {
           val featuresForNode = nodeInfo.featureSubset
           val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
           if (metadata.unorderedFeatures.isEmpty) {
-            orderedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, 
instanceWeight, featuresForNode)
+            orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, 
instanceWeight, featuresForNode)
           } else {
-            mixedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, bins, 
metadata.unorderedFeatures,
+            mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, 
metadata.unorderedFeatures,
               instanceWeight, featuresForNode)
           }
         }
@@ -518,30 +512,76 @@ object DecisionTree extends Serializable with Logging {
       agg
     }
 
-    // Calculate bin aggregates.
-    timer.start("aggregation")
-    val binAggregates: DTStatsAggregator = {
-      val initAgg = if (metadata.subsamplingFeatures) {
-        new DTStatsAggregatorSubsampledFeatures(metadata, 
treeToNodeToIndexInfo)
-      } else {
-        new DTStatsAggregatorFixedFeatures(metadata, numNodes)
+    /**
+     * Get node index in group --> features indices map,
+     * which is a short cut to find feature indices for a node given node 
index in group
+     * @param treeToNodeToIndexInfo
+     * @return
+     */
+    def getNodeToFeatures(treeToNodeToIndexInfo: Map[Int, Map[Int, 
NodeIndexInfo]])
+      : Option[Map[Int, Array[Int]]] = if (!metadata.subsamplingFeatures) {
+      None
+    } else {
+      val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]()
+      treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo =>
+        nodeIdToNodeInfo.values.foreach { nodeIndexInfo =>
+          assert(nodeIndexInfo.featureSubset.isDefined)
+          mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = 
nodeIndexInfo.featureSubset.get
+        }
       }
-      input.treeAggregate(initAgg)(binSeqOp, DTStatsAggregator.binCombOp)
+      Some(mutableNodeToFeatures.toMap)
     }
-    timer.stop("aggregation")
 
     // Calculate best splits for all nodes in the group
     timer.start("chooseSplits")
 
+    // In each partition, iterate all instances and compute aggregate stats 
for each node,
+    // yield an (nodeIndex, nodeAggregateStats) pair for each node.
+    // After a `reduceByKey` operation,
+    // stats of a node will be shuffled to a particular partition and be 
combined together,
+    // then best splits for nodes are found there.
+    // Finally, only best Splits for nodes are collected to driver to 
construct decision tree.
+    val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
+    val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
+    val nodeToBestSplits =
+      input.mapPartitions { points =>
+        // Construct a nodeStatsAggregators array to hold node aggregate stats,
+        // each node will have a nodeStatsAggregator
+        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
+          val featuresForNode = nodeToFeaturesBc.value.flatMap { 
nodeToFeatures =>
+            Some(nodeToFeatures(nodeIndex))
+          }
+          new DTStatsAggregator(metadata, featuresForNode)
+        }
+
+        // iterator all instances in current partition and update aggregate 
stats
+        points.foreach(binSeqOp(nodeStatsAggregators, _))
+
+        // transform nodeStatsAggregators array to (nodeIndex, 
nodeAggregateStats) pairs,
+        // which can be combined with other partition using `reduceByKey`
+        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
+      }.reduceByKey((a, b) => a.merge(b))
+        .map { case (nodeIndex, aggStats) =>
+          val featuresForNode = nodeToFeaturesBc.value.flatMap { 
nodeToFeatures =>
+            Some(nodeToFeatures(nodeIndex))
+          }
+
+          // find best split for each node
+          val (split: Split, stats: InformationGainStats, predict: Predict) =
+            binsToBestSplit(aggStats, splits, featuresForNode)
+          (nodeIndex, (split, stats, predict))
+        }.collectAsMap()
+
+    timer.stop("chooseSplits")
+
     // Iterate over all nodes in this group.
     nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
       nodesForTree.foreach { node =>
         val nodeIndex = node.id
         val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
         val aggNodeIndex = nodeInfo.nodeIndexInGroup
-        val featuresForNode = nodeInfo.featureSubset
         val (split: Split, stats: InformationGainStats, predict: Predict) =
-          binsToBestSplit(binAggregates, aggNodeIndex, splits, featuresForNode)
+          nodeToBestSplits(aggNodeIndex)
         logDebug("best split = " + split)
 
         // Extract info for this node.  Create children if not leaf.
@@ -565,7 +605,7 @@ object DecisionTree extends Serializable with Logging {
         }
       }
     }
-    timer.stop("chooseSplits")
+
   }
 
   /**
@@ -633,36 +673,33 @@ object DecisionTree extends Serializable with Logging {
   /**
    * Find the best split for a node.
    * @param binAggregates Bin statistics.
-   * @param nodeIndex Index into aggregates for node to split in this group.
    * @return tuple for best split: (Split, information gain, prediction at 
node)
    */
   private def binsToBestSplit(
       binAggregates: DTStatsAggregator,
-      nodeIndex: Int,
       splits: Array[Array[Split]],
       featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, 
Predict) = {
 
-    val metadata: DecisionTreeMetadata = binAggregates.metadata
-
     // calculate predict only once
     var predict: Option[Predict] = None
 
     // For each (feature, split), calculate the gain, and select the best 
(feature, split).
-    val (bestSplit, bestSplitStats) = Range(0, 
metadata.numFeaturesPerNode).map { featureIndexIdx =>
+    val (bestSplit, bestSplitStats) =
+      Range(0, binAggregates.metadata.numFeaturesPerNode).map { 
featureIndexIdx =>
       val featureIndex = if (featuresForNode.nonEmpty) {
         featuresForNode.get.apply(featureIndexIdx)
       } else {
         featureIndexIdx
       }
-      val numSplits = metadata.numSplits(featureIndex)
-      if (metadata.isContinuous(featureIndex)) {
+      val numSplits = binAggregates.metadata.numSplits(featureIndex)
+      if (binAggregates.metadata.isContinuous(featureIndex)) {
         // Cumulative sum (scanLeft) of bin statistics.
         // Afterwards, binAggregates for a bin is the sum of aggregates for
         // that bin + all preceding bins.
-        val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, 
featureIndexIdx)
+        val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
         var splitIndex = 0
         while (splitIndex < numSplits) {
-          binAggregates.mergeForNodeFeature(nodeFeatureOffset, splitIndex + 1, 
splitIndex)
+          binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, 
splitIndex)
           splitIndex += 1
         }
         // Find best split.
@@ -672,27 +709,29 @@ object DecisionTree extends Serializable with Logging {
             val rightChildStats = 
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
             rightChildStats.subtract(leftChildStats)
             predict = Some(predict.getOrElse(calculatePredict(leftChildStats, 
rightChildStats)))
-            val gainStats = calculateGainForSplit(leftChildStats, 
rightChildStats, metadata)
+            val gainStats = calculateGainForSplit(leftChildStats,
+              rightChildStats, binAggregates.metadata)
             (splitIdx, gainStats)
           }.maxBy(_._2.gain)
         (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
-      } else if (metadata.isUnordered(featureIndex)) {
+      } else if (binAggregates.metadata.isUnordered(featureIndex)) {
         // Unordered categorical feature
         val (leftChildOffset, rightChildOffset) =
-          binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, 
featureIndexIdx)
+          binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
         val (bestFeatureSplitIndex, bestFeatureGainStats) =
           Range(0, numSplits).map { splitIndex =>
             val leftChildStats = 
binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
             val rightChildStats = 
binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
             predict = Some(predict.getOrElse(calculatePredict(leftChildStats, 
rightChildStats)))
-            val gainStats = calculateGainForSplit(leftChildStats, 
rightChildStats, metadata)
+            val gainStats = calculateGainForSplit(leftChildStats,
+              rightChildStats, binAggregates.metadata)
             (splitIndex, gainStats)
           }.maxBy(_._2.gain)
         (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
       } else {
         // Ordered categorical feature
-        val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, 
featureIndexIdx)
-        val numBins = metadata.numBins(featureIndex)
+        val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+        val numBins = binAggregates.metadata.numBins(featureIndex)
 
         /* Each bin is one category (feature value).
          * The bins are ordered based on centroidForCategories, and this 
ordering determines which
@@ -700,7 +739,7 @@ object DecisionTree extends Serializable with Logging {
          *
          * centroidForCategories is a list: (category, centroid)
          */
-        val centroidForCategories = if (metadata.isMulticlass) {
+        val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
           // For categorical variables in multiclass classification,
           // the bins are ordered by the impurity of their corresponding 
labels.
           Range(0, numBins).map { case featureValue =>
@@ -741,7 +780,7 @@ object DecisionTree extends Serializable with Logging {
         while (splitIndex < numSplits) {
           val currentCategory = categoriesSortedByCentroid(splitIndex)._1
           val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
-          binAggregates.mergeForNodeFeature(nodeFeatureOffset, nextCategory, 
currentCategory)
+          binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, 
currentCategory)
           splitIndex += 1
         }
         // lastCategory = index of bin with total aggregates for this (node, 
feature)
@@ -756,7 +795,8 @@ object DecisionTree extends Serializable with Logging {
               binAggregates.getImpurityCalculator(nodeFeatureOffset, 
lastCategory)
             rightChildStats.subtract(leftChildStats)
             predict = Some(predict.getOrElse(calculatePredict(leftChildStats, 
rightChildStats)))
-            val gainStats = calculateGainForSplit(leftChildStats, 
rightChildStats, metadata)
+            val gainStats = calculateGainForSplit(leftChildStats,
+              rightChildStats, binAggregates.metadata)
             (splitIndex, gainStats)
           }.maxBy(_._2.gain)
         val categoriesForSplit =

http://git-wip-us.apache.org/repos/asf/spark/blob/2e4eae3a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 7fa7725..fa7a26f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -171,8 +171,8 @@ private class RandomForest (
 
       // Choose node splits, and enqueue new nodes as needed.
       timer.start("findBestSplits")
-      DecisionTree.findBestSplits(baggedInput,
-        metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, 
bins, nodeQueue, timer)
+      DecisionTree.findBestSplits(baggedInput, metadata, topNodes, 
nodesForGroup,
+        treeToNodeToIndexInfo, splits, bins, nodeQueue, timer)
       timer.stop("findBestSplits")
     }
 
@@ -382,6 +382,7 @@ object RandomForest extends Serializable with Logging {
    * @param maxMemoryUsage  Bound on size of aggregate statistics.
    * @return  (nodesForGroup, treeToNodeToIndexInfo).
    *          nodesForGroup holds the nodes to split: treeIndex --> nodes in 
tree.
+   *
    *          treeToNodeToIndexInfo holds indices selected features for each 
node:
    *            treeIndex --> (global) node index --> (node index in group, 
feature indices).
    *          The (global) node index is the index in the tree; the node index 
in group is the

http://git-wip-us.apache.org/repos/asf/spark/blob/2e4eae3a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
index d49df7a..55f422d 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -17,17 +17,19 @@
 
 package org.apache.spark.mllib.tree.impl
 
-import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
 import org.apache.spark.mllib.tree.impurity._
 
+
+
 /**
- * DecisionTree statistics aggregator.
- * This holds a flat array of statistics for a set of (nodes, features, bins)
+ * DecisionTree statistics aggregator for a node.
+ * This holds a flat array of statistics for a set of (features, bins)
  * and helps with indexing.
  * This class is abstract to support learning with and without feature 
subsampling.
  */
-private[tree] abstract class DTStatsAggregator(
-    val metadata: DecisionTreeMetadata) extends Serializable {
+private[tree] class DTStatsAggregator(
+    val metadata: DecisionTreeMetadata,
+    featureSubset: Option[Array[Int]]) extends Serializable {
 
   /**
    * [[ImpurityAggregator]] instance specifying the impurity type.
@@ -42,7 +44,25 @@ private[tree] abstract class DTStatsAggregator(
   /**
    * Number of elements (Double values) used for the sufficient statistics of 
each bin.
    */
-  val statsSize: Int = impurityAggregator.statsSize
+  private val statsSize: Int = impurityAggregator.statsSize
+
+  /**
+   * Number of bins for each feature.  This is indexed by the feature index.
+   */
+  private val numBins: Array[Int] = {
+    if (featureSubset.isDefined) {
+      featureSubset.get.map(metadata.numBins(_))
+    } else {
+      metadata.numBins
+    }
+  }
+
+  /**
+   * Offset for each feature for calculating indices into the [[allStats]] 
array.
+   */
+  private val featureOffsets: Array[Int] = {
+    numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
+  }
 
   /**
    * Indicator for each feature of whether that feature is an unordered 
feature.
@@ -51,107 +71,95 @@ private[tree] abstract class DTStatsAggregator(
   def isUnordered(featureIndex: Int): Boolean = 
metadata.isUnordered(featureIndex)
 
   /**
-   * Total number of elements stored in this aggregator.
+   * Total number of elements stored in this aggregator
    */
-  def allStatsSize: Int
+  private val allStatsSize: Int = featureOffsets.last
 
   /**
-   * Get flat array of elements stored in this aggregator.
+   * Flat array of elements.
+   * Index for start of stats for a (feature, bin) is:
+   *   index = featureOffsets(featureIndex) + binIndex * statsSize
+   * Note: For unordered features,
+   *       the left child stats have binIndex in [0, numBins(featureIndex) / 
2))
+   *       and the right child stats in [numBins(featureIndex) / 2), 
numBins(featureIndex))
    */
-  protected def allStats: Array[Double]
+  private val allStats: Array[Double] = new Array[Double](allStatsSize)
+
 
   /**
    * Get an [[ImpurityCalculator]] for a given (node, feature, bin).
-   * @param nodeFeatureOffset  For ordered features, this is a pre-computed 
(node, feature) offset
-   *                           from [[getNodeFeatureOffset]].
+   * @param featureOffset  For ordered features, this is a pre-computed (node, 
feature) offset
+   *                           from [[getFeatureOffset]].
    *                           For unordered features, this is a pre-computed
    *                           (node, feature, left/right child) offset from
-   *                           [[getLeftRightNodeFeatureOffsets]].
+   *                           [[getLeftRightFeatureOffsets]].
    */
-  def getImpurityCalculator(nodeFeatureOffset: Int, binIndex: Int): 
ImpurityCalculator = {
-    impurityAggregator.getCalculator(allStats, nodeFeatureOffset + binIndex * 
statsSize)
+  def getImpurityCalculator(featureOffset: Int, binIndex: Int): 
ImpurityCalculator = {
+    impurityAggregator.getCalculator(allStats, featureOffset + binIndex * 
statsSize)
   }
 
   /**
-   * Update the stats for a given (node, feature, bin) for ordered features, 
using the given label.
+   * Update the stats for a given (feature, bin) for ordered features, using 
the given label.
    */
-  def update(
-      nodeIndex: Int,
-      featureIndex: Int,
-      binIndex: Int,
-      label: Double,
-      instanceWeight: Double): Unit = {
-    val i = getNodeFeatureOffset(nodeIndex, featureIndex) + binIndex * 
statsSize
+  def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: 
Double): Unit = {
+    val i = featureOffsets(featureIndex) + binIndex * statsSize
     impurityAggregator.update(allStats, i, label, instanceWeight)
   }
 
   /**
-   * Pre-compute node offset for use with [[nodeUpdate]].
-   */
-  def getNodeOffset(nodeIndex: Int): Int
-
-  /**
    * Faster version of [[update]].
-   * Update the stats for a given (node, feature, bin) for ordered features, 
using the given label.
-   * @param nodeOffset  Pre-computed node offset from [[getNodeOffset]].
+   * Update the stats for a given (feature, bin), using the given label.
+   * @param featureOffset  For ordered features, this is a pre-computed 
feature offset
+   *                           from [[getFeatureOffset]].
+   *                           For unordered features, this is a pre-computed
+   *                           (feature, left/right child) offset from
+   *                           [[getLeftRightFeatureOffsets]].
    */
-  def nodeUpdate(
-      nodeOffset: Int,
-      nodeIndex: Int,
-      featureIndex: Int,
+  def featureUpdate(
+      featureOffset: Int,
       binIndex: Int,
       label: Double,
-      instanceWeight: Double): Unit
+      instanceWeight: Double): Unit = {
+    impurityAggregator.update(allStats, featureOffset + binIndex * statsSize,
+      label, instanceWeight)
+  }
 
   /**
-   * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
+   * Pre-compute feature offset for use with [[featureUpdate]].
    * For ordered features only.
    */
-  def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int
+  def getFeatureOffset(featureIndex: Int): Int = {
+    require(!isUnordered(featureIndex),
+      s"DTStatsAggregator.getFeatureOffset is for ordered features only, but 
was called" +
+        s" for unordered feature $featureIndex.")
+    featureOffsets(featureIndex)
+  }
 
   /**
-   * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
+   * Pre-compute feature offset for use with [[featureUpdate]].
    * For unordered features only.
    */
-  def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, 
Int) = {
+  def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = {
     require(isUnordered(featureIndex),
-      s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered 
features only," +
+      s"DTStatsAggregator.getLeftRightFeatureOffsets is for unordered features 
only," +
         s" but was called for ordered feature $featureIndex.")
-    val baseOffset = getNodeFeatureOffset(nodeIndex, featureIndex)
-    (baseOffset, baseOffset + (metadata.numBins(featureIndex) >> 1) * 
statsSize)
-  }
-
-  /**
-   * Faster version of [[update]].
-   * Update the stats for a given (node, feature, bin), using the given label.
-   * @param nodeFeatureOffset  For ordered features, this is a pre-computed 
(node, feature) offset
-   *                           from [[getNodeFeatureOffset]].
-   *                           For unordered features, this is a pre-computed
-   *                           (node, feature, left/right child) offset from
-   *                           [[getLeftRightNodeFeatureOffsets]].
-   */
-  def nodeFeatureUpdate(
-      nodeFeatureOffset: Int,
-      binIndex: Int,
-      label: Double,
-      instanceWeight: Double): Unit = {
-    impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * 
statsSize, label,
-      instanceWeight)
+    val baseOffset = featureOffsets(featureIndex)
+    (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
   }
 
   /**
-   * For a given (node, feature), merge the stats for two bins.
-   * @param nodeFeatureOffset  For ordered features, this is a pre-computed 
(node, feature) offset
-   *                           from [[getNodeFeatureOffset]].
+   * For a given feature, merge the stats for two bins.
+   * @param featureOffset  For ordered features, this is a pre-computed 
feature offset
+   *                           from [[getFeatureOffset]].
    *                           For unordered features, this is a pre-computed
-   *                           (node, feature, left/right child) offset from
-   *                           [[getLeftRightNodeFeatureOffsets]].
+   *                           (feature, left/right child) offset from
+   *                           [[getLeftRightFeatureOffsets]].
    * @param binIndex  The other bin is merged into this bin.
    * @param otherBinIndex  This bin is not modified.
    */
-  def mergeForNodeFeature(nodeFeatureOffset: Int, binIndex: Int, 
otherBinIndex: Int): Unit = {
-    impurityAggregator.merge(allStats, nodeFeatureOffset + binIndex * 
statsSize,
-      nodeFeatureOffset + otherBinIndex * statsSize)
+  def mergeForFeature(featureOffset: Int, binIndex: Int, otherBinIndex: Int): 
Unit = {
+    impurityAggregator.merge(allStats, featureOffset + binIndex * statsSize,
+      featureOffset + otherBinIndex * statsSize)
   }
 
   /**
@@ -161,7 +169,7 @@ private[tree] abstract class DTStatsAggregator(
   def merge(other: DTStatsAggregator): DTStatsAggregator = {
     require(allStatsSize == other.allStatsSize,
       s"DTStatsAggregator.merge requires that both aggregators have the same 
length stats vectors."
-      + s" This aggregator is of length $allStatsSize, but the other is 
${other.allStatsSize}.")
+        + s" This aggregator is of length $allStatsSize, but the other is 
${other.allStatsSize}.")
     var i = 0
     // TODO: Test BLAS.axpy
     while (i < allStatsSize) {
@@ -171,149 +179,3 @@ private[tree] abstract class DTStatsAggregator(
     this
   }
 }
-
-/**
- * DecisionTree statistics aggregator.
- * This holds a flat array of statistics for a set of (nodes, features, bins)
- * and helps with indexing.
- *
- * This instance of [[DTStatsAggregator]] is used when not subsampling 
features.
- *
- * @param numNodes  Number of nodes to collect statistics for.
- */
-private[tree] class DTStatsAggregatorFixedFeatures(
-    metadata: DecisionTreeMetadata,
-    numNodes: Int) extends DTStatsAggregator(metadata) {
-
-  /**
-   * Offset for each feature for calculating indices into the [[allStats]] 
array.
-   * Mapping: featureIndex --> offset
-   */
-  private val featureOffsets: Array[Int] = {
-    metadata.numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
-  }
-
-  /**
-   * Number of elements for each node, corresponding to stride between nodes 
in [[allStats]].
-   */
-  private val nodeStride: Int = featureOffsets.last
-
-  override val allStatsSize: Int = numNodes * nodeStride
-
-  /**
-   * Flat array of elements.
-   * Index for start of stats for a (node, feature, bin) is:
-   *   index = nodeIndex * nodeStride + featureOffsets(featureIndex) + 
binIndex * statsSize
-   * Note: For unordered features, the left child stats precede the right 
child stats
-   *       in the binIndex order.
-   */
-  override protected val allStats: Array[Double] = new 
Array[Double](allStatsSize)
-
-  override def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride
-
-  override def nodeUpdate(
-      nodeOffset: Int,
-      nodeIndex: Int,
-      featureIndex: Int,
-      binIndex: Int,
-      label: Double,
-      instanceWeight: Double): Unit = {
-    val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize
-    impurityAggregator.update(allStats, i, label, instanceWeight)
-  }
-
-  override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
-    nodeIndex * nodeStride + featureOffsets(featureIndex)
-  }
-}
-
-/**
- * DecisionTree statistics aggregator.
- * This holds a flat array of statistics for a set of (nodes, features, bins)
- * and helps with indexing.
- *
- * This instance of [[DTStatsAggregator]] is used when subsampling features.
- *
- * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> 
nodeIndexInfo,
- *                              where nodeIndexInfo stores the index in the 
group and the
- *                              feature subsets (if using feature subsets).
- */
-private[tree] class DTStatsAggregatorSubsampledFeatures(
-    metadata: DecisionTreeMetadata,
-    treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]) extends 
DTStatsAggregator(metadata) {
-
-  /**
-   * For each node, offset for each feature for calculating indices into the 
[[allStats]] array.
-   * Mapping: nodeIndex --> featureIndex --> offset
-   */
-  private val featureOffsets: Array[Array[Int]] = {
-    val numNodes: Int = treeToNodeToIndexInfo.values.map(_.size).sum
-    val offsets = new Array[Array[Int]](numNodes)
-    treeToNodeToIndexInfo.foreach { case (treeIndex, nodeToIndexInfo) =>
-      nodeToIndexInfo.foreach { case (globalNodeIndex, nodeInfo) =>
-        offsets(nodeInfo.nodeIndexInGroup) = 
nodeInfo.featureSubset.get.map(metadata.numBins(_))
-          .scanLeft(0)((total, nBins) => total + statsSize * nBins)
-      }
-    }
-    offsets
-  }
-
-  /**
-   * For each node, offset for each feature for calculating indices into the 
[[allStats]] array.
-   */
-  protected val nodeOffsets: Array[Int] = 
featureOffsets.map(_.last).scanLeft(0)(_ + _)
-
-  override val allStatsSize: Int = nodeOffsets.last
-
-  /**
-   * Flat array of elements.
-   * Index for start of stats for a (node, feature, bin) is:
-   *   index = nodeOffsets(nodeIndex) + featureOffsets(featureIndex) + 
binIndex * statsSize
-   * Note: For unordered features, the left child stats precede the right 
child stats
-   *       in the binIndex order.
-   */
-  override protected val allStats: Array[Double] = new 
Array[Double](allStatsSize)
-
-  override def getNodeOffset(nodeIndex: Int): Int = nodeOffsets(nodeIndex)
-
-  /**
-   * Faster version of [[update]].
-   * Update the stats for a given (node, feature, bin) for ordered features, 
using the given label.
-   * @param nodeOffset  Pre-computed node offset from [[getNodeOffset]].
-   * @param featureIndex  Index of feature in featuresForNodes(nodeIndex).
-   *                      Note: This is NOT the original feature index.
-   */
-  override def nodeUpdate(
-      nodeOffset: Int,
-      nodeIndex: Int,
-      featureIndex: Int,
-      binIndex: Int,
-      label: Double,
-      instanceWeight: Double): Unit = {
-    val i = nodeOffset + featureOffsets(nodeIndex)(featureIndex) + binIndex * 
statsSize
-    impurityAggregator.update(allStats, i, label, instanceWeight)
-  }
-
-  /**
-   * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
-   * For ordered features only.
-   * @param featureIndex  Index of feature in featuresForNodes(nodeIndex).
-   *                      Note: This is NOT the original feature index.
-   */
-  override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
-    nodeOffsets(nodeIndex) + featureOffsets(nodeIndex)(featureIndex)
-  }
-}
-
-private[tree] object DTStatsAggregator extends Serializable {
-
-  /**
-   * Combines two aggregates (modifying the first) and returns the combination.
-   */
-  def binCombOp(
-      agg1: DTStatsAggregator,
-      agg2: DTStatsAggregator): DTStatsAggregator = {
-    agg1.merge(agg2)
-  }
-
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/2e4eae3a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index f3e2619..a89e71e 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -38,6 +38,17 @@ class InformationGainStats(
     "gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
       .format(gain, impurity, leftImpurity, rightImpurity)
   }
+
+  override def equals(o: Any) =
+    o match {
+      case other: InformationGainStats => {
+        gain == other.gain &&
+        impurity == other.impurity &&
+        leftImpurity == other.leftImpurity &&
+        rightImpurity == other.rightImpurity
+      }
+      case _ => false
+    }
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2e4eae3a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index 30669fc..20d372d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -145,6 +145,7 @@ class RandomForestSuite extends FunSuite with 
LocalSparkContext {
 
         assert(nodesForGroup.size === numTrees, failString)
         assert(nodesForGroup.values.forall(_.size == 1), failString) // 1 node 
per tree
+
         if (numFeaturesPerNode == numFeatures) {
           // featureSubset values should all be None
           
assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)),


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to