Repository: spark Updated Branches: refs/heads/branch-1.6 8a2b8fcbb -> b54a586af
[SPARK-17027][ML] Avoid integer overflow in PolynomialExpansion.getPolySize Replaces custom choose function with o.a.commons.math3.CombinatoricsUtils.binomialCoefficient Spark unit tests Author: zero323 <[email protected]> Closes #14614 from zero323/SPARK-17027. (cherry picked from commit 0ebf7c1bff736cf54ec47957d71394d5b75b47a7) Signed-off-by: Sean Owen <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b54a586a Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b54a586a Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b54a586a Branch: refs/heads/branch-1.6 Commit: b54a586af4b8ca7e8b97311bf5e75e00797de899 Parents: 8a2b8fc Author: zero323 <[email protected]> Authored: Sun Aug 14 11:59:24 2016 +0100 Committer: Sean Owen <[email protected]> Committed: Sun Aug 14 12:01:26 2016 +0100 ---------------------------------------------------------------------- .../spark/ml/feature/PolynomialExpansion.scala | 10 ++++---- .../ml/feature/PolynomialExpansionSuite.scala | 24 ++++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/b54a586a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 0861059..684f71f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -19,6 +19,8 @@ package org.apache.spark.ml.feature import scala.collection.mutable +import org.apache.commons.math3.util.CombinatoricsUtils + import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{ParamMap, IntParam, ParamValidators} @@ -80,12 +82,12 @@ class PolynomialExpansion(override val uid: String) @Since("1.6.0") object PolynomialExpansion extends DefaultParamsReadable[PolynomialExpansion] { - private def choose(n: Int, k: Int): Int = { - Range(n, n - k, -1).product / Range(k, 1, -1).product + private def getPolySize(numFeatures: Int, degree: Int): Int = { + val n = CombinatoricsUtils.binomialCoefficient(numFeatures + degree, degree) + require(n <= Integer.MAX_VALUE) + n.toInt } - private def getPolySize(numFeatures: Int, degree: Int): Int = choose(numFeatures + degree, degree) - private def expandDense( values: Array[Double], lastIdx: Int, http://git-wip-us.apache.org/repos/asf/spark/blob/b54a586a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index 70892dc..9b062fe 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -108,5 +108,29 @@ class PolynomialExpansionSuite .setDegree(3) testDefaultReadWrite(t) } + + test("SPARK-17027. Integer overflow in PolynomialExpansion.getPolySize") { + val data: Array[(Vector, Int, Int)] = Array( + (Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0), 3002, 4367), + (Vectors.sparse(5, Seq((0, 1.0), (4, 5.0))), 3002, 4367), + (Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0, 6.0), 8007, 12375) + ) + + val df = spark.createDataFrame(data) + .toDF("features", "expectedPoly10size", "expectedPoly11size") + + val t = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + + for (i <- Seq(10, 11)) { + val transformed = t.setDegree(i) + .transform(df) + .select(s"expectedPoly${i}size", "polyFeatures") + .rdd.map { case Row(expected: Int, v: Vector) => expected == v.size } + + assert(transformed.collect.forall(identity)) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
