Repository: spark
Updated Branches:
  refs/heads/master 075a0b658 -> 1bc435ae3


[SPARK-10064] [ML] Parallelize decision tree bin split calculations

Reimplement `DecisionTree.findSplitsBins` via `RDD` to parallelize bin 
calculation.

With large feature spaces the current implementation is very slow. This change 
limits the features that are distributed (or collected) to just the continuous 
features, and performs the split calculations in parallel. It completes on a 
real multi terabyte dataset in less than a minute instead of multiple hours.

Author: Nathan Howell <[email protected]>

Closes #8246 from NathanHowell/SPARK-10064.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1bc435ae
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1bc435ae
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1bc435ae

Branch: refs/heads/master
Commit: 1bc435ae3afb7a007b8a8ff00dcad4738a9ff055
Parents: 075a0b6
Author: Nathan Howell <[email protected]>
Authored: Wed Oct 7 17:46:16 2015 -0700
Committer: Joseph K. Bradley <[email protected]>
Committed: Wed Oct 7 17:46:16 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/mllib/tree/DecisionTree.scala  | 164 ++++++++++---------
 .../spark/mllib/tree/impl/NodeIdCache.scala     |  18 +-
 .../spark/mllib/tree/DecisionTreeSuite.scala    |   6 -
 .../spark/mllib/tree/EnsembleTestHelper.scala   |   4 +-
 4 files changed, 97 insertions(+), 95 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1bc435ae/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 4a77d4a..53d6482 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -19,7 +19,6 @@ package org.apache.spark.mllib.tree
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
-import scala.collection.mutable.ArrayBuilder
 
 import org.apache.spark.Logging
 import org.apache.spark.annotation.{Experimental, Since}
@@ -643,8 +642,8 @@ object DecisionTree extends Serializable with Logging {
 
     val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => 
a.merge(b))
         .map { case (nodeIndex, aggStats) =>
-          val featuresForNode = nodeToFeaturesBc.value.flatMap { 
nodeToFeatures =>
-            Some(nodeToFeatures(nodeIndex))
+          val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures =>
+            nodeToFeatures(nodeIndex)
           }
 
           // find best split for each node
@@ -976,8 +975,8 @@ object DecisionTree extends Serializable with Logging {
     val numFeatures = metadata.numFeatures
 
     // Sample the input only if there are continuous features.
-    val hasContinuousFeatures = Range(0, 
numFeatures).exists(metadata.isContinuous)
-    val sampledInput = if (hasContinuousFeatures) {
+    val continuousFeatures = Range(0, 
numFeatures).filter(metadata.isContinuous)
+    val sampledInput = if (continuousFeatures.nonEmpty) {
       // Calculate the number of samples for approximate quantile calculation.
       val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 
10000)
       val fraction = if (requiredSamples < metadata.numExamples) {
@@ -986,81 +985,14 @@ object DecisionTree extends Serializable with Logging {
         1.0
       }
       logDebug("fraction of data used for calculating quantiles = " + fraction)
-      input.sample(withReplacement = false, fraction, new 
XORShiftRandom().nextInt()).collect()
+      input.sample(withReplacement = false, fraction, new 
XORShiftRandom().nextInt())
     } else {
-      new Array[LabeledPoint](0)
+      input.sparkContext.emptyRDD[LabeledPoint]
     }
 
     metadata.quantileStrategy match {
       case Sort =>
-        val splits = new Array[Array[Split]](numFeatures)
-        val bins = new Array[Array[Bin]](numFeatures)
-
-        // Find all splits.
-        // Iterate over all features.
-        var featureIndex = 0
-        while (featureIndex < numFeatures) {
-          if (metadata.isContinuous(featureIndex)) {
-            val featureSamples = sampledInput.map(lp => 
lp.features(featureIndex))
-            val featureSplits = findSplitsForContinuousFeature(featureSamples,
-              metadata, featureIndex)
-
-            val numSplits = featureSplits.length
-            val numBins = numSplits + 1
-            logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits")
-            splits(featureIndex) = new Array[Split](numSplits)
-            bins(featureIndex) = new Array[Bin](numBins)
-
-            var splitIndex = 0
-            while (splitIndex < numSplits) {
-              val threshold = featureSplits(splitIndex)
-              splits(featureIndex)(splitIndex) =
-                new Split(featureIndex, threshold, Continuous, List())
-              splitIndex += 1
-            }
-            bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, 
Continuous),
-              splits(featureIndex)(0), Continuous, Double.MinValue)
-
-            splitIndex = 1
-            while (splitIndex < numSplits) {
-              bins(featureIndex)(splitIndex) =
-                new Bin(splits(featureIndex)(splitIndex - 1), 
splits(featureIndex)(splitIndex),
-                  Continuous, Double.MinValue)
-              splitIndex += 1
-            }
-            bins(featureIndex)(numSplits) = new 
Bin(splits(featureIndex)(numSplits - 1),
-              new DummyHighSplit(featureIndex, Continuous), Continuous, 
Double.MinValue)
-          } else {
-            val numSplits = metadata.numSplits(featureIndex)
-            val numBins = metadata.numBins(featureIndex)
-            // Categorical feature
-            val featureArity = metadata.featureArity(featureIndex)
-            if (metadata.isUnordered(featureIndex)) {
-              // Unordered features
-              // 2^(maxFeatureValue - 1) - 1 combinations
-              splits(featureIndex) = new Array[Split](numSplits)
-              var splitIndex = 0
-              while (splitIndex < numSplits) {
-                val categories: List[Double] =
-                  extractMultiClassCategories(splitIndex + 1, featureArity)
-                splits(featureIndex)(splitIndex) =
-                  new Split(featureIndex, Double.MinValue, Categorical, 
categories)
-                splitIndex += 1
-              }
-            } else {
-              // Ordered features
-              //   Bins correspond to feature values, so we do not need to 
compute splits or bins
-              //   beforehand.  Splits are constructed as needed during 
training.
-              splits(featureIndex) = new Array[Split](0)
-            }
-            // For ordered features, bins correspond to feature values.
-            // For unordered categorical features, there is no need to 
construct the bins.
-            // since there is a one-to-one correspondence between the splits 
and the bins.
-            bins(featureIndex) = new Array[Bin](0)
-          }
-          featureIndex += 1
-        }
-        (splits, bins)
+        findSplitsBinsBySorting(sampledInput, metadata, continuousFeatures)
       case MinMax =>
         throw new UnsupportedOperationException("minmax not supported yet.")
       case ApproxHist =>
@@ -1068,6 +1000,82 @@ object DecisionTree extends Serializable with Logging {
     }
   }
 
+  private def findSplitsBinsBySorting(
+      input: RDD[LabeledPoint],
+      metadata: DecisionTreeMetadata,
+      continuousFeatures: IndexedSeq[Int]): (Array[Array[Split]], 
Array[Array[Bin]]) = {
+    def findSplits(
+        featureIndex: Int,
+        featureSamples: Iterable[Double]): (Int, (Array[Split], Array[Bin])) = 
{
+      val splits = {
+        val featureSplits = findSplitsForContinuousFeature(
+          featureSamples.toArray,
+          metadata,
+          featureIndex)
+        logDebug(s"featureIndex = $featureIndex, numSplits = 
${featureSplits.length}")
+
+        featureSplits.map(threshold => new Split(featureIndex, threshold, 
Continuous, Nil))
+      }
+
+      val bins = {
+        val lowSplit = new DummyLowSplit(featureIndex, Continuous)
+        val highSplit = new DummyHighSplit(featureIndex, Continuous)
+
+        // tack the dummy splits on either side of the computed splits
+        val allSplits = lowSplit +: splits.toSeq :+ highSplit
+
+        // slide across the split points pairwise to allocate the bins
+        allSplits.sliding(2).map {
+          case Seq(left, right) => new Bin(left, right, Continuous, 
Double.MinValue)
+        }.toArray
+      }
+
+      (featureIndex, (splits, bins))
+    }
+
+    val continuousSplits = {
+      // reduce the parallelism for split computations when there are less
+      // continuous features than input partitions. this prevents tasks from
+      // being spun up that will definitely do no work.
+      val numPartitions = math.min(continuousFeatures.length, 
input.partitions.length)
+
+      input
+        .flatMap(point => continuousFeatures.map(idx => (idx, 
point.features(idx))))
+        .groupByKey(numPartitions)
+        .map { case (k, v) => findSplits(k, v) }
+        .collectAsMap()
+    }
+
+    val numFeatures = metadata.numFeatures
+    val (splits, bins) = Range(0, numFeatures).unzip {
+      case i if metadata.isContinuous(i) =>
+        val (split, bin) = continuousSplits(i)
+        metadata.setNumSplits(i, split.length)
+        (split, bin)
+
+      case i if metadata.isCategorical(i) && metadata.isUnordered(i) =>
+        // Unordered features
+        // 2^(maxFeatureValue - 1) - 1 combinations
+        val featureArity = metadata.featureArity(i)
+        val split = Range(0, metadata.numSplits(i)).map { splitIndex =>
+          val categories = extractMultiClassCategories(splitIndex + 1, 
featureArity)
+          new Split(i, Double.MinValue, Categorical, categories)
+        }
+
+        // For unordered categorical features, there is no need to construct 
the bins.
+        // since there is a one-to-one correspondence between the splits and 
the bins.
+        (split.toArray, Array.empty[Bin])
+
+      case i if metadata.isCategorical(i) =>
+        // Ordered features
+        //   Bins correspond to feature values, so we do not need to compute 
splits or bins
+        //   beforehand.  Splits are constructed as needed during training.
+        (Array.empty[Split], Array.empty[Bin])
+    }
+
+    (splits.toArray, bins.toArray)
+  }
+
   /**
    * Nested method to extract list of eligible categories given an index. It 
extracts the
    * position of ones in a binary representation of the input. If binary
@@ -1131,7 +1139,7 @@ object DecisionTree extends Serializable with Logging {
         logDebug("stride = " + stride)
 
         // iterate `valueCount` to find splits
-        val splitsBuilder = ArrayBuilder.make[Double]
+        val splitsBuilder = Array.newBuilder[Double]
         var index = 1
         // currentCount: sum of counts of values that have been visited
         var currentCount = valueCounts(0)._2
@@ -1163,8 +1171,8 @@ object DecisionTree extends Serializable with Logging {
     assert(splits.length > 0,
       s"DecisionTree could not handle feature $featureIndex since it had only 
1 unique value." +
         "  Please remove this feature and then try again.")
-    // set number of splits accordingly
-    metadata.setNumSplits(featureIndex, splits.length)
+
+    // the split metadata must be updated on the driver
 
     splits
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/1bc435ae/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
index 0abed54..1c61197 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
@@ -108,21 +108,21 @@ private[spark] class NodeIdCache(
 
     prevNodeIdsForInstances = nodeIdsForInstances
     nodeIdsForInstances = data.zip(nodeIdsForInstances).map {
-      dataPoint => {
+      case (point, node) => {
         var treeId = 0
         while (treeId < nodeIdUpdaters.length) {
-          val nodeIdUpdater = 
nodeIdUpdaters(treeId).getOrElse(dataPoint._2(treeId), null)
+          val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(node(treeId), 
null)
           if (nodeIdUpdater != null) {
             val newNodeIndex = nodeIdUpdater.updateNodeIndex(
-              binnedFeatures = dataPoint._1.datum.binnedFeatures,
+              binnedFeatures = point.datum.binnedFeatures,
               bins = bins)
-            dataPoint._2(treeId) = newNodeIndex
+            node(treeId) = newNodeIndex
           }
 
           treeId += 1
         }
 
-        dataPoint._2
+        node
       }
     }
 
@@ -138,7 +138,7 @@ private[spark] class NodeIdCache(
       while (checkpointQueue.size > 1 && canDelete) {
         // We can delete the oldest checkpoint iff
         // the next checkpoint actually exists in the file system.
-        if (checkpointQueue.get(1).get.getCheckpointFile != None) {
+        if (checkpointQueue.get(1).get.getCheckpointFile.isDefined) {
           val old = checkpointQueue.dequeue()
 
           // Since the old checkpoint is not deleted by Spark,
@@ -159,11 +159,11 @@ private[spark] class NodeIdCache(
    * Call this after training is finished to delete any remaining checkpoints.
    */
   def deleteAllCheckpoints(): Unit = {
-    while (checkpointQueue.size > 0) {
+    while (checkpointQueue.nonEmpty) {
       val old = checkpointQueue.dequeue()
-      if (old.getCheckpointFile != None) {
+      for (checkpointFile <- old.getCheckpointFile) {
         val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
-        fs.delete(new Path(old.getCheckpointFile.get), true)
+        fs.delete(new Path(checkpointFile), true)
       }
     }
     if (prevNodeIdsForInstances != null) {

http://git-wip-us.apache.org/repos/asf/spark/blob/1bc435ae/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 356d957..1a4299d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -135,8 +135,6 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
       val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 
3).map(_.toDouble)
       val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, 
fakeMetadata, 0)
       assert(splits.length === 3)
-      assert(fakeMetadata.numSplits(0) === 3)
-      assert(fakeMetadata.numBins(0) === 4)
       // check returned splits are distinct
       assert(splits.distinct.length === splits.length)
     }
@@ -151,8 +149,6 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
       val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 
5).map(_.toDouble)
       val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, 
fakeMetadata, 0)
       assert(splits.length === 2)
-      assert(fakeMetadata.numSplits(0) === 2)
-      assert(fakeMetadata.numBins(0) === 3)
       assert(splits(0) === 2.0)
       assert(splits(1) === 3.0)
     }
@@ -167,8 +163,6 @@ class DecisionTreeSuite extends SparkFunSuite with 
MLlibTestSparkContext {
       val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 
2).map(_.toDouble)
       val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, 
fakeMetadata, 0)
       assert(splits.length === 1)
-      assert(fakeMetadata.numSplits(0) === 1)
-      assert(fakeMetadata.numBins(0) === 2)
       assert(splits(0) === 1.0)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/1bc435ae/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
index 334bf37..3d3f800 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
@@ -69,8 +69,8 @@ object EnsembleTestHelper {
       required: Double,
       metricName: String = "mse") {
     val predictions = input.map(x => model.predict(x.features))
-    val errors = predictions.zip(input.map(_.label)).map { case (prediction, 
label) =>
-      label - prediction
+    val errors = predictions.zip(input).map { case (prediction, point) =>
+      point.label - prediction
     }
     val metric = metricName match {
       case "mse" =>


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to