Repository: spark Updated Branches: refs/heads/branch-1.6 c3ed9504d -> 496496b25
[SPARK-14159][ML] Fixed bug in StringIndexer + related issue in RFormula - 1.6 backport Backport of [https://github.com/apache/spark/pull/11965] for branch-1.6. There were no merge conflicts. ## What changes were proposed in this pull request? StringIndexerModel.transform sets the output column metadata to use name inputCol. It should not. Fixing this causes a problem with the metadata produced by RFormula. Fix in RFormula: I added the StringIndexer columns to prefixesToRewrite, and I modified VectorAttributeRewriter to find and replace all "prefixes" since attributes collect multiple prefixes from StringIndexer + Interaction. Note that "prefixes" is no longer accurate since internal strings may be replaced. ## How was this patch tested? Unit test which failed before this fix. Author: Joseph K. Bradley <[email protected]> Closes #12595 from jkbradley/StringIndexer-fix-1.6. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/496496b2 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/496496b2 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/496496b2 Branch: refs/heads/branch-1.6 Commit: 496496b2548387ad2889eaaf884ab039096436ef Parents: c3ed950 Author: Joseph K. Bradley <[email protected]> Authored: Tue Apr 26 14:00:39 2016 -0700 Committer: Joseph K. Bradley <[email protected]> Committed: Tue Apr 26 14:00:39 2016 -0700 ---------------------------------------------------------------------- .../scala/org/apache/spark/ml/feature/RFormula.scala | 15 ++++++--------- .../org/apache/spark/ml/feature/StringIndexer.scala | 7 +++---- .../apache/spark/ml/feature/StringIndexerSuite.scala | 13 +++++++++++++ 3 files changed, 22 insertions(+), 13 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/496496b2/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 5c43a41..564c867 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -101,6 +101,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R encoderStages += new StringIndexer() .setInputCol(term) .setOutputCol(indexCol) + prefixesToRewrite(indexCol + "_") = term + "_" (term, indexCol) case _ => (term, term) @@ -198,7 +199,7 @@ class RFormulaModel private[feature]( override def copy(extra: ParamMap): RFormulaModel = copyValues( new RFormulaModel(uid, resolvedFormula, pipelineModel)) - override def toString: String = s"RFormulaModel(${resolvedFormula}) (uid=$uid)" + override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)" private def transformLabel(dataset: DataFrame): DataFrame = { val labelName = resolvedFormula.label @@ -268,14 +269,10 @@ private class VectorAttributeRewriter( val group = AttributeGroup.fromStructField(dataset.schema(vectorCol)) val attrs = group.attributes.get.map { attr => if (attr.name.isDefined) { - val name = attr.name.get - val replacement = prefixesToRewrite.filter { case (k, _) => name.startsWith(k) } - if (replacement.nonEmpty) { - val (k, v) = replacement.headOption.get - attr.withName(v + name.stripPrefix(k)) - } else { - attr + val name = prefixesToRewrite.foldLeft(attr.name.get) { case (curName, (from, to)) => + curName.replace(from, to) } + attr.withName(name) } else { attr } @@ -284,7 +281,7 @@ private class VectorAttributeRewriter( } val otherCols = dataset.columns.filter(_ != vectorCol).map(dataset.col) val rewrittenCol = dataset.col(vectorCol).as(vectorCol, metadata) - dataset.select((otherCols :+ rewrittenCol): _*) + dataset.select(otherCols :+ rewrittenCol : _*) } override def transformSchema(schema: StructType): StructType = { http://git-wip-us.apache.org/repos/asf/spark/blob/496496b2/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index b3413a1..a843cc8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -160,15 +160,14 @@ class StringIndexerModel ( } val metadata = NominalAttribute.defaultAttr - .withName($(inputCol)).withValues(labels).toMetadata() + .withName($(outputCol)).withValues(labels).toMetadata() // If we are skipping invalid records, filter them out. - val filteredDataset = (getHandleInvalid) match { - case "skip" => { + val filteredDataset = getHandleInvalid match { + case "skip" => val filterer = udf { label: String => labelToIndex.contains(label) } dataset.where(filterer(dataset($(inputCol)))) - } case _ => dataset } filteredDataset.select(col("*"), http://git-wip-us.apache.org/repos/asf/spark/blob/496496b2/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 26f4613..6ba4aaa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -210,4 +210,17 @@ class StringIndexerSuite .setLabels(Array("a", "b", "c")) testDefaultReadWrite(t) } + + test("StringIndexer metadata") { + val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + val transformed = indexer.transform(df) + val attrs = + NominalAttribute.decodeStructField(transformed.schema("labelIndex"), preserveName = true) + assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex") + } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
