Repository: spark Updated Branches: refs/heads/master fbad72288 -> 73ab7f141
http://git-wip-us.apache.org/repos/asf/spark/blob/73ab7f14/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index a5c49a3..2f36fd9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -23,10 +23,10 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ -import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} -import org.apache.spark.mllib.tree.impl.TreePoint +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LocalSparkContext import org.apache.spark.mllib.regression.LabeledPoint @@ -64,7 +64,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(bins.length === 2) assert(splits(0).length === 99) @@ -82,7 +83,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(bins.length === 2) assert(splits(0).length === 99) @@ -162,7 +164,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) // Check splits. @@ -279,7 +282,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) // Expecting 2^2 - 1 = 3 bins/splits assert(splits(0)(0).feature === 0) @@ -373,7 +377,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) // 2^10 - 1 > 100, so categorical variables will be ordered @@ -428,10 +433,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + new Array[Node](0), splits, bins, 10) val split = bestSplits(0)._1 assert(split.categories.length === 1) @@ -456,10 +462,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + new Array[Node](0), splits, bins, 10) val split = bestSplits(0)._1 assert(split.categories.length === 1) @@ -495,7 +502,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -503,9 +511,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -518,7 +526,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -526,9 +535,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -542,7 +551,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -550,9 +560,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -566,7 +576,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -574,9 +585,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -590,7 +601,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -598,14 +610,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()), -1) - val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()) ,1) - val filters = Array[List[Filter]](List(), List(leftFilter), List(rightFilter)) + // Train a 1-node model + val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100) + val modelOneNode = DecisionTree.train(rdd, strategyOneNode) + val nodes: Array[Node] = new Array[Node](7) + nodes(0) = modelOneNode.topNode + nodes(0).leftNode = None + nodes(0).rightNode = None + val parentImpurities = Array(0.5, 0.5, 0.5) // Single group second level tree construction. - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, filters, + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, nodes, splits, bins, 10) assert(bestSplits.length === 2) assert(bestSplits(0)._2.gain > 0) @@ -613,8 +630,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second // level tree construction. - val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, - filters, splits, bins, 0) + val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, + nodes, splits, bins, 0) assert(bestSplitsWithGroups.length === 2) assert(bestSplitsWithGroups(0)._2.gain > 0) assert(bestSplitsWithGroups(1)._2.gain > 0) @@ -629,19 +646,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity) assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict) } - } test("stump with categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(strategy.isMulticlassClassification) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -657,11 +674,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0)) arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0)) arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0)) - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 2) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) @@ -688,20 +705,22 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for multiclass classification, with just enough bins") { val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, - numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + numClassesForClassification = 3, maxBins = maxBins, + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -716,18 +735,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with continuous variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3) assert(strategy.isMulticlassClassification) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 0.9) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -741,18 +761,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with continuous + categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3)) assert(strategy.isMulticlassClassification) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 0.9) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -765,14 +786,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for ordered multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
