Repository: spark Updated Branches: refs/heads/master a1b136d05 -> c8b612dec
[SPARK-17870][MLLIB][ML] Change statistic to pValue for SelectKBest and SelectPercentile because of DoF difference ## What changes were proposed in this pull request? For feature selection method ChiSquareSelector, it is based on the ChiSquareTestResult.statistic (ChiSqure value) to select the features. It select the features with the largest ChiSqure value. But the Degree of Freedom (df) of ChiSqure value is different in Statistics.chiSqTest(RDD), and for different df, you cannot base on ChiSqure value to select features. So we change statistic to pValue for SelectKBest and SelectPercentile ## How was this patch tested? change existing test Author: Peng <[email protected]> Closes #15444 from mpjlu/chisqure-bug. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c8b612de Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c8b612de Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c8b612de Branch: refs/heads/master Commit: c8b612decba28e51789891f7881b6d4ebc50e2bb Parents: a1b136d Author: Peng <[email protected]> Authored: Fri Oct 14 12:48:57 2016 +0100 Committer: Sean Owen <[email protected]> Committed: Fri Oct 14 12:48:57 2016 +0100 ---------------------------------------------------------------------- .../scala/org/apache/spark/mllib/feature/ChiSqSelector.scala | 4 ++-- .../org/apache/spark/ml/feature/ChiSqSelectorSuite.scala | 6 +++--- .../org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala | 8 ++++---- python/pyspark/ml/feature.py | 4 ++-- python/pyspark/mllib/feature.py | 8 ++++---- 5 files changed, 15 insertions(+), 15 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/c8b612de/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index c305b36..f8276de 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -234,11 +234,11 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { val features = selectorType match { case ChiSqSelector.KBest => chiSqTestResult - .sortBy { case (res, _) => -res.statistic } + .sortBy { case (res, _) => res.pValue } .take(numTopFeatures) case ChiSqSelector.Percentile => chiSqTestResult - .sortBy { case (res, _) => -res.statistic } + .sortBy { case (res, _) => res.pValue } .take((chiSqTestResult.length * percentile).toInt) case ChiSqSelector.FPR => chiSqTestResult http://git-wip-us.apache.org/repos/asf/spark/blob/c8b612de/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index dfebfc8..6af06d8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -38,10 +38,10 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext ) val preFilteredData = Seq( - Vectors.dense(0.0), - Vectors.dense(6.0), Vectors.dense(8.0), - Vectors.dense(5.0) + Vectors.dense(0.0), + Vectors.dense(0.0), + Vectors.dense(8.0) ) val df = sc.parallelize(data.zip(preFilteredData)) http://git-wip-us.apache.org/repos/asf/spark/blob/c8b612de/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala index ec23a4a..ac702b4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -54,10 +54,10 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))), LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))), 2) val preFilteredData = - Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))), - LabeledPoint(1.0, Vectors.dense(Array(6.0))), - LabeledPoint(1.0, Vectors.dense(Array(8.0))), - LabeledPoint(2.0, Vectors.dense(Array(5.0)))) + Set(LabeledPoint(0.0, Vectors.dense(Array(8.0))), + LabeledPoint(1.0, Vectors.dense(Array(0.0))), + LabeledPoint(1.0, Vectors.dense(Array(0.0))), + LabeledPoint(2.0, Vectors.dense(Array(8.0)))) val model = new ChiSqSelector(1).fit(labeledDiscreteData) val filteredData = labeledDiscreteData.map { lp => LabeledPoint(lp.label, model.transform(lp.features)) http://git-wip-us.apache.org/repos/asf/spark/blob/c8b612de/python/pyspark/ml/feature.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index a33c3e7..7683360 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2592,9 +2592,9 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja >>> selector = ChiSqSelector(numTopFeatures=1, outputCol="selectedFeatures") >>> model = selector.fit(df) >>> model.transform(df).head().selectedFeatures - DenseVector([1.0]) + DenseVector([18.0]) >>> model.selectedFeatures - [3] + [2] >>> chiSqSelectorPath = temp_path + "/chi-sq-selector" >>> selector.save(chiSqSelectorPath) >>> loadedSelector = ChiSqSelector.load(chiSqSelectorPath) http://git-wip-us.apache.org/repos/asf/spark/blob/c8b612de/python/pyspark/mllib/feature.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 4aea818..50ef7c7 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -288,15 +288,15 @@ class ChiSqSelector(object): ... ] >>> model = ChiSqSelector().setNumTopFeatures(1).fit(sc.parallelize(data)) >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0})) - SparseVector(1, {0: 6.0}) + SparseVector(1, {}) >>> model.transform(DenseVector([8.0, 9.0, 5.0])) - DenseVector([5.0]) + DenseVector([8.0]) >>> model = ChiSqSelector().setSelectorType("percentile").setPercentile(0.34).fit( ... sc.parallelize(data)) >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0})) - SparseVector(1, {0: 6.0}) + SparseVector(1, {}) >>> model.transform(DenseVector([8.0, 9.0, 5.0])) - DenseVector([5.0]) + DenseVector([8.0]) >>> data = [ ... LabeledPoint(0.0, SparseVector(4, {0: 8.0, 1: 7.0})), ... LabeledPoint(1.0, SparseVector(4, {1: 9.0, 2: 6.0, 3: 4.0})), --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
