[SPARK-1545] [mllib] Add Random Forests

This PR adds RandomForest to MLlib.  The implementation is basic, and future 
performance optimizations will be important.  (Note: RFs = Random Forests.)

# Overview

## RandomForest
* trains multiple trees at once to reduce the number of passes over the data
* allows feature subsets at each node
* uses a queue of nodes instead of fixed groups for each level

This implementation is based an implementation by manishamde and the [Alpine 
Labs Sequoia Forest](https://github.com/AlpineNow/SparkML2) by codedeft (in 
particular, the TreePoint, BaggedPoint, and node queue implementations).  Thank 
you for your inputs!

## Testing

Correctness: This has been tested for correctness with the test suites and with 
DecisionTreeRunner on example datasets.

Performance: This has been performance tested using [this branch of 
spark-perf](https://github.com/jkbradley/spark-perf/tree/rfs).  Results below.

### Regression tests for DecisionTree

Summary: For training 1 tree, there are small regressions, especially from 
feature subsampling.

In the table below, each row is a single (random) dataset.  The 2 different 
sets of result columns are for 2 different RF implementations:
* (numTrees): This is from an earlier commit, after implementing RandomForest 
to train multiple trees at once.  It does not include any code for feature 
subsampling.
* (feature subsets): This is from this current PR's code, after implementing 
feature subsampling.
These tests were to identify regressions in DecisionTree, so they are training 
1 tree with all of the features (i.e., no feature subsampling).

These were run on an EC2 cluster with 15 workers, training 1 tree with maxDepth 
= 5 (= 6 levels).  Speedup values < 1 indicate slowdowns from the old 
DecisionTree implementation.

numInstances | numFeatures | runtime (sec) | speedup | runtime (sec) | speedup
---- | ---- | ---- | ---- | ---- | ----
 | | (numTrees) | (numTrees) | (feature subsets) | (feature subsets)
20000 | 100 | 4.051 | 1.044433473 | 4.478 | 0.9448414471
20000 | 500 | 8.472 | 1.104461756 | 9.315 | 1.004508857
20000 | 1500 | 19.354 | 1.05854087 | 20.863 | 0.9819776638
20000 | 3500 | 43.674 | 1.072033704 | 45.887 | 1.020332556
200000 | 100 | 4.196 | 1.171830315 | 4.848 | 1.014232673
200000 | 500 | 8.926 | 1.082791844 | 9.771 | 0.989151571
200000 | 1500 | 20.58 | 1.068415938 | 22.134 | 0.9934038131
200000 | 3500 | 48.043 | 1.075203464 | 52.249 | 0.9886505005
2000000 | 100 | 4.944 | 1.01355178 | 5.796 | 0.8645617667
2000000 | 500 | 11.11 | 1.016831683 | 12.482 | 0.9050632911
2000000 | 1500 | 31.144 | 1.017852556 | 35.274 | 0.8986789136
2000000 | 3500 | 79.981 | 1.085382778 | 101.105 | 0.8586123337
20000000 | 100 | 8.304 | 0.9270231214 | 9.073 | 0.8484514494
20000000 | 500 | 28.174 | 1.083268262 | 34.236 | 0.8914592826
20000000 | 1500 | 143.97 | 0.9579634646 | 159.275 | 0.8659111599

### Tests for forests

I have run other tests with numTrees=10 and with sqrt(numFeatures), and those 
indicate that multi-model training and feature subsets can speed up training 
for forests, especially when training deeper trees.

# Details on specific classes

## Changes to DecisionTree
* Main train() method is now in RandomForest.
* findBestSplits() is no longer needed.  (It split levels into groups, but we 
now use a queue of nodes.)
* Many small changes to support RFs.  (Note: These methods should be moved to 
RandomForest.scala in a later PR, but are in DecisionTree.scala to make code 
comparison easier.)

## RandomForest
* Main train() method is from old DecisionTree.
* selectNodesToSplit: Note that it selects nodes and feature subsets jointly to 
track memory usage.

## RandomForestModel
* Stores an Array[DecisionTreeModel]
* Prediction:
 * For classification, most common label.  For regression, mean.
 * We could support other methods later.

## examples/.../DecisionTreeRunner
* This now takes numTrees and featureSubsetStrategy, to support RFs.

## DTStatsAggregator
* 2 types of functionality (w/ and w/o subsampling features): These require 
different indexing methods.  (We could treat both as subsampling, but this is 
less efficient
  DTStatsAggregator is now abstract, and 2 child classes implement these 2 
types of functionality.

## impurities
* These now take instance weights.

## Node
* Some vals changed to vars.
 * This is unfortunately a public API change (DeveloperApi).  This could be 
avoided by creating a LearningNode struct, but would be awkward.

## RandomForestSuite
Please let me know if there are missing tests!

## BaggedPoint
This wraps TreePoint and holds bootstrap weights/counts.

# Design decisions

* BaggedPoint: BaggedPoint is separate from TreePoint since it may be useful 
for other bagging algorithms later on.

* RandomForest public API: What options should be easily supported by the 
train* methods?  Should ALL options be in the Java-friendly constructors?  
Should there be a constructor taking Strategy?

* Feature subsampling options: What options should be supported?  scikit-learn 
supports the same options, except for "onethird."  One option would be to allow 
users to specific fractions ("0.1"): the current options could be supported, 
and any unrecognized values would be parsed as Doubles in [0,1].

* Splits and bins are computed before bootstrapping, so all trees use the same 
discretization.

* One queue, instead of one queue per tree.

CC: mengxr manishamde codedeft chouqin  Please let me know if you have 
suggestions---thanks!

Author: Joseph K. Bradley <[email protected]>
Author: qiping.lqp <[email protected]>
Author: chouqin <[email protected]>

Closes #2435 from jkbradley/rfs-new and squashes the following commits:

c694174 [Joseph K. Bradley] Fixed typo
cc59d78 [Joseph K. Bradley] fixed imports
e25909f [Joseph K. Bradley] Simplified node group maps.  Specifically, created 
NodeIndexInfo to store node index in agg and feature subsets, and no longer 
create extra maps in findBestSplits
fbe9a1e [Joseph K. Bradley] Changed default featureSubsetStrategy to be sqrt 
for classification, onethird for regression.  Updated docs with references.
ef7c293 [Joseph K. Bradley] Updates based on code review.  Most substantial 
changes: * Simplified DTStatsAggregator * Made RandomForestModel.trees public * 
Added test for regression to RandomForestSuite
593b13c [Joseph K. Bradley] Fixed bug in metadata for computing log2(num 
features).  Now it checks >= 1.
a1a08df [Joseph K. Bradley] Removed old comments
866e766 [Joseph K. Bradley] Changed RandomForestSuite randomized tests to use 
multiple fixed random seeds.
ff8bb96 [Joseph K. Bradley] removed usage of null from RandomForest and 
replaced with Option
bf1a4c5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into 
rfs-new
6b79c07 [Joseph K. Bradley] Added RandomForestSuite, and fixed small bugs, 
style issues.
d7753d4 [Joseph K. Bradley] Added numTrees and featureSubsetStrategy to 
DecisionTreeRunner (to support RandomForest).  Fixed bugs so that RandomForest 
now runs.
746d43c [Joseph K. Bradley] Implemented feature subsampling.  Tested 
DecisionTree but not RandomForest.
6309d1d [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into 
rfs-new.  Added RandomForestModel.toString
b7ae594 [Joseph K. Bradley] Updated docs.  Small fix for bug which does not 
cause errors: No longer allocate unused child nodes for leaf nodes.
121c74e [Joseph K. Bradley] Basic random forests are implemented.  Random 
features per node not yet implemented.  Test suite not implemented.
325d18a [Joseph K. Bradley] Merge branch 'chouqin-dt-preprune' into rfs-new
4ef9bf1 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into 
rfs-new
61b2e72 [Joseph K. Bradley] Added max of 10GB for maxMemoryInMB in Strategy.
a95e7c8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into 
chouqin-dt-preprune
6da8571 [Joseph K. Bradley] RFs partly implemented, not done yet
eddd1eb [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into 
rfs-new
5c4ac33 [Joseph K. Bradley] Added check in Strategy to make sure 
minInstancesPerNode >= 1
0dd4d87 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into 
dt-spark-3160
95c479d [Joseph K. Bradley] * Fixed typo in tree suite test "do not choose 
split that does not satisfy min instance per node requirements" * small style 
fixes
e2628b6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into 
chouqin-dt-preprune
19b01af [Joseph K. Bradley] Merge remote-tracking branch 'chouqin/dt-preprune' 
into chouqin-dt-preprune
f1d11d1 [chouqin] fix typo
c7ebaf1 [chouqin] fix typo
39f9b60 [chouqin] change edge `minInstancesPerNode` to 2 and add one more test
c6e2dfc [Joseph K. Bradley] Added minInstancesPerNode and minInfoGain 
parameters to DecisionTreeRunner.scala and to Python API in tree.py
306120f [Joseph K. Bradley] Fixed typo in DecisionTreeModel.scala doc
eaa1dcf [Joseph K. Bradley] Added topNode doc in DecisionTree and scalastyle fix
d4d7864 [Joseph K. Bradley] Marked Node.build as deprecated
d4dbb99 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into 
dt-spark-3160
1a8f0ad [Joseph K. Bradley] Eliminated pre-allocated nodes array in main 
train() method. * Nodes are constructed and added to the tree structure as 
needed during training.
0278a11 [chouqin] remove `noSplit` and set `Predict` private to tree
d593ec7 [chouqin] fix docs and change minInstancesPerNode to 1
2ab763b [Joseph K. Bradley] Simplifications to DecisionTree code:
efcc736 [qiping.lqp] fix bug
10b8012 [qiping.lqp] fix style
6728fad [qiping.lqp] minor fix: remove empty lines
bb465ca [qiping.lqp] Merge branch 'master' of https://github.com/apache/spark 
into dt-preprune
cadd569 [qiping.lqp] add api docs
46b891f [qiping.lqp] fix bug
e72c7e4 [qiping.lqp] add comments
845c6fa [qiping.lqp] fix style
f195e83 [qiping.lqp] fix style
987cbf4 [qiping.lqp] fix bug
ff34845 [qiping.lqp] separate calculation of predict of node from calculation 
of info gain
ac42378 [qiping.lqp] add min info gain and min instances per node parameters in 
decision tree


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0dc2b636
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0dc2b636
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0dc2b636

Branch: refs/heads/master
Commit: 0dc2b6361d61b7d94cba3dc83da2abb7e08ba6fe
Parents: f350cd3
Author: Joseph K. Bradley <[email protected]>
Authored: Sun Sep 28 21:44:50 2014 -0700
Committer: Xiangrui Meng <[email protected]>
Committed: Sun Sep 28 21:44:50 2014 -0700

----------------------------------------------------------------------
 .../examples/mllib/DecisionTreeRunner.scala     |  76 ++-
 .../apache/spark/mllib/tree/DecisionTree.scala  | 457 +++++++------------
 .../apache/spark/mllib/tree/RandomForest.scala  | 451 ++++++++++++++++++
 .../spark/mllib/tree/impl/BaggedPoint.scala     |  80 ++++
 .../mllib/tree/impl/DTStatsAggregator.scala     | 219 ++++++---
 .../mllib/tree/impl/DecisionTreeMetadata.scala  |  47 +-
 .../spark/mllib/tree/impurity/Entropy.scala     |   4 +-
 .../apache/spark/mllib/tree/impurity/Gini.scala |   4 +-
 .../spark/mllib/tree/impurity/Impurity.scala    |   2 +-
 .../spark/mllib/tree/impurity/Variance.scala    |   8 +-
 .../apache/spark/mllib/tree/model/Node.scala    |  13 +-
 .../mllib/tree/model/RandomForestModel.scala    | 105 +++++
 .../spark/mllib/tree/DecisionTreeSuite.scala    | 210 ++++-----
 .../spark/mllib/tree/RandomForestSuite.scala    | 245 ++++++++++
 14 files changed, 1410 insertions(+), 511 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
----------------------------------------------------------------------
diff --git 
a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
 
b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 4683e6e..96fb068 100644
--- 
a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ 
b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -21,16 +21,18 @@ import scopt.OptionParser
 
 import org.apache.spark.{SparkConf, SparkContext}
 import org.apache.spark.SparkContext._
+import org.apache.spark.mllib.evaluation.MulticlassMetrics
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.{DecisionTree, impurity}
+import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity}
 import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
 import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.model.DecisionTreeModel
+import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel}
 import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
 
 /**
- * An example runner for decision tree. Run with
+ * An example runner for decision trees and random forests. Run with
  * {{{
  * ./bin/run-example org.apache.spark.examples.mllib.DecisionTreeRunner 
[options]
  * }}}
@@ -57,6 +59,8 @@ object DecisionTreeRunner {
       maxBins: Int = 32,
       minInstancesPerNode: Int = 1,
       minInfoGain: Double = 0.0,
+      numTrees: Int = 1,
+      featureSubsetStrategy: String = "auto",
       fracTest: Double = 0.2)
 
   def main(args: Array[String]) {
@@ -79,11 +83,20 @@ object DecisionTreeRunner {
         .action((x, c) => c.copy(maxBins = x))
       opt[Int]("minInstancesPerNode")
         .text(s"min number of instances required at child nodes to create the 
parent split," +
-        s" default: ${defaultParams.minInstancesPerNode}")
+          s" default: ${defaultParams.minInstancesPerNode}")
         .action((x, c) => c.copy(minInstancesPerNode = x))
       opt[Double]("minInfoGain")
         .text(s"min info gain required to create a split, default: 
${defaultParams.minInfoGain}")
         .action((x, c) => c.copy(minInfoGain = x))
+      opt[Int]("numTrees")
+        .text(s"number of trees (1 = decision tree, 2+ = random forest)," +
+          s" default: ${defaultParams.numTrees}")
+        .action((x, c) => c.copy(numTrees = x))
+      opt[String]("featureSubsetStrategy")
+        .text(s"feature subset sampling strategy" +
+          s" (${RandomForest.supportedFeatureSubsetStrategies.mkString(", 
")}}), " +
+          s"default: ${defaultParams.featureSubsetStrategy}")
+        .action((x, c) => c.copy(featureSubsetStrategy = x))
       opt[Double]("fracTest")
         .text(s"fraction of data to hold out for testing, default: 
${defaultParams.fracTest}")
         .action((x, c) => c.copy(fracTest = x))
@@ -191,18 +204,35 @@ object DecisionTreeRunner {
           numClassesForClassification = numClasses,
           minInstancesPerNode = params.minInstancesPerNode,
           minInfoGain = params.minInfoGain)
-    val model = DecisionTree.train(training, strategy)
-
-    println(model)
-
-    if (params.algo == Classification) {
-      val accuracy = accuracyScore(model, test)
-      println(s"Test accuracy = $accuracy")
-    }
-
-    if (params.algo == Regression) {
-      val mse = meanSquaredError(model, test)
-      println(s"Test mean squared error = $mse")
+    if (params.numTrees == 1) {
+      val model = DecisionTree.train(training, strategy)
+      println(model)
+      if (params.algo == Classification) {
+        val accuracy =
+          new MulticlassMetrics(test.map(lp => (model.predict(lp.features), 
lp.label))).precision
+        println(s"Test accuracy = $accuracy")
+      }
+      if (params.algo == Regression) {
+        val mse = meanSquaredError(model, test)
+        println(s"Test mean squared error = $mse")
+      }
+    } else {
+      val randomSeed = Utils.random.nextInt()
+      if (params.algo == Classification) {
+        val model = RandomForest.trainClassifier(training, strategy, 
params.numTrees,
+          params.featureSubsetStrategy, randomSeed)
+        println(model)
+        val accuracy =
+          new MulticlassMetrics(test.map(lp => (model.predict(lp.features), 
lp.label))).precision
+        println(s"Test accuracy = $accuracy")
+      }
+      if (params.algo == Regression) {
+        val model = RandomForest.trainRegressor(training, strategy, 
params.numTrees,
+          params.featureSubsetStrategy, randomSeed)
+        println(model)
+        val mse = meanSquaredError(model, test)
+        println(s"Test mean squared error = $mse")
+      }
     }
 
     sc.stop()
@@ -211,9 +241,7 @@ object DecisionTreeRunner {
   /**
    * Calculates the classifier accuracy.
    */
-  private def accuracyScore(
-      model: DecisionTreeModel,
-      data: RDD[LabeledPoint]): Double = {
+  private def accuracyScore(model: DecisionTreeModel, data: 
RDD[LabeledPoint]): Double = {
     val correctCount = data.filter(y => model.predict(y.features) == 
y.label).count()
     val count = data.count()
     correctCount.toDouble / count
@@ -228,4 +256,14 @@ object DecisionTreeRunner {
       err * err
     }.mean()
   }
+
+  /**
+   * Calculates the mean squared error for regression.
+   */
+  private def meanSquaredError(tree: RandomForestModel, data: 
RDD[LabeledPoint]): Double = {
+    data.map { y =>
+      val err = tree.predict(y.features) - y.label
+      err * err
+    }.mean()
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index c7f2576..b7dc373 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -18,12 +18,14 @@
 package org.apache.spark.mllib.tree
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable
 
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.Logging
 import org.apache.spark.mllib.rdd.RDDFunctions._
 import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
 import org.apache.spark.mllib.tree.configuration.Strategy
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.FeatureType._
@@ -33,7 +35,6 @@ import org.apache.spark.mllib.tree.impurity.{Impurities, 
Impurity}
 import org.apache.spark.mllib.tree.impurity._
 import org.apache.spark.mllib.tree.model._
 import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.random.XORShiftRandom
 
 
@@ -56,99 +57,10 @@ class DecisionTree (private val strategy: Strategy) extends 
Serializable with Lo
    * @return DecisionTreeModel that can be used for prediction
    */
   def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
-
-    val timer = new TimeTracker()
-
-    timer.start("total")
-
-    timer.start("init")
-
-    val retaggedInput = input.retag(classOf[LabeledPoint])
-    val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy)
-    logDebug("algo = " + strategy.algo)
-    logDebug("maxBins = " + metadata.maxBins)
-
-    // Find the splits and the corresponding bins (interval between the 
splits) using a sample
-    // of the input data.
-    timer.start("findSplitsBins")
-    val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)
-    timer.stop("findSplitsBins")
-    logDebug("numBins: feature: number of bins")
-    logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
-        s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
-      }.mkString("\n"))
-
-    // Bin feature values (TreePoint representation).
-    // Cache input RDD for speedup during multiple passes.
-    val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
-      .persist(StorageLevel.MEMORY_AND_DISK)
-
-    // depth of the decision tree
-    val maxDepth = strategy.maxDepth
-    require(maxDepth <= 30,
-      s"DecisionTree currently only supports maxDepth <= 30, but was given 
maxDepth = $maxDepth.")
-
-    // Calculate level for single group construction
-
-    // Max memory usage for aggregates
-    val maxMemoryUsage = strategy.maxMemoryInMB * 1024L * 1024L
-    logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
-    // TODO: Calculate memory usage more precisely.
-    val numElementsPerNode = DecisionTree.getElementsPerNode(metadata)
-
-    logDebug("numElementsPerNode = " + numElementsPerNode)
-    val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for 
bin aggregate array
-    val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 
1)
-    logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup)
-    // nodes at a level is 2^level. level is zero indexed.
-    val maxLevelForSingleGroup = math.max(
-      (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0)
-    logDebug("max level for single group = " + maxLevelForSingleGroup)
-
-    timer.stop("init")
-
-    /*
-     * The main idea here is to perform level-wise training of the decision 
tree nodes thus
-     * reducing the passes over the data from l to log2(l) where l is the 
total number of nodes.
-     * Each data sample is handled by a particular node at that level (or it 
reaches a leaf
-     * beforehand and is not used in later levels.
-     */
-
-    var topNode: Node = null // set on first iteration
-    var level = 0
-    var break = false
-    while (level <= maxDepth && !break) {
-      logDebug("#####################################")
-      logDebug("level = " + level)
-      logDebug("#####################################")
-
-      // Find best split for all nodes at a level.
-      timer.start("findBestSplits")
-      val (tmpTopNode: Node, doneTraining: Boolean) = 
DecisionTree.findBestSplits(treeInput,
-        metadata, level, topNode, splits, bins, maxLevelForSingleGroup, timer)
-      timer.stop("findBestSplits")
-
-      if (level == 0) {
-        topNode = tmpTopNode
-      }
-      if (doneTraining) {
-        break = true
-        logDebug("done training")
-      }
-
-      level += 1
-    }
-
-    logDebug("#####################################")
-    logDebug("Extracting tree model")
-    logDebug("#####################################")
-
-    timer.stop("total")
-
-    logInfo("Internal timing for DecisionTree:")
-    logInfo(s"$timer")
-
-    new DecisionTreeModel(topNode, strategy.algo)
+    // Note: random seed will not be used since numTrees = 1.
+    val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = 
"all", seed = 0)
+    val rfModel = rf.train(input)
+    rfModel.trees(0)
   }
 
 }
@@ -353,57 +265,9 @@ object DecisionTree extends Serializable with Logging {
   }
 
   /**
-   * Returns an array of optimal splits for all nodes at a given level. Splits 
the task into
-   * multiple groups if the level-wise training task could lead to memory 
overflow.
-   *
-   * @param input Training data: RDD of 
[[org.apache.spark.mllib.tree.impl.TreePoint]]
-   * @param metadata Learning and dataset metadata
-   * @param level Level of the tree
-   * @param topNode Root node of the tree (or invalid node when training first 
level).
-   * @param splits possible splits for all features, indexed 
(numFeatures)(numSplits)
-   * @param bins possible bins for all features, indexed (numFeatures)(numBins)
-   * @param maxLevelForSingleGroup the deepest level for single-group 
level-wise computation.
-   * @return  (root, doneTraining) where:
-   *          root = Root node (which is newly created on the first iteration),
-   *          doneTraining = true if no more internal nodes were created.
-   */
-  private[tree] def findBestSplits(
-      input: RDD[TreePoint],
-      metadata: DecisionTreeMetadata,
-      level: Int,
-      topNode: Node,
-      splits: Array[Array[Split]],
-      bins: Array[Array[Bin]],
-      maxLevelForSingleGroup: Int,
-      timer: TimeTracker = new TimeTracker): (Node, Boolean) = {
-
-    // split into groups to avoid memory overflow during aggregation
-    if (level > maxLevelForSingleGroup) {
-      // When information for all nodes at a given level cannot be stored in 
memory,
-      // the nodes are divided into multiple groups at each level with the 
number of groups
-      // increasing exponentially per level. For example, if 
maxLevelForSingleGroup is 10,
-      // numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
-      val numGroups = 1 << level - maxLevelForSingleGroup
-      logDebug("numGroups = " + numGroups)
-      // Iterate over each group of nodes at a level.
-      var groupIndex = 0
-      var doneTraining = true
-      while (groupIndex < numGroups) {
-        val (_, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, 
level,
-          topNode, splits, bins, timer, numGroups, groupIndex)
-        doneTraining = doneTraining && doneTrainingGroup
-        groupIndex += 1
-      }
-      (topNode, doneTraining) // Not first iteration, so topNode was already 
set.
-    } else {
-      findBestSplitsPerGroup(input, metadata, level, topNode, splits, bins, 
timer)
-    }
-  }
-
-  /**
    * Get the node index corresponding to this data point.
-   * This function mimics prediction, passing an example from the root node 
down to a node
-   * at the current level being trained; that node's index is returned.
+   * This function mimics prediction, passing an example from the root node 
down to a leaf
+   * or unsplit node; that node's index is returned.
    *
    * @param node  Node in tree from which to classify the given data point.
    * @param binnedFeatures  Binned feature vector for data point.
@@ -413,14 +277,15 @@ object DecisionTree extends Serializable with Logging {
    *          Otherwise, last node reachable in tree matching this example.
    *          Note: This is the global node index, i.e., the index used in the 
tree.
    *                This index is different from the index used during 
training a particular
-   *                set of nodes in a (level, group).
+   *                group of nodes on one call to [[findBestSplits()]].
    */
   private def predictNodeIndex(
       node: Node,
       binnedFeatures: Array[Int],
       bins: Array[Array[Bin]],
       unorderedFeatures: Set[Int]): Int = {
-    if (node.isLeaf) {
+    if (node.isLeaf || node.split.isEmpty) {
+      // Node is either leaf, or has not yet been split.
       node.id
     } else {
       val featureIndex = node.split.get.feature
@@ -465,43 +330,60 @@ object DecisionTree extends Serializable with Logging {
    * @param agg  Array storing aggregate calculation, with a set of sufficient 
statistics for
    *             each (node, feature, bin).
    * @param treePoint  Data point being aggregated.
-   * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 at 
start of (level, group).
+   * @param nodeIndex  Node corresponding to treePoint.  agg is indexed in [0, 
numNodes).
    * @param bins possible bins for all features, indexed (numFeatures)(numBins)
    * @param unorderedFeatures  Set of indices of unordered features.
+   * @param instanceWeight  Weight (importance) of instance in dataset.
    */
   private def mixedBinSeqOp(
       agg: DTStatsAggregator,
       treePoint: TreePoint,
       nodeIndex: Int,
       bins: Array[Array[Bin]],
-      unorderedFeatures: Set[Int]): Unit = {
-    // Iterate over all features.
-    val numFeatures = treePoint.binnedFeatures.size
+      unorderedFeatures: Set[Int],
+      instanceWeight: Double,
+      featuresForNode: Option[Array[Int]]): Unit = {
+    val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
+      // Use subsampled features
+      featuresForNode.get.size
+    } else {
+      // Use all features
+      agg.metadata.numFeatures
+    }
     val nodeOffset = agg.getNodeOffset(nodeIndex)
-    var featureIndex = 0
-    while (featureIndex < numFeatures) {
+    // Iterate over features.
+    var featureIndexIdx = 0
+    while (featureIndexIdx < numFeaturesPerNode) {
+      val featureIndex = if (featuresForNode.nonEmpty) {
+        featuresForNode.get.apply(featureIndexIdx)
+      } else {
+        featureIndexIdx
+      }
       if (unorderedFeatures.contains(featureIndex)) {
         // Unordered feature
         val featureValue = treePoint.binnedFeatures(featureIndex)
         val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
-          agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex)
+          agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx)
         // Update the left or right bin for each split.
-        val numSplits = agg.numSplits(featureIndex)
+        val numSplits = agg.metadata.numSplits(featureIndex)
         var splitIndex = 0
         while (splitIndex < numSplits) {
           if 
(bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) {
-            agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, 
treePoint.label)
+            agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, 
treePoint.label,
+              instanceWeight)
           } else {
-            agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, 
treePoint.label)
+            agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, 
treePoint.label,
+              instanceWeight)
           }
           splitIndex += 1
         }
       } else {
         // Ordered feature
         val binIndex = treePoint.binnedFeatures(featureIndex)
-        agg.nodeUpdate(nodeOffset, featureIndex, binIndex, treePoint.label)
+        agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, 
treePoint.label,
+          instanceWeight)
       }
-      featureIndex += 1
+      featureIndexIdx += 1
     }
   }
 
@@ -513,66 +395,77 @@ object DecisionTree extends Serializable with Logging {
    * @param agg  Array storing aggregate calculation, with a set of sufficient 
statistics for
    *             each (node, feature, bin).
    * @param treePoint  Data point being aggregated.
-   * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 at 
start of (level, group).
-   * @return agg
+   * @param nodeIndex  Node corresponding to treePoint.  agg is indexed in [0, 
numNodes).
+   * @param instanceWeight  Weight (importance) of instance in dataset.
    */
   private def orderedBinSeqOp(
       agg: DTStatsAggregator,
       treePoint: TreePoint,
-      nodeIndex: Int): Unit = {
+      nodeIndex: Int,
+      instanceWeight: Double,
+      featuresForNode: Option[Array[Int]]): Unit = {
     val label = treePoint.label
     val nodeOffset = agg.getNodeOffset(nodeIndex)
-    // Iterate over all features.
-    val numFeatures = agg.numFeatures
-    var featureIndex = 0
-    while (featureIndex < numFeatures) {
-      val binIndex = treePoint.binnedFeatures(featureIndex)
-      agg.nodeUpdate(nodeOffset, featureIndex, binIndex, label)
-      featureIndex += 1
+    // Iterate over features.
+    if (featuresForNode.nonEmpty) {
+      // Use subsampled features
+      var featureIndexIdx = 0
+      while (featureIndexIdx < featuresForNode.get.size) {
+        val binIndex = 
treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
+        agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, 
label, instanceWeight)
+        featureIndexIdx += 1
+      }
+    } else {
+      // Use all features
+      val numFeatures = agg.metadata.numFeatures
+      var featureIndex = 0
+      while (featureIndex < numFeatures) {
+        val binIndex = treePoint.binnedFeatures(featureIndex)
+        agg.nodeUpdate(nodeOffset, nodeIndex, featureIndex, binIndex, label, 
instanceWeight)
+        featureIndex += 1
+      }
     }
   }
 
   /**
-   * Returns an array of optimal splits for a group of nodes at a given level
+   * Given a group of nodes, this finds the best split for each node.
    *
    * @param input Training data: RDD of 
[[org.apache.spark.mllib.tree.impl.TreePoint]]
    * @param metadata Learning and dataset metadata
-   * @param level Level of the tree
-   * @param topNode Root node of the tree (or invalid node when training first 
level).
+   * @param topNodes Root node for each tree.  Used for matching instances 
with nodes.
+   * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
+   * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> 
nodeIndexInfo,
+   *                              where nodeIndexInfo stores the index in the 
group and the
+   *                              feature subsets (if using feature subsets).
    * @param splits possible splits for all features, indexed 
(numFeatures)(numSplits)
    * @param bins possible bins for all features, indexed (numFeatures)(numBins)
-   * @param numGroups total number of node groups at the current level. 
Default value is set to 1.
-   * @param groupIndex index of the node group being processed. Default value 
is set to 0.
-   * @return  (root, doneTraining) where:
-   *          root = Root node (which is newly created on the first iteration),
-   *          doneTraining = true if no more internal nodes were created.
+   * @param nodeQueue  Queue of nodes to split, with values (treeIndex, node).
+   *                   Updated with new non-leaf nodes which are created.
    */
-  private def findBestSplitsPerGroup(
-      input: RDD[TreePoint],
+  private[tree] def findBestSplits(
+      input: RDD[BaggedPoint[TreePoint]],
       metadata: DecisionTreeMetadata,
-      level: Int,
-      topNode: Node,
+      topNodes: Array[Node],
+      nodesForGroup: Map[Int, Array[Node]],
+      treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
       splits: Array[Array[Split]],
       bins: Array[Array[Bin]],
-      timer: TimeTracker,
-      numGroups: Int = 1,
-      groupIndex: Int = 0): (Node, Boolean) = {
+      nodeQueue: mutable.Queue[(Int, Node)],
+      timer: TimeTracker = new TimeTracker): Unit = {
 
     /*
      * The high-level descriptions of the best split optimizations are noted 
here.
      *
-     * *Level-wise training*
-     * We perform bin calculations for all nodes at the given level to avoid 
making multiple
-     * passes over the data. Thus, for a slightly increased computation and 
storage cost we save
-     * several iterations over the data especially at higher levels of the 
decision tree.
+     * *Group-wise training*
+     * We perform bin calculations for groups of nodes to reduce the number of
+     * passes over the data.  Each iteration requires more computation and 
storage,
+     * but saves several iterations over the data.
      *
      * *Bin-wise computation*
      * We use a bin-wise best split computation strategy instead of a 
straightforward best split
      * computation strategy. Instead of analyzing each sample for contribution 
to the left/right
      * child node impurity of every split, we first categorize each feature of 
a sample into a
-     * bin. Each bin is an interval between a low and high split. Since each 
split, and thus bin,
-     * is ordered (read ordering for categorical variables in the 
findSplitsBins method),
-     * we exploit this structure to calculate aggregates for bins and then use 
these aggregates
+     * bin. We exploit this structure to calculate aggregates for bins and 
then use these aggregates
      * to calculate information gain for each split.
      *
      * *Aggregation over partitions*
@@ -582,42 +475,15 @@ object DecisionTree extends Serializable with Logging {
      * drastically reduce the communication overhead.
      */
 
-    // Common calculations for multiple nested methods:
-
-    // numNodes:  Number of nodes in this (level of tree, group),
-    //            where nodes at deeper (larger) levels may be divided into 
groups.
-    val numNodes = Node.maxNodesInLevel(level) / numGroups
+    // numNodes:  Number of nodes in this group
+    val numNodes = nodesForGroup.values.map(_.size).sum
     logDebug("numNodes = " + numNodes)
-
     logDebug("numFeatures = " + metadata.numFeatures)
     logDebug("numClasses = " + metadata.numClasses)
     logDebug("isMulticlass = " + metadata.isMulticlass)
     logDebug("isMulticlassWithCategoricalFeatures = " +
       metadata.isMulticlassWithCategoricalFeatures)
 
-    // shift when more than one group is used at deep tree level
-    val groupShift = numNodes * groupIndex
-
-    // Used for treePointToNodeIndex to get an index for this (level, group).
-    // - Node.startIndexInLevel(level) gives the global index offset for nodes 
at this level.
-    // - groupShift corrects for groups in this level before the current group.
-    val globalNodeIndexOffset = Node.startIndexInLevel(level) + groupShift
-
-    /**
-     * Find the node index for the given example.
-     * Nodes are indexed from 0 at the start of this (level, group).
-     * If the example does not reach this level, returns a value < 0.
-     */
-    def treePointToNodeIndex(treePoint: TreePoint): Int = {
-      if (level == 0) {
-        0
-      } else {
-        val globalNodeIndex =
-          predictNodeIndex(topNode, treePoint.binnedFeatures, bins, 
metadata.unorderedFeatures)
-        globalNodeIndex - globalNodeIndexOffset
-      }
-    }
-
     /**
      * Performs a sequential aggregation over a partition.
      *
@@ -626,21 +492,27 @@ object DecisionTree extends Serializable with Logging {
      *
      * @param agg  Array storing aggregate calculation, with a set of 
sufficient statistics for
      *             each (node, feature, bin).
-     * @param treePoint   Data point being aggregated.
+     * @param baggedPoint   Data point being aggregated.
      * @return  agg
      */
     def binSeqOp(
         agg: DTStatsAggregator,
-        treePoint: TreePoint): DTStatsAggregator = {
-      val nodeIndex = treePointToNodeIndex(treePoint)
-      // If the example does not reach this level, then nodeIndex < 0.
-      // If the example reaches this level but is handled in a different group,
-      //  then either nodeIndex < 0 (previous group) or nodeIndex >= numNodes 
(later group).
-      if (nodeIndex >= 0 && nodeIndex < numNodes) {
-        if (metadata.unorderedFeatures.isEmpty) {
-          orderedBinSeqOp(agg, treePoint, nodeIndex)
-        } else {
-          mixedBinSeqOp(agg, treePoint, nodeIndex, bins, 
metadata.unorderedFeatures)
+        baggedPoint: BaggedPoint[TreePoint]): DTStatsAggregator = {
+      treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
+        val nodeIndex = predictNodeIndex(topNodes(treeIndex), 
baggedPoint.datum.binnedFeatures,
+          bins, metadata.unorderedFeatures)
+        val nodeInfo = nodeIndexToInfo.getOrElse(nodeIndex, null)
+        // If the example does not reach a node in this group, then nodeIndex 
= null.
+        if (nodeInfo != null) {
+          val aggNodeIndex = nodeInfo.nodeIndexInGroup
+          val featuresForNode = nodeInfo.featureSubset
+          val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
+          if (metadata.unorderedFeatures.isEmpty) {
+            orderedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, 
instanceWeight, featuresForNode)
+          } else {
+            mixedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, bins, 
metadata.unorderedFeatures,
+              instanceWeight, featuresForNode)
+          }
         }
       }
       agg
@@ -649,71 +521,62 @@ object DecisionTree extends Serializable with Logging {
     // Calculate bin aggregates.
     timer.start("aggregation")
     val binAggregates: DTStatsAggregator = {
-      val initAgg = new DTStatsAggregator(metadata, numNodes)
+      val initAgg = if (metadata.subsamplingFeatures) {
+        new DTStatsAggregatorSubsampledFeatures(metadata, 
treeToNodeToIndexInfo)
+      } else {
+        new DTStatsAggregatorFixedFeatures(metadata, numNodes)
+      }
       input.treeAggregate(initAgg)(binSeqOp, DTStatsAggregator.binCombOp)
     }
     timer.stop("aggregation")
 
-    // Calculate best splits for all nodes at a given level
+    // Calculate best splits for all nodes in the group
     timer.start("chooseSplits")
-    // On the first iteration, we need to get and return the newly created 
root node.
-    var newTopNode: Node = topNode
-
-    // Iterate over all nodes at this level
-    var nodeIndex = 0
-    var internalNodeCount = 0
-    while (nodeIndex < numNodes) {
-      val (split: Split, stats: InformationGainStats, predict: Predict) =
-        binsToBestSplit(binAggregates, nodeIndex, level, metadata, splits)
-      logDebug("best split = " + split)
-
-      val globalNodeIndex = globalNodeIndexOffset + nodeIndex
 
-      // Extract info for this node at the current level.
-      val isLeaf = (stats.gain <= 0) || (level == metadata.maxDepth)
-      val node =
-        new Node(globalNodeIndex, predict.predict, isLeaf, Some(split), None, 
None, Some(stats))
-      logDebug("Node = " + node)
-
-      if (!isLeaf) {
-        internalNodeCount += 1
-      }
-      if (level == 0) {
-        newTopNode = node
-      } else {
-        // Set parent.
-        val parentNode = Node.getNode(Node.parentIndex(globalNodeIndex), 
topNode)
-        if (Node.isLeftChild(globalNodeIndex)) {
-          parentNode.leftNode = Some(node)
-        } else {
-          parentNode.rightNode = Some(node)
+    // Iterate over all nodes in this group.
+    nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
+      nodesForTree.foreach { node =>
+        val nodeIndex = node.id
+        val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
+        val aggNodeIndex = nodeInfo.nodeIndexInGroup
+        val featuresForNode = nodeInfo.featureSubset
+        val (split: Split, stats: InformationGainStats, predict: Predict) =
+          binsToBestSplit(binAggregates, aggNodeIndex, splits, featuresForNode)
+        logDebug("best split = " + split)
+
+        // Extract info for this node.  Create children if not leaf.
+        val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == 
metadata.maxDepth)
+        assert(node.id == nodeIndex)
+        node.predict = predict.predict
+        node.isLeaf = isLeaf
+        node.stats = Some(stats)
+        logDebug("Node = " + node)
+
+        if (!isLeaf) {
+          node.split = Some(split)
+          node.leftNode = Some(Node.emptyNode(Node.leftChildIndex(nodeIndex)))
+          node.rightNode = 
Some(Node.emptyNode(Node.rightChildIndex(nodeIndex)))
+          nodeQueue.enqueue((treeIndex, node.leftNode.get))
+          nodeQueue.enqueue((treeIndex, node.rightNode.get))
+          logDebug("leftChildIndex = " + node.leftNode.get.id +
+            ", impurity = " + stats.leftImpurity)
+          logDebug("rightChildIndex = " + node.rightNode.get.id +
+            ", impurity = " + stats.rightImpurity)
         }
       }
-      if (level < metadata.maxDepth) {
-        logDebug("leftChildIndex = " + Node.leftChildIndex(globalNodeIndex) +
-          ", impurity = " + stats.leftImpurity)
-        logDebug("rightChildIndex = " + Node.rightChildIndex(globalNodeIndex) +
-          ", impurity = " + stats.rightImpurity)
-      }
-
-      nodeIndex += 1
     }
     timer.stop("chooseSplits")
-
-    val doneTraining = internalNodeCount == 0
-    (newTopNode, doneTraining)
   }
 
   /**
    * Calculate the information gain for a given (feature, split) based upon 
left/right aggregates.
    * @param leftImpurityCalculator left node aggregates for this (feature, 
split)
    * @param rightImpurityCalculator right node aggregate for this (feature, 
split)
-   * @return information gain and statistics for all splits
+   * @return information gain and statistics for split
    */
   private def calculateGainForSplit(
       leftImpurityCalculator: ImpurityCalculator,
       rightImpurityCalculator: ImpurityCalculator,
-      level: Int,
       metadata: DecisionTreeMetadata): InformationGainStats = {
     val leftCount = leftImpurityCalculator.count
     val rightCount = rightImpurityCalculator.count
@@ -753,7 +616,7 @@ object DecisionTree extends Serializable with Logging {
    * Calculate predict value for current node, given stats of any split.
    * Note that this function is called only once for each node.
    * @param leftImpurityCalculator left node aggregates for a split
-   * @param rightImpurityCalculator right node aggregates for a node
+   * @param rightImpurityCalculator right node aggregates for a split
    * @return predict value for current node
    */
   private def calculatePredict(
@@ -770,27 +633,33 @@ object DecisionTree extends Serializable with Logging {
   /**
    * Find the best split for a node.
    * @param binAggregates Bin statistics.
-   * @param nodeIndex Index for node to split in this (level, group).
-   * @return tuple for best split: (Split, information gain)
+   * @param nodeIndex Index into aggregates for node to split in this group.
+   * @return tuple for best split: (Split, information gain, prediction at 
node)
    */
   private def binsToBestSplit(
       binAggregates: DTStatsAggregator,
       nodeIndex: Int,
-      level: Int,
-      metadata: DecisionTreeMetadata,
-      splits: Array[Array[Split]]): (Split, InformationGainStats, Predict) = {
+      splits: Array[Array[Split]],
+      featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, 
Predict) = {
+
+    val metadata: DecisionTreeMetadata = binAggregates.metadata
 
     // calculate predict only once
     var predict: Option[Predict] = None
 
     // For each (feature, split), calculate the gain, and select the best 
(feature, split).
-    val (bestSplit, bestSplitStats) = Range(0, metadata.numFeatures).map { 
featureIndex =>
+    val (bestSplit, bestSplitStats) = Range(0, 
metadata.numFeaturesPerNode).map { featureIndexIdx =>
+      val featureIndex = if (featuresForNode.nonEmpty) {
+        featuresForNode.get.apply(featureIndexIdx)
+      } else {
+        featureIndexIdx
+      }
       val numSplits = metadata.numSplits(featureIndex)
       if (metadata.isContinuous(featureIndex)) {
         // Cumulative sum (scanLeft) of bin statistics.
         // Afterwards, binAggregates for a bin is the sum of aggregates for
         // that bin + all preceding bins.
-        val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, 
featureIndex)
+        val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, 
featureIndexIdx)
         var splitIndex = 0
         while (splitIndex < numSplits) {
           binAggregates.mergeForNodeFeature(nodeFeatureOffset, splitIndex + 1, 
splitIndex)
@@ -803,26 +672,26 @@ object DecisionTree extends Serializable with Logging {
             val rightChildStats = 
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
             rightChildStats.subtract(leftChildStats)
             predict = Some(predict.getOrElse(calculatePredict(leftChildStats, 
rightChildStats)))
-            val gainStats = calculateGainForSplit(leftChildStats, 
rightChildStats, level, metadata)
+            val gainStats = calculateGainForSplit(leftChildStats, 
rightChildStats, metadata)
             (splitIdx, gainStats)
           }.maxBy(_._2.gain)
         (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
       } else if (metadata.isUnordered(featureIndex)) {
         // Unordered categorical feature
         val (leftChildOffset, rightChildOffset) =
-          binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex)
+          binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, 
featureIndexIdx)
         val (bestFeatureSplitIndex, bestFeatureGainStats) =
           Range(0, numSplits).map { splitIndex =>
             val leftChildStats = 
binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
             val rightChildStats = 
binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
             predict = Some(predict.getOrElse(calculatePredict(leftChildStats, 
rightChildStats)))
-            val gainStats = calculateGainForSplit(leftChildStats, 
rightChildStats, level, metadata)
+            val gainStats = calculateGainForSplit(leftChildStats, 
rightChildStats, metadata)
             (splitIndex, gainStats)
           }.maxBy(_._2.gain)
         (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
       } else {
         // Ordered categorical feature
-        val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, 
featureIndex)
+        val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, 
featureIndexIdx)
         val numBins = metadata.numBins(featureIndex)
 
         /* Each bin is one category (feature value).
@@ -887,7 +756,7 @@ object DecisionTree extends Serializable with Logging {
               binAggregates.getImpurityCalculator(nodeFeatureOffset, 
lastCategory)
             rightChildStats.subtract(leftChildStats)
             predict = Some(predict.getOrElse(calculatePredict(leftChildStats, 
rightChildStats)))
-            val gainStats = calculateGainForSplit(leftChildStats, 
rightChildStats, level, metadata)
+            val gainStats = calculateGainForSplit(leftChildStats, 
rightChildStats, metadata)
             (splitIndex, gainStats)
           }.maxBy(_._2.gain)
         val categoriesForSplit =
@@ -904,18 +773,6 @@ object DecisionTree extends Serializable with Logging {
   }
 
   /**
-   * Get the number of values to be stored per node in the bin aggregates.
-   */
-  private def getElementsPerNode(metadata: DecisionTreeMetadata): Long = {
-    val totalBins = metadata.numBins.map(_.toLong).sum
-    if (metadata.isClassification) {
-      metadata.numClasses * totalBins
-    } else {
-      3 * totalBins
-    }
-  }
-
-  /**
    * Returns splits and bins for decision tree calculation.
    * Continuous and categorical features are handled differently.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
new file mode 100644
index 0000000..7fa7725
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -0,0 +1,451 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, 
DecisionTreeMetadata, TimeTracker}
+import org.apache.spark.mllib.tree.impurity.Impurities
+import org.apache.spark.mllib.tree.model._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
+
+/**
+ * :: Experimental ::
+ * A class which implements a random forest learning algorithm for 
classification and regression.
+ * It supports both continuous and categorical features.
+ *
+ * The settings for featureSubsetStrategy are based on the following 
references:
+ *  - log2: tested in Breiman (2001)
+ *  - sqrt: recommended by Breiman manual for random forests
+ *  - The defaults of sqrt (classification) and onethird (regression) match 
the R randomForest
+ *    package.
+ * @see [[http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf  Breiman 
(2001)]]
+ * @see [[http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf  
Breiman manual for
+ *     random forests]]
+ *
+ * @param strategy The configuration parameters for the random forest 
algorithm which specify
+ *                 the type of algorithm (classification, regression, etc.), 
feature type
+ *                 (continuous, categorical), depth of the tree, quantile 
calculation strategy,
+ *                 etc.
+ * @param numTrees If 1, then no bootstrapping is used.  If > 1, then 
bootstrapping is done.
+ * @param featureSubsetStrategy Number of features to consider for splits at 
each node.
+ *                              Supported: "auto" (default), "all", "sqrt", 
"log2", "onethird".
+ *                              If "auto" is set, this parameter is set based 
on numTrees:
+ *                                if numTrees == 1, set to "all";
+ *                                if numTrees > 1 (forest) set to "sqrt" for 
classification and
+ *                                  to "onethird" for regression.
+ * @param seed  Random seed for bootstrapping and choosing feature subsets.
+ */
+@Experimental
+private class RandomForest (
+    private val strategy: Strategy,
+    private val numTrees: Int,
+    featureSubsetStrategy: String,
+    private val seed: Int)
+  extends Serializable with Logging {
+
+  strategy.assertValid()
+  require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given 
numTrees = $numTrees.")
+  
require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy),
+    s"RandomForest given invalid featureSubsetStrategy: 
$featureSubsetStrategy." +
+    s" Supported values: 
${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}.")
+
+  /**
+   * Method to train a decision tree model over an RDD
+   * @param input Training data: RDD of 
[[org.apache.spark.mllib.regression.LabeledPoint]]
+   * @return RandomForestModel that can be used for prediction
+   */
+  def train(input: RDD[LabeledPoint]): RandomForestModel = {
+
+    val timer = new TimeTracker()
+
+    timer.start("total")
+
+    timer.start("init")
+
+    val retaggedInput = input.retag(classOf[LabeledPoint])
+    val metadata =
+      DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, 
featureSubsetStrategy)
+    logDebug("algo = " + strategy.algo)
+    logDebug("numTrees = " + numTrees)
+    logDebug("seed = " + seed)
+    logDebug("maxBins = " + metadata.maxBins)
+    logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
+    logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
+
+    // Find the splits and the corresponding bins (interval between the 
splits) using a sample
+    // of the input data.
+    timer.start("findSplitsBins")
+    val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)
+    timer.stop("findSplitsBins")
+    logDebug("numBins: feature: number of bins")
+    logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
+        s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
+      }.mkString("\n"))
+
+    // Bin feature values (TreePoint representation).
+    // Cache input RDD for speedup during multiple passes.
+    val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
+    val baggedInput = if (numTrees > 1) {
+      BaggedPoint.convertToBaggedRDD(treeInput, numTrees, seed)
+    } else {
+      BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+    }.persist(StorageLevel.MEMORY_AND_DISK)
+
+    // depth of the decision tree
+    val maxDepth = strategy.maxDepth
+    require(maxDepth <= 30,
+      s"DecisionTree currently only supports maxDepth <= 30, but was given 
maxDepth = $maxDepth.")
+
+    // Max memory usage for aggregates
+    // TODO: Calculate memory usage more precisely.
+    val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
+    logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
+    val maxMemoryPerNode = {
+      val featureSubset: Option[Array[Int]] = if 
(metadata.subsamplingFeatures) {
+        // Find numFeaturesPerNode largest bins to get an upper bound on 
memory usage.
+        Some(metadata.numBins.zipWithIndex.sortBy(- _._1)
+          .take(metadata.numFeaturesPerNode).map(_._2))
+      } else {
+        None
+      }
+      RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
+    }
+    require(maxMemoryPerNode <= maxMemoryUsage,
+      s"RandomForest/DecisionTree given maxMemoryInMB = 
${strategy.maxMemoryInMB}," +
+      " which is too small for the given features." +
+      s"  Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}")
+
+    timer.stop("init")
+
+    /*
+     * The main idea here is to perform group-wise training of the decision 
tree nodes thus
+     * reducing the passes over the data from (# nodes) to (# nodes / 
maxNumberOfNodesPerGroup).
+     * Each data sample is handled by a particular node (or it reaches a leaf 
and is not used
+     * in lower levels).
+     */
+
+    // FIFO queue of nodes to train: (treeIndex, node)
+    val nodeQueue = new mutable.Queue[(Int, Node)]()
+
+    val rng = new scala.util.Random()
+    rng.setSeed(seed)
+
+    // Allocate and queue root nodes.
+    val topNodes: Array[Node] = 
Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1))
+    Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, 
topNodes(treeIndex))))
+
+    while (nodeQueue.nonEmpty) {
+      // Collect some nodes to split, and choose features for each node (if 
subsampling).
+      // Each group of nodes may come from one or multiple trees, and at 
multiple levels.
+      val (nodesForGroup, treeToNodeToIndexInfo) =
+        RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, 
rng)
+      // Sanity check (should never occur):
+      assert(nodesForGroup.size > 0,
+        s"RandomForest selected empty nodesForGroup.  Error for unknown 
reason.")
+
+      // Choose node splits, and enqueue new nodes as needed.
+      timer.start("findBestSplits")
+      DecisionTree.findBestSplits(baggedInput,
+        metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, 
bins, nodeQueue, timer)
+      timer.stop("findBestSplits")
+    }
+
+    timer.stop("total")
+
+    logInfo("Internal timing for DecisionTree:")
+    logInfo(s"$timer")
+
+    val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, 
strategy.algo))
+    RandomForestModel.build(trees)
+  }
+
+}
+
+object RandomForest extends Serializable with Logging {
+
+  /**
+   * Method to train a decision tree model for binary or multiclass 
classification.
+   *
+   * @param input Training dataset: RDD of 
[[org.apache.spark.mllib.regression.LabeledPoint]].
+   *              Labels should take values {0, 1, ..., numClasses-1}.
+   * @param strategy Parameters for training each tree in the forest.
+   * @param numTrees Number of trees in the random forest.
+   * @param featureSubsetStrategy Number of features to consider for splits at 
each node.
+   *                              Supported: "auto" (default), "all", "sqrt", 
"log2", "onethird".
+   *                              If "auto" is set, this parameter is set 
based on numTrees:
+   *                                if numTrees == 1, set to "all";
+   *                                if numTrees > 1 (forest) set to "sqrt" for 
classification and
+   *                                  to "onethird" for regression.
+   * @param seed  Random seed for bootstrapping and choosing feature subsets.
+   * @return RandomForestModel that can be used for prediction
+   */
+  def trainClassifier(
+      input: RDD[LabeledPoint],
+      strategy: Strategy,
+      numTrees: Int,
+      featureSubsetStrategy: String,
+      seed: Int): RandomForestModel = {
+    require(strategy.algo == Classification,
+      s"RandomForest.trainClassifier given Strategy with invalid algo: 
${strategy.algo}")
+    val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
+    rf.train(input)
+  }
+
+  /**
+   * Method to train a decision tree model for binary or multiclass 
classification.
+   *
+   * @param input Training dataset: RDD of 
[[org.apache.spark.mllib.regression.LabeledPoint]].
+   *              Labels should take values {0, 1, ..., numClasses-1}.
+   * @param numClassesForClassification number of classes for classification.
+   * @param categoricalFeaturesInfo Map storing arity of categorical features.
+   *                                E.g., an entry (n -> k) indicates that 
feature n is categorical
+   *                                with k categories indexed from 0: {0, 1, 
..., k-1}.
+   * @param numTrees Number of trees in the random forest.
+   * @param featureSubsetStrategy Number of features to consider for splits at 
each node.
+   *                              Supported: "auto" (default), "all", "sqrt", 
"log2", "onethird".
+   *                              If "auto" is set, this parameter is set 
based on numTrees:
+   *                                if numTrees == 1, set to "all";
+   *                                if numTrees > 1 (forest) set to "sqrt" for 
classification and
+   *                                  to "onethird" for regression.
+   * @param impurity Criterion used for information gain calculation.
+   *                 Supported values: "gini" (recommended) or "entropy".
+   * @param maxDepth Maximum depth of the tree.
+   *                 E.g., depth 0 means 1 leaf node; depth 1 means 1 internal 
node + 2 leaf nodes.
+   *                  (suggested value: 4)
+   * @param maxBins maximum number of bins used for splitting features
+   *                 (suggested value: 100)
+   * @param seed  Random seed for bootstrapping and choosing feature subsets.
+   * @return RandomForestModel that can be used for prediction
+   */
+  def trainClassifier(
+      input: RDD[LabeledPoint],
+      numClassesForClassification: Int,
+      categoricalFeaturesInfo: Map[Int, Int],
+      numTrees: Int,
+      featureSubsetStrategy: String,
+      impurity: String,
+      maxDepth: Int,
+      maxBins: Int,
+      seed: Int = Utils.random.nextInt()): RandomForestModel = {
+    val impurityType = Impurities.fromString(impurity)
+    val strategy = new Strategy(Classification, impurityType, maxDepth,
+      numClassesForClassification, maxBins, Sort, categoricalFeaturesInfo)
+    trainClassifier(input, strategy, numTrees, featureSubsetStrategy, seed)
+  }
+
+  /**
+   * Java-friendly API for 
[[org.apache.spark.mllib.tree.RandomForest$#trainClassifier]]
+   */
+  def trainClassifier(
+      input: JavaRDD[LabeledPoint],
+      numClassesForClassification: Int,
+      categoricalFeaturesInfo: java.util.Map[java.lang.Integer, 
java.lang.Integer],
+      numTrees: Int,
+      featureSubsetStrategy: String,
+      impurity: String,
+      maxDepth: Int,
+      maxBins: Int,
+      seed: Int): RandomForestModel = {
+    trainClassifier(input.rdd, numClassesForClassification,
+      categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, 
Int]].asScala.toMap,
+      numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
+  }
+
+  /**
+   * Method to train a decision tree model for regression.
+   *
+   * @param input Training dataset: RDD of 
[[org.apache.spark.mllib.regression.LabeledPoint]].
+   *              Labels are real numbers.
+   * @param strategy Parameters for training each tree in the forest.
+   * @param numTrees Number of trees in the random forest.
+   * @param featureSubsetStrategy Number of features to consider for splits at 
each node.
+   *                              Supported: "auto" (default), "all", "sqrt", 
"log2", "onethird".
+   *                              If "auto" is set, this parameter is set 
based on numTrees:
+   *                                if numTrees == 1, set to "all";
+   *                                if numTrees > 1 (forest) set to "sqrt" for 
classification and
+   *                                  to "onethird" for regression.
+   * @param seed  Random seed for bootstrapping and choosing feature subsets.
+   * @return RandomForestModel that can be used for prediction
+   */
+  def trainRegressor(
+      input: RDD[LabeledPoint],
+      strategy: Strategy,
+      numTrees: Int,
+      featureSubsetStrategy: String,
+      seed: Int): RandomForestModel = {
+    require(strategy.algo == Regression,
+      s"RandomForest.trainRegressor given Strategy with invalid algo: 
${strategy.algo}")
+    val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
+    rf.train(input)
+  }
+
+  /**
+   * Method to train a decision tree model for regression.
+   *
+   * @param input Training dataset: RDD of 
[[org.apache.spark.mllib.regression.LabeledPoint]].
+   *              Labels are real numbers.
+   * @param categoricalFeaturesInfo Map storing arity of categorical features.
+   *                                E.g., an entry (n -> k) indicates that 
feature n is categorical
+   *                                with k categories indexed from 0: {0, 1, 
..., k-1}.
+   * @param numTrees Number of trees in the random forest.
+   * @param featureSubsetStrategy Number of features to consider for splits at 
each node.
+   *                              Supported: "auto" (default), "all", "sqrt", 
"log2", "onethird".
+   *                              If "auto" is set, this parameter is set 
based on numTrees:
+   *                                if numTrees == 1, set to "all";
+   *                                if numTrees > 1 (forest) set to "sqrt" for 
classification and
+   *                                  to "onethird" for regression.
+   * @param impurity Criterion used for information gain calculation.
+   *                 Supported values: "variance".
+   * @param maxDepth Maximum depth of the tree.
+   *                 E.g., depth 0 means 1 leaf node; depth 1 means 1 internal 
node + 2 leaf nodes.
+   *                  (suggested value: 4)
+   * @param maxBins maximum number of bins used for splitting features
+   *                 (suggested value: 100)
+   * @param seed  Random seed for bootstrapping and choosing feature subsets.
+   * @return RandomForestModel that can be used for prediction
+   */
+  def trainRegressor(
+      input: RDD[LabeledPoint],
+      categoricalFeaturesInfo: Map[Int, Int],
+      numTrees: Int,
+      featureSubsetStrategy: String,
+      impurity: String,
+      maxDepth: Int,
+      maxBins: Int,
+      seed: Int = Utils.random.nextInt()): RandomForestModel = {
+    val impurityType = Impurities.fromString(impurity)
+    val strategy = new Strategy(Regression, impurityType, maxDepth,
+      0, maxBins, Sort, categoricalFeaturesInfo)
+    trainRegressor(input, strategy, numTrees, featureSubsetStrategy, seed)
+  }
+
+  /**
+   * Java-friendly API for 
[[org.apache.spark.mllib.tree.RandomForest$#trainRegressor]]
+   */
+  def trainRegressor(
+      input: JavaRDD[LabeledPoint],
+      categoricalFeaturesInfo: java.util.Map[java.lang.Integer, 
java.lang.Integer],
+      numTrees: Int,
+      featureSubsetStrategy: String,
+      impurity: String,
+      maxDepth: Int,
+      maxBins: Int,
+      seed: Int): RandomForestModel = {
+    trainRegressor(input.rdd,
+      categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, 
Int]].asScala.toMap,
+      numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
+  }
+
+  /**
+   * List of supported feature subset sampling strategies.
+   */
+  val supportedFeatureSubsetStrategies: Array[String] =
+    Array("auto", "all", "sqrt", "log2", "onethird")
+
+  private[tree] class NodeIndexInfo(
+      val nodeIndexInGroup: Int,
+      val featureSubset: Option[Array[Int]]) extends Serializable
+
+  /**
+   * Pull nodes off of the queue, and collect a group of nodes to be split on 
this iteration.
+   * This tracks the memory usage for aggregates and stops adding nodes when 
too much memory
+   * will be needed; this allows an adaptive number of nodes since different 
nodes may require
+   * different amounts of memory (if featureSubsetStrategy is not "all").
+   *
+   * @param nodeQueue  Queue of nodes to split.
+   * @param maxMemoryUsage  Bound on size of aggregate statistics.
+   * @return  (nodesForGroup, treeToNodeToIndexInfo).
+   *          nodesForGroup holds the nodes to split: treeIndex --> nodes in 
tree.
+   *          treeToNodeToIndexInfo holds indices selected features for each 
node:
+   *            treeIndex --> (global) node index --> (node index in group, 
feature indices).
+   *          The (global) node index is the index in the tree; the node index 
in group is the
+   *           index in [0, numNodesInGroup) of the node in this group.
+   *          The feature indices are None if not subsampling features.
+   */
+  private[tree] def selectNodesToSplit(
+      nodeQueue: mutable.Queue[(Int, Node)],
+      maxMemoryUsage: Long,
+      metadata: DecisionTreeMetadata,
+      rng: scala.util.Random): (Map[Int, Array[Node]], Map[Int, Map[Int, 
NodeIndexInfo]]) = {
+    // Collect some nodes to split:
+    //  nodesForGroup(treeIndex) = nodes to split
+    val mutableNodesForGroup = new mutable.HashMap[Int, 
mutable.ArrayBuffer[Node]]()
+    val mutableTreeToNodeToIndexInfo =
+      new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]()
+    var memUsage: Long = 0L
+    var numNodesInGroup = 0
+    while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) {
+      val (treeIndex, node) = nodeQueue.head
+      // Choose subset of features for node (if subsampling).
+      val featureSubset: Option[Array[Int]] = if 
(metadata.subsamplingFeatures) {
+        // TODO: Use more efficient subsampling?  (use selection-and-rejection 
or reservoir)
+        Some(rng.shuffle(Range(0, metadata.numFeatures).toList)
+          .take(metadata.numFeaturesPerNode).toArray)
+      } else {
+        None
+      }
+      // Check if enough memory remains to add this node to the group.
+      val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, 
featureSubset) * 8L
+      if (memUsage + nodeMemUsage <= maxMemoryUsage) {
+        nodeQueue.dequeue()
+        mutableNodesForGroup.getOrElseUpdate(treeIndex, new 
mutable.ArrayBuffer[Node]()) += node
+        mutableTreeToNodeToIndexInfo
+          .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, 
NodeIndexInfo]())(node.id)
+          = new NodeIndexInfo(numNodesInGroup, featureSubset)
+      }
+      numNodesInGroup += 1
+      memUsage += nodeMemUsage
+    }
+    // Convert mutable maps to immutable ones.
+    val nodesForGroup: Map[Int, Array[Node]] = 
mutableNodesForGroup.mapValues(_.toArray).toMap
+    val treeToNodeToIndexInfo = 
mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap
+    (nodesForGroup, treeToNodeToIndexInfo)
+  }
+
+  /**
+   * Get the number of values to be stored for this node in the bin aggregates.
+   * @param featureSubset  Indices of features which may be split at this node.
+   *                       If None, then use all features.
+   */
+  private[tree] def aggregateSizeForNode(
+      metadata: DecisionTreeMetadata,
+      featureSubset: Option[Array[Int]]): Long = {
+    val totalBins = if (featureSubset.nonEmpty) {
+      featureSubset.get.map(featureIndex => 
metadata.numBins(featureIndex).toLong).sum
+    } else {
+      metadata.numBins.map(_.toLong).sum
+    }
+    if (metadata.isClassification) {
+      metadata.numClasses * totalBins
+    } else {
+      3 * totalBins
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
new file mode 100644
index 0000000..937c8a2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impl
+
+import cern.jet.random.Poisson
+import cern.jet.random.engine.DRand
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
+
+/**
+ * Internal representation of a datapoint which belongs to several subsamples 
of the same dataset,
+ * particularly for bagging (e.g., for random forests).
+ *
+ * This holds one instance, as well as an array of weights which represent the 
(weighted)
+ * number of times which this instance appears in each subsample.
+ * E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the 
dataset and that
+ * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, 
respectively.
+ *
+ * @param datum  Data instance
+ * @param subsampleWeights  Weight of this instance in each subsampled dataset.
+ *
+ * TODO: This does not currently support (Double) weighted instances.  Once 
MLlib has weighted
+ *       dataset support, update.  (We store subsampleWeights as Double for 
this future extension.)
+ */
+private[tree] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: 
Array[Double])
+  extends Serializable
+
+private[tree] object BaggedPoint {
+
+  /**
+   * Convert an input dataset into its BaggedPoint representation,
+   * choosing subsample counts for each instance.
+   * Each subsample has the same number of instances as the original dataset,
+   * and is created by subsampling with replacement.
+   * @param input     Input dataset.
+   * @param numSubsamples  Number of subsamples of this RDD to take.
+   * @param seed   Random seed.
+   * @return  BaggedPoint dataset representation
+   */
+  def convertToBaggedRDD[Datum](
+      input: RDD[Datum],
+      numSubsamples: Int,
+      seed: Int = Utils.random.nextInt()): RDD[BaggedPoint[Datum]] = {
+    input.mapPartitionsWithIndex { (partitionIndex, instances) =>
+      // TODO: Support different sampling rates, and sampling without 
replacement.
+      // Use random seed = seed + partitionIndex + 1 to make generation 
reproducible.
+      val poisson = new Poisson(1.0, new DRand(seed + partitionIndex + 1))
+      instances.map { instance =>
+        val subsampleWeights = new Array[Double](numSubsamples)
+        var subsampleIndex = 0
+        while (subsampleIndex < numSubsamples) {
+          subsampleWeights(subsampleIndex) = poisson.nextInt()
+          subsampleIndex += 1
+        }
+        new BaggedPoint(instance, subsampleWeights)
+      }
+    }
+  }
+
+  def convertToBaggedRDDWithoutSampling[Datum](input: RDD[Datum]): 
RDD[BaggedPoint[Datum]] = {
+    input.map(datum => new BaggedPoint(datum, Array(1.0)))
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
index 61a9424..d49df7a 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -17,16 +17,17 @@
 
 package org.apache.spark.mllib.tree.impl
 
+import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
 import org.apache.spark.mllib.tree.impurity._
 
 /**
  * DecisionTree statistics aggregator.
  * This holds a flat array of statistics for a set of (nodes, features, bins)
  * and helps with indexing.
+ * This class is abstract to support learning with and without feature 
subsampling.
  */
-private[tree] class DTStatsAggregator(
-    val metadata: DecisionTreeMetadata,
-    val numNodes: Int) extends Serializable {
+private[tree] abstract class DTStatsAggregator(
+    val metadata: DecisionTreeMetadata) extends Serializable {
 
   /**
    * [[ImpurityAggregator]] instance specifying the impurity type.
@@ -43,18 +44,6 @@ private[tree] class DTStatsAggregator(
    */
   val statsSize: Int = impurityAggregator.statsSize
 
-  val numFeatures: Int = metadata.numFeatures
-
-  /**
-   * Number of bins for each feature.  This is indexed by the feature index.
-   */
-  val numBins: Array[Int] = metadata.numBins
-
-  /**
-   * Number of splits for the given feature.
-   */
-  def numSplits(featureIndex: Int): Int = metadata.numSplits(featureIndex)
-
   /**
    * Indicator for each feature of whether that feature is an unordered 
feature.
    * TODO: Is Array[Boolean] any faster?
@@ -62,30 +51,14 @@ private[tree] class DTStatsAggregator(
   def isUnordered(featureIndex: Int): Boolean = 
metadata.isUnordered(featureIndex)
 
   /**
-   * Offset for each feature for calculating indices into the [[allStats]] 
array.
-   */
-  private val featureOffsets: Array[Int] = {
-    numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
-  }
-
-  /**
-   * Number of elements for each node, corresponding to stride between nodes 
in [[allStats]].
-   */
-  private val nodeStride: Int = featureOffsets.last
-
-  /**
    * Total number of elements stored in this aggregator.
    */
-  val allStatsSize: Int = numNodes * nodeStride
+  def allStatsSize: Int
 
   /**
-   * Flat array of elements.
-   * Index for start of stats for a (node, feature, bin) is:
-   *   index = nodeIndex * nodeStride + featureOffsets(featureIndex) + 
binIndex * statsSize
-   * Note: For unordered features, the left child stats have binIndex in [0, 
numBins(featureIndex))
-   *       and the right child stats in [numBins(featureIndex), 2 * 
numBins(featureIndex))
+   * Get flat array of elements stored in this aggregator.
    */
-  val allStats: Array[Double] = new Array[Double](allStatsSize)
+  protected def allStats: Array[Double]
 
   /**
    * Get an [[ImpurityCalculator]] for a given (node, feature, bin).
@@ -102,36 +75,39 @@ private[tree] class DTStatsAggregator(
   /**
    * Update the stats for a given (node, feature, bin) for ordered features, 
using the given label.
    */
-  def update(nodeIndex: Int, featureIndex: Int, binIndex: Int, label: Double): 
Unit = {
-    val i = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * 
statsSize
-    impurityAggregator.update(allStats, i, label)
+  def update(
+      nodeIndex: Int,
+      featureIndex: Int,
+      binIndex: Int,
+      label: Double,
+      instanceWeight: Double): Unit = {
+    val i = getNodeFeatureOffset(nodeIndex, featureIndex) + binIndex * 
statsSize
+    impurityAggregator.update(allStats, i, label, instanceWeight)
   }
 
   /**
    * Pre-compute node offset for use with [[nodeUpdate]].
    */
-  def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride
+  def getNodeOffset(nodeIndex: Int): Int
 
   /**
    * Faster version of [[update]].
    * Update the stats for a given (node, feature, bin) for ordered features, 
using the given label.
    * @param nodeOffset  Pre-computed node offset from [[getNodeOffset]].
    */
-  def nodeUpdate(nodeOffset: Int, featureIndex: Int, binIndex: Int, label: 
Double): Unit = {
-    val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize
-    impurityAggregator.update(allStats, i, label)
-  }
+  def nodeUpdate(
+      nodeOffset: Int,
+      nodeIndex: Int,
+      featureIndex: Int,
+      binIndex: Int,
+      label: Double,
+      instanceWeight: Double): Unit
 
   /**
    * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
    * For ordered features only.
    */
-  def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
-    require(!isUnordered(featureIndex),
-      s"DTStatsAggregator.getNodeFeatureOffset is for ordered features only, 
but was called" +
-      s" for unordered feature $featureIndex.")
-    nodeIndex * nodeStride + featureOffsets(featureIndex)
-  }
+  def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int
 
   /**
    * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
@@ -140,9 +116,9 @@ private[tree] class DTStatsAggregator(
   def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, 
Int) = {
     require(isUnordered(featureIndex),
       s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered 
features only," +
-      s" but was called for ordered feature $featureIndex.")
-    val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex)
-    (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
+        s" but was called for ordered feature $featureIndex.")
+    val baseOffset = getNodeFeatureOffset(nodeIndex, featureIndex)
+    (baseOffset, baseOffset + (metadata.numBins(featureIndex) >> 1) * 
statsSize)
   }
 
   /**
@@ -154,8 +130,13 @@ private[tree] class DTStatsAggregator(
    *                           (node, feature, left/right child) offset from
    *                           [[getLeftRightNodeFeatureOffsets]].
    */
-  def nodeFeatureUpdate(nodeFeatureOffset: Int, binIndex: Int, label: Double): 
Unit = {
-    impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * 
statsSize, label)
+  def nodeFeatureUpdate(
+      nodeFeatureOffset: Int,
+      binIndex: Int,
+      label: Double,
+      instanceWeight: Double): Unit = {
+    impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * 
statsSize, label,
+      instanceWeight)
   }
 
   /**
@@ -189,7 +170,139 @@ private[tree] class DTStatsAggregator(
     }
     this
   }
+}
+
+/**
+ * DecisionTree statistics aggregator.
+ * This holds a flat array of statistics for a set of (nodes, features, bins)
+ * and helps with indexing.
+ *
+ * This instance of [[DTStatsAggregator]] is used when not subsampling 
features.
+ *
+ * @param numNodes  Number of nodes to collect statistics for.
+ */
+private[tree] class DTStatsAggregatorFixedFeatures(
+    metadata: DecisionTreeMetadata,
+    numNodes: Int) extends DTStatsAggregator(metadata) {
+
+  /**
+   * Offset for each feature for calculating indices into the [[allStats]] 
array.
+   * Mapping: featureIndex --> offset
+   */
+  private val featureOffsets: Array[Int] = {
+    metadata.numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
+  }
+
+  /**
+   * Number of elements for each node, corresponding to stride between nodes 
in [[allStats]].
+   */
+  private val nodeStride: Int = featureOffsets.last
 
+  override val allStatsSize: Int = numNodes * nodeStride
+
+  /**
+   * Flat array of elements.
+   * Index for start of stats for a (node, feature, bin) is:
+   *   index = nodeIndex * nodeStride + featureOffsets(featureIndex) + 
binIndex * statsSize
+   * Note: For unordered features, the left child stats precede the right 
child stats
+   *       in the binIndex order.
+   */
+  override protected val allStats: Array[Double] = new 
Array[Double](allStatsSize)
+
+  override def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride
+
+  override def nodeUpdate(
+      nodeOffset: Int,
+      nodeIndex: Int,
+      featureIndex: Int,
+      binIndex: Int,
+      label: Double,
+      instanceWeight: Double): Unit = {
+    val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize
+    impurityAggregator.update(allStats, i, label, instanceWeight)
+  }
+
+  override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
+    nodeIndex * nodeStride + featureOffsets(featureIndex)
+  }
+}
+
+/**
+ * DecisionTree statistics aggregator.
+ * This holds a flat array of statistics for a set of (nodes, features, bins)
+ * and helps with indexing.
+ *
+ * This instance of [[DTStatsAggregator]] is used when subsampling features.
+ *
+ * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> 
nodeIndexInfo,
+ *                              where nodeIndexInfo stores the index in the 
group and the
+ *                              feature subsets (if using feature subsets).
+ */
+private[tree] class DTStatsAggregatorSubsampledFeatures(
+    metadata: DecisionTreeMetadata,
+    treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]) extends 
DTStatsAggregator(metadata) {
+
+  /**
+   * For each node, offset for each feature for calculating indices into the 
[[allStats]] array.
+   * Mapping: nodeIndex --> featureIndex --> offset
+   */
+  private val featureOffsets: Array[Array[Int]] = {
+    val numNodes: Int = treeToNodeToIndexInfo.values.map(_.size).sum
+    val offsets = new Array[Array[Int]](numNodes)
+    treeToNodeToIndexInfo.foreach { case (treeIndex, nodeToIndexInfo) =>
+      nodeToIndexInfo.foreach { case (globalNodeIndex, nodeInfo) =>
+        offsets(nodeInfo.nodeIndexInGroup) = 
nodeInfo.featureSubset.get.map(metadata.numBins(_))
+          .scanLeft(0)((total, nBins) => total + statsSize * nBins)
+      }
+    }
+    offsets
+  }
+
+  /**
+   * For each node, offset for each feature for calculating indices into the 
[[allStats]] array.
+   */
+  protected val nodeOffsets: Array[Int] = 
featureOffsets.map(_.last).scanLeft(0)(_ + _)
+
+  override val allStatsSize: Int = nodeOffsets.last
+
+  /**
+   * Flat array of elements.
+   * Index for start of stats for a (node, feature, bin) is:
+   *   index = nodeOffsets(nodeIndex) + featureOffsets(featureIndex) + 
binIndex * statsSize
+   * Note: For unordered features, the left child stats precede the right 
child stats
+   *       in the binIndex order.
+   */
+  override protected val allStats: Array[Double] = new 
Array[Double](allStatsSize)
+
+  override def getNodeOffset(nodeIndex: Int): Int = nodeOffsets(nodeIndex)
+
+  /**
+   * Faster version of [[update]].
+   * Update the stats for a given (node, feature, bin) for ordered features, 
using the given label.
+   * @param nodeOffset  Pre-computed node offset from [[getNodeOffset]].
+   * @param featureIndex  Index of feature in featuresForNodes(nodeIndex).
+   *                      Note: This is NOT the original feature index.
+   */
+  override def nodeUpdate(
+      nodeOffset: Int,
+      nodeIndex: Int,
+      featureIndex: Int,
+      binIndex: Int,
+      label: Double,
+      instanceWeight: Double): Unit = {
+    val i = nodeOffset + featureOffsets(nodeIndex)(featureIndex) + binIndex * 
statsSize
+    impurityAggregator.update(allStats, i, label, instanceWeight)
+  }
+
+  /**
+   * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
+   * For ordered features only.
+   * @param featureIndex  Index of feature in featuresForNodes(nodeIndex).
+   *                      Note: This is NOT the original feature index.
+   */
+  override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
+    nodeOffsets(nodeIndex) + featureOffsets(nodeIndex)(featureIndex)
+  }
 }
 
 private[tree] object DTStatsAggregator extends Serializable {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to