Repository: spark Updated Branches: refs/heads/master 7e28fabdf -> 909c6d812
[SPARK-16307][ML] Add test to verify the predicted variances of a DT on toy data ## What changes were proposed in this pull request? The current tests assumes that `impurity.calculate()` returns the variance correctly. It should be better to make the tests independent of this assumption. In other words verify that the variance computed equals the variance computed manually on a small tree. ## How was this patch tested? The patch is a test.... Author: MechCoder <[email protected]> Closes #13981 from MechCoder/dt_variance. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/909c6d81 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/909c6d81 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/909c6d81 Branch: refs/heads/master Commit: 909c6d812f6ca3a3305e4611a700c8c17905b953 Parents: 7e28fab Author: MechCoder <[email protected]> Authored: Wed Jul 6 02:54:44 2016 -0700 Committer: Yanbo Liang <[email protected]> Committed: Wed Jul 6 02:54:44 2016 -0700 ---------------------------------------------------------------------- .../regression/DecisionTreeRegressorSuite.scala | 20 ++++++++++++++++++++ .../apache/spark/ml/tree/impl/TreeTests.scala | 12 ++++++++++++ 2 files changed, 32 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/909c6d81/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 9afb742..15fa26e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} @@ -96,6 +97,25 @@ class DecisionTreeRegressorSuite assert(variance === expectedVariance, s"Expected variance $expectedVariance but got $variance.") } + + val varianceData: RDD[LabeledPoint] = TreeTests.varianceData(sc) + val varianceDF = TreeTests.setMetadata(varianceData, Map.empty[Int, Int], 0) + dt.setMaxDepth(1) + .setMaxBins(6) + .setSeed(0) + val transformVarDF = dt.fit(varianceDF).transform(varianceDF) + val calculatedVariances = transformVarDF.select(dt.getVarianceCol).collect().map { + case Row(variance: Double) => variance + } + + // Since max depth is set to 1, the best split point is that which splits the data + // into (0.0, 1.0, 2.0) and (10.0, 12.0, 14.0). The predicted variance for each + // data point in the left node is 0.667 and for each data point in the right node + // is 2.667 + val expectedVariances = Array(0.667, 0.667, 0.667, 2.667, 2.667, 2.667) + calculatedVariances.zip(expectedVariances).foreach { case (actual, expected) => + assert(actual ~== expected absTol 1e-3) + } } test("Feature importance with toy data") { http://git-wip-us.apache.org/repos/asf/spark/blob/909c6d81/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index d2fa8d0..c90cb8c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -183,6 +183,18 @@ private[ml] object TreeTests extends SparkFunSuite { )) /** + * Create some toy data for testing correctness of variance. + */ + def varianceData(sc: SparkContext): RDD[LabeledPoint] = sc.parallelize(Seq( + new LabeledPoint(1.0, Vectors.dense(Array(0.0))), + new LabeledPoint(2.0, Vectors.dense(Array(1.0))), + new LabeledPoint(3.0, Vectors.dense(Array(2.0))), + new LabeledPoint(10.0, Vectors.dense(Array(3.0))), + new LabeledPoint(12.0, Vectors.dense(Array(4.0))), + new LabeledPoint(14.0, Vectors.dense(Array(5.0))) + )) + + /** * Mapping from all Params to valid settings which differ from the defaults. * This is useful for tests which need to exercise all Params, such as save/load. * This excludes input columns to simplify some tests. --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
