Repository: spark
Updated Branches:
  refs/heads/master b39e80d39 -> 1614485fd


[SPARK-10788][MLLIB][ML] Remove duplicate bins for decision trees

Decision trees in spark.ml (RandomForest.scala) communicate twice as much data 
as needed for unordered categorical features. Here's an example.

Say there are 3 categories A, B, C. We consider 3 splits:

* A vs. B, C
* A, B vs. C
* A, C vs. B

Currently, we collect statistics for each of the 6 subsets of categories (3 * 2 
= 6). However, we could instead collect statistics for the 3 subsets on the 
left-hand side of the 3 possible splits: A and A,B and A,C. If we also have 
stats for the entire node, then we can compute the stats for the 3 subsets on 
the right-hand side of the splits. In pseudomath: stats(B,C) = stats(A,B,C) - 
stats(A).

This patch adds a parent stats array to the `DTStatsAggregator` so that the 
right child stats do not need to be stored. The right child stats are computed 
by subtracting left child stats from the parent stats for unordered categorical 
features.

Author: sethah <[email protected]>

Closes #9474 from sethah/SPARK-10788.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1614485f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1614485f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1614485f

Branch: refs/heads/master
Commit: 1614485fd92fc94bc3989da49be612e542b93fb8
Parents: b39e80d
Author: sethah <[email protected]>
Authored: Thu Mar 17 16:44:41 2016 -0700
Committer: Joseph K. Bradley <[email protected]>
Committed: Thu Mar 17 16:44:41 2016 -0700

----------------------------------------------------------------------
 .../spark/ml/tree/impl/RandomForest.scala       | 15 ++---
 .../apache/spark/mllib/tree/DecisionTree.scala  | 15 ++---
 .../mllib/tree/impl/DTStatsAggregator.scala     | 59 ++++++++++++--------
 .../mllib/tree/impl/DecisionTreeMetadata.scala  |  6 +-
 .../spark/mllib/tree/impurity/Entropy.scala     |  1 -
 .../apache/spark/mllib/tree/impurity/Gini.scala |  1 -
 .../spark/mllib/tree/impurity/Impurity.scala    |  1 -
 .../spark/mllib/tree/impurity/Variance.scala    |  1 -
 .../spark/mllib/tree/DecisionTreeSuite.scala    |  4 ++
 9 files changed, 54 insertions(+), 49 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1614485f/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 91dc985..dd9a5f2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -244,8 +244,7 @@ private[ml] object RandomForest extends Logging {
       if (unorderedFeatures.contains(featureIndex)) {
         // Unordered feature
         val featureValue = treePoint.binnedFeatures(featureIndex)
-        val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
-          agg.getLeftRightFeatureOffsets(featureIndexIdx)
+        val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx)
         // Update the left or right bin for each split.
         val numSplits = agg.metadata.numSplits(featureIndex)
         val featureSplits = splits(featureIndex)
@@ -253,8 +252,6 @@ private[ml] object RandomForest extends Logging {
         while (splitIndex < numSplits) {
           if (featureSplits(splitIndex).shouldGoLeft(featureValue, 
featureSplits)) {
             agg.featureUpdate(leftNodeFeatureOffset, splitIndex, 
treePoint.label, instanceWeight)
-          } else {
-            agg.featureUpdate(rightNodeFeatureOffset, splitIndex, 
treePoint.label, instanceWeight)
           }
           splitIndex += 1
         }
@@ -394,6 +391,7 @@ private[ml] object RandomForest extends Logging {
           mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
             metadata.unorderedFeatures, instanceWeight, featuresForNode)
         }
+        agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight)
       }
     }
 
@@ -658,7 +656,7 @@ private[ml] object RandomForest extends Logging {
 
     // Calculate InformationGain and ImpurityStats if current node is top node
     val level = LearningNode.indexToLevel(node.id)
-    var gainAndImpurityStats: ImpurityStats = if (level ==0) {
+    var gainAndImpurityStats: ImpurityStats = if (level == 0) {
       null
     } else {
       node.stats
@@ -697,13 +695,12 @@ private[ml] object RandomForest extends Logging {
           (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
         } else if (binAggregates.metadata.isUnordered(featureIndex)) {
           // Unordered categorical feature
-          val (leftChildOffset, rightChildOffset) =
-            binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
+          val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx)
           val (bestFeatureSplitIndex, bestFeatureGainStats) =
             Range(0, numSplits).map { splitIndex =>
               val leftChildStats = 
binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
-              val rightChildStats =
-                binAggregates.getImpurityCalculator(rightChildOffset, 
splitIndex)
+              val rightChildStats = binAggregates.getParentImpurityCalculator()
+                .subtract(leftChildStats)
               gainAndImpurityStats = 
calculateImpurityStats(gainAndImpurityStats,
                 leftChildStats, rightChildStats, binAggregates.metadata)
               (splitIndex, gainAndImpurityStats)

http://git-wip-us.apache.org/repos/asf/spark/blob/1614485f/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 18f66e6..c0934d2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -52,6 +52,7 @@ class DecisionTree @Since("1.0.0") (private val strategy: 
Strategy)
 
   /**
    * Method to train a decision tree model over an RDD
+   *
    * @param input Training data: RDD of 
[[org.apache.spark.mllib.regression.LabeledPoint]].
    * @return DecisionTreeModel that can be used for prediction.
    */
@@ -368,8 +369,7 @@ object DecisionTree extends Serializable with Logging {
       if (unorderedFeatures.contains(featureIndex)) {
         // Unordered feature
         val featureValue = treePoint.binnedFeatures(featureIndex)
-        val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
-          agg.getLeftRightFeatureOffsets(featureIndexIdx)
+        val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx)
         // Update the left or right bin for each split.
         val numSplits = agg.metadata.numSplits(featureIndex)
         var splitIndex = 0
@@ -377,9 +377,6 @@ object DecisionTree extends Serializable with Logging {
           if 
(splits(featureIndex)(splitIndex).categories.contains(featureValue)) {
             agg.featureUpdate(leftNodeFeatureOffset, splitIndex, 
treePoint.label,
               instanceWeight)
-          } else {
-            agg.featureUpdate(rightNodeFeatureOffset, splitIndex, 
treePoint.label,
-              instanceWeight)
           }
           splitIndex += 1
         }
@@ -521,6 +518,7 @@ object DecisionTree extends Serializable with Logging {
           mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
             metadata.unorderedFeatures, instanceWeight, featuresForNode)
         }
+        agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight)
       }
     }
 
@@ -847,13 +845,12 @@ object DecisionTree extends Serializable with Logging {
           (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
         } else if (binAggregates.metadata.isUnordered(featureIndex)) {
           // Unordered categorical feature
-          val (leftChildOffset, rightChildOffset) =
-            binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
+          val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx)
           val (bestFeatureSplitIndex, bestFeatureGainStats) =
             Range(0, numSplits).map { splitIndex =>
               val leftChildStats = 
binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
-              val rightChildStats =
-                binAggregates.getImpurityCalculator(rightChildOffset, 
splitIndex)
+              val rightChildStats = binAggregates.getParentImpurityCalculator()
+                .subtract(leftChildStats)
               predictWithImpurity = Some(predictWithImpurity.getOrElse(
                 calculatePredictImpurity(leftChildStats, rightChildStats)))
               val gainStats = calculateGainForSplit(leftChildStats,

http://git-wip-us.apache.org/repos/asf/spark/blob/1614485f/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
index 7985ed4..c745e9f 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -73,26 +73,34 @@ private[spark] class DTStatsAggregator(
    * Flat array of elements.
    * Index for start of stats for a (feature, bin) is:
    *   index = featureOffsets(featureIndex) + binIndex * statsSize
-   * Note: For unordered features,
-   *       the left child stats have binIndex in [0, numBins(featureIndex) / 
2))
-   *       and the right child stats in [numBins(featureIndex) / 2), 
numBins(featureIndex))
    */
   private val allStats: Array[Double] = new Array[Double](allStatsSize)
 
+  /**
+   * Array of parent node sufficient stats.
+   *
+   * Note: this is necessary because stats for the parent node are not 
available
+   *       on the first iteration of tree learning.
+   */
+  private val parentStats: Array[Double] = new Array[Double](statsSize)
 
   /**
    * Get an [[ImpurityCalculator]] for a given (node, feature, bin).
-   * @param featureOffset  For ordered features, this is a pre-computed (node, 
feature) offset
+   * @param featureOffset  This is a pre-computed (node, feature) offset
    *                           from [[getFeatureOffset]].
-   *                           For unordered features, this is a pre-computed
-   *                           (node, feature, left/right child) offset from
-   *                           [[getLeftRightFeatureOffsets]].
    */
   def getImpurityCalculator(featureOffset: Int, binIndex: Int): 
ImpurityCalculator = {
     impurityAggregator.getCalculator(allStats, featureOffset + binIndex * 
statsSize)
   }
 
   /**
+   * Get an [[ImpurityCalculator]] for the parent node.
+   */
+  def getParentImpurityCalculator(): ImpurityCalculator = {
+    impurityAggregator.getCalculator(parentStats, 0)
+  }
+
+  /**
    * Update the stats for a given (feature, bin) for ordered features, using 
the given label.
    */
   def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: 
Double): Unit = {
@@ -101,13 +109,17 @@ private[spark] class DTStatsAggregator(
   }
 
   /**
+   * Update the parent node stats using the given label.
+   */
+  def updateParent(label: Double, instanceWeight: Double): Unit = {
+    impurityAggregator.update(parentStats, 0, label, instanceWeight)
+  }
+
+  /**
    * Faster version of [[update]].
    * Update the stats for a given (feature, bin), using the given label.
-   * @param featureOffset  For ordered features, this is a pre-computed 
feature offset
+   * @param featureOffset  This is a pre-computed feature offset
    *                           from [[getFeatureOffset]].
-   *                           For unordered features, this is a pre-computed
-   *                           (feature, left/right child) offset from
-   *                           [[getLeftRightFeatureOffsets]].
    */
   def featureUpdate(
       featureOffset: Int,
@@ -125,21 +137,9 @@ private[spark] class DTStatsAggregator(
   def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex)
 
   /**
-   * Pre-compute feature offset for use with [[featureUpdate]].
-   * For unordered features only.
-   */
-  def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = {
-    val baseOffset = featureOffsets(featureIndex)
-    (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
-  }
-
-  /**
    * For a given feature, merge the stats for two bins.
-   * @param featureOffset  For ordered features, this is a pre-computed 
feature offset
+   * @param featureOffset  This is a pre-computed feature offset
    *                           from [[getFeatureOffset]].
-   *                           For unordered features, this is a pre-computed
-   *                           (feature, left/right child) offset from
-   *                           [[getLeftRightFeatureOffsets]].
    * @param binIndex  The other bin is merged into this bin.
    * @param otherBinIndex  This bin is not modified.
    */
@@ -162,6 +162,17 @@ private[spark] class DTStatsAggregator(
       allStats(i) += other.allStats(i)
       i += 1
     }
+
+    require(statsSize == other.statsSize,
+      s"DTStatsAggregator.merge requires that both aggregators have the same 
length parent " +
+        s"stats vectors. This aggregator's parent stats are length $statsSize, 
" +
+        s"but the other is ${other.statsSize}.")
+    var j = 0
+    while (j < statsSize) {
+      parentStats(j) += other.parentStats(j)
+      j += 1
+    }
+
     this
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1614485f/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index df13d29..4f27dc4 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -67,11 +67,11 @@ private[spark] class DecisionTreeMetadata(
 
   /**
    * Number of splits for the given feature.
-   * For unordered features, there are 2 bins per split.
+   * For unordered features, there is 1 bin per split.
    * For ordered features, there is 1 more bin than split.
    */
   def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) {
-    numBins(featureIndex) >> 1
+    numBins(featureIndex)
   } else {
     numBins(featureIndex) - 1
   }
@@ -212,6 +212,6 @@ private[spark] object DecisionTreeMetadata extends Logging {
    * there are math.pow(2, arity - 1) - 1 such splits.
    * Each split has 2 corresponding bins.
    */
-  def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1)
+  def numUnorderedBins(arity: Int): Int = (1 << arity - 1) - 1
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1614485f/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 73df6b0..13aff11 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -113,7 +113,6 @@ private[tree] class EntropyAggregator(numClasses: Int)
   def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = 
{
     new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray)
   }
-
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/1614485f/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index f21845b..39c7f9c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -109,7 +109,6 @@ private[tree] class GiniAggregator(numClasses: Int)
   def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = {
     new GiniCalculator(allStats.view(offset, offset + statsSize).toArray)
   }
-
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/1614485f/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index b2c6e2b..65f0163 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -89,7 +89,6 @@ private[spark] abstract class ImpurityAggregator(val 
statsSize: Int) extends Ser
    * @param offset    Start index of stats for this (node, feature, bin).
    */
   def getCalculator(allStats: Array[Double], offset: Int): ImpurityCalculator
-
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/1614485f/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index 09017d4..92d74a1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -93,7 +93,6 @@ private[tree] class VarianceAggregator()
   def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator 
= {
     new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray)
   }
-
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/1614485f/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 5518bdf..89b64fc 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -189,6 +189,10 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     assert(bins.length === 2)
     assert(splits(0).length === 3)
     assert(bins(0).length === 0)
+    assert(metadata.numSplits(0) === 3)
+    assert(metadata.numBins(0) === 3)
+    assert(metadata.numSplits(1) === 3)
+    assert(metadata.numBins(1) === 3)
 
     // Expecting 2^2 - 1 = 3 bins/splits
     assert(splits(0)(0).feature === 0)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to