Repository: spark Updated Branches: refs/heads/branch-1.6 290279808 -> d31854da5
[SPARK-12746][ML] ArrayType(_, true) should also accept ArrayType(_, false) fix for branch-1.6 https://issues.apache.org/jira/browse/SPARK-13359 Author: Earthson Lu <[email protected]> Closes #11237 from Earthson/SPARK-13359. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d31854da Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d31854da Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d31854da Branch: refs/heads/branch-1.6 Commit: d31854da5155550f4e9c5e717c92dfec87d0ff6a Parents: 2902798 Author: Earthson Lu <[email protected]> Authored: Mon Feb 22 23:40:36 2016 -0800 Committer: Xiangrui Meng <[email protected]> Committed: Mon Feb 22 23:40:36 2016 -0800 ---------------------------------------------------------------------- .../apache/spark/ml/feature/CountVectorizer.scala | 3 ++- .../org/apache/spark/ml/util/SchemaUtils.scala | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d31854da/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index b9e2144..fab3e74 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -70,7 +70,8 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) + val typeCandidates = List(ArrayType(StringType, true), ArrayType(StringType, false)) + SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } http://git-wip-us.apache.org/repos/asf/spark/blob/d31854da/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 7decbbd..76021ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -44,6 +44,23 @@ private[spark] object SchemaUtils { } /** + * Check whether the given schema contains a column of one of the require data types. + * @param colName column name + * @param dataTypes required column data types + */ + def checkColumnTypes( + schema: StructType, + colName: String, + dataTypes: Seq[DataType], + msg: String = ""): Unit = { + val actualDataType = schema(colName).dataType + val message = if (msg != null && msg.trim.length > 0) " " + msg else "" + require(dataTypes.exists(actualDataType.equals), + s"Column $colName must be of type equal to one of the following types: " + + s"${dataTypes.mkString("[", ", ", "]")} but was actually of type $actualDataType.$message") + } + + /** * Appends a new column to the input schema. This fails if the given output column already exists. * @param schema input schema * @param colName new column name. If this column name is an empty string "", this method returns --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
