Repository: spark
Updated Branches:
  refs/heads/master fbad72288 -> 73ab7f141


http://git-wip-us.apache.org/repos/asf/spark/blob/73ab7f14/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index a5c49a3..2f36fd9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -23,10 +23,10 @@ import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.FeatureType._
-import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy}
-import org.apache.spark.mllib.tree.impl.TreePoint
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint}
 import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
-import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split}
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node}
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.LocalSparkContext
 import org.apache.spark.mllib.regression.LabeledPoint
@@ -64,7 +64,8 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(Classification, Gini, 3, 2, 100)
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(bins.length === 2)
     assert(splits(0).length === 99)
@@ -82,7 +83,8 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
       numClassesForClassification = 2,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(bins.length === 2)
     assert(splits(0).length === 99)
@@ -162,7 +164,8 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
       numClassesForClassification = 2,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
 
     // Check splits.
 
@@ -279,7 +282,8 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
       numClassesForClassification = 100,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
 
     // Expecting 2^2 - 1 = 3 bins/splits
     assert(splits(0)(0).feature === 0)
@@ -373,7 +377,8 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
       numClassesForClassification = 100,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 10, 1-> 10))
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
 
     // 2^10 - 1 > 100, so categorical variables will be ordered
 
@@ -428,10 +433,11 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
       maxDepth = 2,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
-    val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), 
strategy, 0,
-      Array[List[Filter]](), splits, bins, 10)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), 
metadata, 0,
+      new Array[Node](0), splits, bins, 10)
 
     val split = bestSplits(0)._1
     assert(split.categories.length === 1)
@@ -456,10 +462,11 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
       maxDepth = 2,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
-    val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), 
strategy, 0,
-      Array[List[Filter]](), splits, bins, 10)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), 
metadata, 0,
+      new Array[Node](0), splits, bins, 10)
 
     val split = bestSplits(0)._1
     assert(split.categories.length === 1)
@@ -495,7 +502,8 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(Classification, Gini, 3, 2, 100)
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
     assert(bins.length === 2)
@@ -503,9 +511,9 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(splits(0).length === 99)
     assert(bins(0).length === 100)
 
-    val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), 
strategy, 0,
-      Array[List[Filter]](), splits, bins, 10)
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), 
metadata, 0,
+      new Array[Node](0), splits, bins, 10)
     assert(bestSplits.length === 1)
     assert(bestSplits(0)._1.feature === 0)
     assert(bestSplits(0)._2.gain === 0)
@@ -518,7 +526,8 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(Classification, Gini, 3, 2, 100)
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
     assert(bins.length === 2)
@@ -526,9 +535,9 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(splits(0).length === 99)
     assert(bins(0).length === 100)
 
-    val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), 
strategy, 0,
-      Array[List[Filter]](), splits, bins, 10)
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), 
metadata, 0,
+      new Array[Node](0), splits, bins, 10)
     assert(bestSplits.length === 1)
     assert(bestSplits(0)._1.feature === 0)
     assert(bestSplits(0)._2.gain === 0)
@@ -542,7 +551,8 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
     assert(bins.length === 2)
@@ -550,9 +560,9 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(splits(0).length === 99)
     assert(bins(0).length === 100)
 
-    val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), 
strategy, 0,
-      Array[List[Filter]](), splits, bins, 10)
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), 
metadata, 0,
+      new Array[Node](0), splits, bins, 10)
     assert(bestSplits.length === 1)
     assert(bestSplits(0)._1.feature === 0)
     assert(bestSplits(0)._2.gain === 0)
@@ -566,7 +576,8 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
     assert(bins.length === 2)
@@ -574,9 +585,9 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(splits(0).length === 99)
     assert(bins(0).length === 100)
 
-    val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), 
strategy, 0,
-      Array[List[Filter]](), splits, bins, 10)
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), 
metadata, 0,
+      new Array[Node](0), splits, bins, 10)
     assert(bestSplits.length === 1)
     assert(bestSplits(0)._1.feature === 0)
     assert(bestSplits(0)._2.gain === 0)
@@ -590,7 +601,8 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
     val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
     assert(bins.length === 2)
@@ -598,14 +610,19 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(splits(0).length === 99)
     assert(bins(0).length === 100)
 
-    val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()), 
-1)
-    val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()) 
,1)
-    val filters = Array[List[Filter]](List(), List(leftFilter), 
List(rightFilter))
+    // Train a 1-node model
+    val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100)
+    val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
+    val nodes: Array[Node] = new Array[Node](7)
+    nodes(0) = modelOneNode.topNode
+    nodes(0).leftNode = None
+    nodes(0).rightNode = None
+
     val parentImpurities = Array(0.5, 0.5, 0.5)
 
     // Single group second level tree construction.
-    val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, 
strategy, 1, filters,
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, 
metadata, 1, nodes,
       splits, bins, 10)
     assert(bestSplits.length === 2)
     assert(bestSplits(0)._2.gain > 0)
@@ -613,8 +630,8 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
 
     // maxLevelForSingleGroup parameter is set to 0 to force splitting into 
groups for second
     // level tree construction.
-    val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, 
parentImpurities, strategy, 1,
-      filters, splits, bins, 0)
+    val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, 
parentImpurities, metadata, 1,
+      nodes, splits, bins, 0)
     assert(bestSplitsWithGroups.length === 2)
     assert(bestSplitsWithGroups(0)._2.gain > 0)
     assert(bestSplitsWithGroups(1)._2.gain > 0)
@@ -629,19 +646,19 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
       assert(bestSplits(i)._2.rightImpurity === 
bestSplitsWithGroups(i)._2.rightImpurity)
       assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict)
     }
-
   }
 
   test("stump with categorical variables for multiclass classification") {
     val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
-    val input = sc.parallelize(arr)
+    val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 4,
       numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 
-> 3))
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
     assert(strategy.isMulticlassClassification)
-    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
-    val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), 
strategy, 0,
-      Array[List[Filter]](), splits, bins, 10)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), 
metadata, 0,
+      new Array[Node](0), splits, bins, 10)
 
     assert(bestSplits.length === 1)
     val bestSplit = bestSplits(0)._1
@@ -657,11 +674,11 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0))
     arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0))
     arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0))
-    val input = sc.parallelize(arr)
+    val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 4,
       numClassesForClassification = 2)
 
-    val model = DecisionTree.train(input, strategy)
+    val model = DecisionTree.train(rdd, strategy)
     validateClassifier(model, arr, 1.0)
     assert(model.numNodes === 3)
     assert(model.depth === 1)
@@ -688,20 +705,22 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
   test("stump with categorical variables for multiclass classification, with 
just enough bins") {
     val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow 
unordered features
     val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
-    val input = sc.parallelize(arr)
+    val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 4,
-      numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 
-> 3))
+      numClassesForClassification = 3, maxBins = maxBins,
+      categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
     assert(strategy.isMulticlassClassification)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
 
-    val model = DecisionTree.train(input, strategy)
+    val model = DecisionTree.train(rdd, strategy)
     validateClassifier(model, arr, 1.0)
     assert(model.numNodes === 3)
     assert(model.depth === 1)
 
-    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
-    val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), 
strategy, 0,
-      Array[List[Filter]](), splits, bins, 10)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), 
metadata, 0,
+      new Array[Node](0), splits, bins, 10)
 
     assert(bestSplits.length === 1)
     val bestSplit = bestSplits(0)._1
@@ -716,18 +735,19 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
 
   test("stump with continuous variables for multiclass classification") {
     val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
-    val input = sc.parallelize(arr)
+    val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 4,
       numClassesForClassification = 3)
     assert(strategy.isMulticlassClassification)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
 
-    val model = DecisionTree.train(input, strategy)
+    val model = DecisionTree.train(rdd, strategy)
     validateClassifier(model, arr, 0.9)
 
-    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
-    val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), 
strategy, 0,
-      Array[List[Filter]](), splits, bins, 10)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), 
metadata, 0,
+      new Array[Node](0), splits, bins, 10)
 
     assert(bestSplits.length === 1)
     val bestSplit = bestSplits(0)._1
@@ -741,18 +761,19 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
 
   test("stump with continuous + categorical variables for multiclass 
classification") {
     val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
-    val input = sc.parallelize(arr)
+    val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 4,
       numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3))
     assert(strategy.isMulticlassClassification)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
 
-    val model = DecisionTree.train(input, strategy)
+    val model = DecisionTree.train(rdd, strategy)
     validateClassifier(model, arr, 0.9)
 
-    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
-    val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), 
strategy, 0,
-      Array[List[Filter]](), splits, bins, 10)
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), 
metadata, 0,
+      new Array[Node](0), splits, bins, 10)
 
     assert(bestSplits.length === 1)
     val bestSplit = bestSplits(0)._1
@@ -765,14 +786,16 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
 
   test("stump with categorical variables for ordered multiclass 
classification") {
     val arr = 
DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
-    val input = sc.parallelize(arr)
+    val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini, 
maxDepth = 4,
       numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 
1 -> 10))
     assert(strategy.isMulticlassClassification)
-    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
-    val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), 
strategy, 0,
-      Array[List[Filter]](), splits, bins, 10)
+    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), 
metadata, 0,
+      new Array[Node](0), splits, bins, 10)
 
     assert(bestSplits.length === 1)
     val bestSplit = bestSplits(0)._1


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to