Repository: spark Updated Branches: refs/heads/master 0182d9599 -> 507bea5ca
[SPARK-14143] Options for parsing NaNs, Infinity and nulls for numeric types 1. Adds the following options for parsing NaNs: nanValue 2. Adds the following options for parsing infinity: positiveInf, negativeInf. `TypeCast.castTo` is unit tested and an end-to-end test is added to `CSVSuite` Author: Hossein <[email protected]> Closes #11947 from falaki/SPARK-14143. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/507bea5c Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/507bea5c Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/507bea5c Branch: refs/heads/master Commit: 507bea5ca6d95c995f8152b8473713c136e23754 Parents: 0182d95 Author: Hossein <[email protected]> Authored: Sat Apr 30 18:11:56 2016 -0700 Committer: Reynold Xin <[email protected]> Committed: Sat Apr 30 18:12:03 2016 -0700 ---------------------------------------------------------------------- .../datasources/csv/CSVInferSchema.scala | 83 ++++++++++++-------- .../execution/datasources/csv/CSVOptions.scala | 15 ++++ .../execution/datasources/csv/CSVRelation.scala | 3 +- sql/core/src/test/resources/numbers.csv | 9 +++ .../execution/datasources/csv/CSVSuite.scala | 23 ++++++ .../datasources/csv/CSVTypeCastSuite.scala | 83 +++++++++++++++++--- 6 files changed, 174 insertions(+), 42 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/507bea5c/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index a26a808..cfd66af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -190,40 +190,61 @@ private[csv] object CSVTypeCast { datum: String, castType: DataType, nullable: Boolean = true, - nullValue: String = "", - dateFormat: SimpleDateFormat = null): Any = { + options: CSVOptions = CSVOptions()): Any = { - if (datum == nullValue && nullable && (!castType.isInstanceOf[StringType])) { - null - } else { - castType match { - case _: ByteType => datum.toByte - case _: ShortType => datum.toShort - case _: IntegerType => datum.toInt - case _: LongType => datum.toLong - case _: FloatType => Try(datum.toFloat) - .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue()) - case _: DoubleType => Try(datum.toDouble) - .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue()) - case _: BooleanType => datum.toBoolean - case dt: DecimalType => + castType match { + case _: ByteType => if (datum == options.nullValue && nullable) null else datum.toByte + case _: ShortType => if (datum == options.nullValue && nullable) null else datum.toShort + case _: IntegerType => if (datum == options.nullValue && nullable) null else datum.toInt + case _: LongType => if (datum == options.nullValue && nullable) null else datum.toLong + case _: FloatType => + if (datum == options.nullValue && nullable) { + null + } else if (datum == options.nanValue) { + Float.NaN + } else if (datum == options.negativeInf) { + Float.NegativeInfinity + } else if (datum == options.positiveInf) { + Float.PositiveInfinity + } else { + Try(datum.toFloat) + .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue()) + } + case _: DoubleType => + if (datum == options.nullValue && nullable) { + null + } else if (datum == options.nanValue) { + Double.NaN + } else if (datum == options.negativeInf) { + Double.NegativeInfinity + } else if (datum == options.positiveInf) { + Double.PositiveInfinity + } else { + Try(datum.toDouble) + .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue()) + } + case _: BooleanType => datum.toBoolean + case dt: DecimalType => + if (datum == options.nullValue && nullable) { + null + } else { val value = new BigDecimal(datum.replaceAll(",", "")) Decimal(value, dt.precision, dt.scale) - case _: TimestampType if dateFormat != null => - // This one will lose microseconds parts. - // See https://issues.apache.org/jira/browse/SPARK-10681. - dateFormat.parse(datum).getTime * 1000L - case _: TimestampType => - // This one will lose microseconds parts. - // See https://issues.apache.org/jira/browse/SPARK-10681. - DateTimeUtils.stringToTime(datum).getTime * 1000L - case _: DateType if dateFormat != null => - DateTimeUtils.millisToDays(dateFormat.parse(datum).getTime) - case _: DateType => - DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime) - case _: StringType => UTF8String.fromString(datum) - case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") - } + } + case _: TimestampType if options.dateFormat != null => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681. + options.dateFormat.parse(datum).getTime * 1000L + case _: TimestampType => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681. + DateTimeUtils.stringToTime(datum).getTime * 1000L + case _: DateType if options.dateFormat != null => + DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime) + case _: DateType => + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime) + case _: StringType => UTF8String.fromString(datum) + case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") } } http://git-wip-us.apache.org/repos/asf/spark/blob/507bea5c/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index b87d19f..e4fd094 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -90,6 +90,12 @@ private[sql] class CSVOptions(@transient private val parameters: Map[String, Str val nullValue = parameters.getOrElse("nullValue", "") + val nanValue = parameters.getOrElse("nanValue", "NaN") + + val positiveInf = parameters.getOrElse("positiveInf", "Inf") + val negativeInf = parameters.getOrElse("negativeInf", "-Inf") + + val compressionCodec: Option[String] = { val name = parameters.get("compression").orElse(parameters.get("codec")) name.map(CompressionCodecs.getCodecClassName) @@ -111,3 +117,12 @@ private[sql] class CSVOptions(@transient private val parameters: Map[String, Str val rowSeparator = "\n" } + +object CSVOptions { + + def apply(): CSVOptions = new CSVOptions(Map.empty) + + def apply(paramName: String, paramValue: String): CSVOptions = { + new CSVOptions(Map(paramName -> paramValue)) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/507bea5c/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 9a72363..4f2d438 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -99,8 +99,7 @@ object CSVRelation extends Logging { indexSafeTokens(index), field.dataType, field.nullable, - params.nullValue, - params.dateFormat) + params) if (subIndex < requiredSize) { row(subIndex) = value } http://git-wip-us.apache.org/repos/asf/spark/blob/507bea5c/sql/core/src/test/resources/numbers.csv ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/numbers.csv b/sql/core/src/test/resources/numbers.csv new file mode 100644 index 0000000..af8feac --- /dev/null +++ b/sql/core/src/test/resources/numbers.csv @@ -0,0 +1,9 @@ +int,long,float,double +8,1000000,1.042,23848545.0374 +--,34232323,98.343,184721.23987223 +34,--,98.343,184721.23987223 +34,43323123,--,184721.23987223 +34,43323123,223823.9484,-- +34,43323123,223823.NAN,NAN +34,43323123,223823.INF,INF +34,43323123,223823.-INF,-INF http://git-wip-us.apache.org/repos/asf/spark/blob/507bea5c/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 8847c76..07f00a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -46,6 +46,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val disableCommentsFile = "disable_comments.csv" private val boolFile = "bool.csv" private val simpleSparseFile = "simple_sparse.csv" + private val numbersFile = "numbers.csv" private val datesFile = "dates.csv" private val unescapedQuotesFile = "unescaped-quotes.csv" @@ -535,4 +536,26 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { verifyCars(cars, withHeader = false, checkTypes = false) } + + test("nulls, NaNs and Infinity values can be parsed") { + val numbers = sqlContext + .read + .format("csv") + .schema(StructType(List( + StructField("int", IntegerType, true), + StructField("long", LongType, true), + StructField("float", FloatType, true), + StructField("double", DoubleType, true) + ))) + .options(Map( + "header" -> "true", + "mode" -> "DROPMALFORMED", + "nullValue" -> "--", + "nanValue" -> "NAN", + "negativeInf" -> "-INF", + "positiveInf" -> "INF")) + .load(testFile(numbersFile)) + + assert(numbers.count() == 8) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/507bea5c/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala index 8b59bc1..26b33b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala @@ -29,6 +29,8 @@ import org.apache.spark.unsafe.types.UTF8String class CSVTypeCastSuite extends SparkFunSuite { + private def assertNull(v: Any) = assert(v == null) + test("Can parse decimal type values") { val stringValues = Seq("10.05", "1,000.01", "158,058,049.001") val decimalValues = Seq(10.05, 1000.01, 158058049.001) @@ -66,17 +68,21 @@ class CSVTypeCastSuite extends SparkFunSuite { } test("Nullable types are handled") { - assert(CSVTypeCast.castTo("", IntegerType, nullable = true) == null) + assert(CSVTypeCast.castTo("", IntegerType, nullable = true, CSVOptions()) == null) } test("String type should always return the same as the input") { - assert(CSVTypeCast.castTo("", StringType, nullable = true) == UTF8String.fromString("")) - assert(CSVTypeCast.castTo("", StringType, nullable = false) == UTF8String.fromString("")) + assert( + CSVTypeCast.castTo("", StringType, nullable = true, CSVOptions()) == + UTF8String.fromString("")) + assert( + CSVTypeCast.castTo("", StringType, nullable = false, CSVOptions()) == + UTF8String.fromString("")) } test("Throws exception for empty string with non null type") { val exception = intercept[NumberFormatException]{ - CSVTypeCast.castTo("", IntegerType, nullable = false) + CSVTypeCast.castTo("", IntegerType, nullable = false, CSVOptions()) } assert(exception.getMessage.contains("For input string: \"\"")) } @@ -90,12 +96,12 @@ class CSVTypeCastSuite extends SparkFunSuite { assert(CSVTypeCast.castTo("1.00", DoubleType) == 1.0) assert(CSVTypeCast.castTo("true", BooleanType) == true) - val dateFormat = new SimpleDateFormat("dd/MM/yyyy hh:mm") + val options = CSVOptions("dateFormat", "dd/MM/yyyy hh:mm") val customTimestamp = "31/01/2015 00:00" - val expectedTime = dateFormat.parse("31/01/2015 00:00").getTime - assert(CSVTypeCast.castTo(customTimestamp, TimestampType, dateFormat = dateFormat) - == expectedTime * 1000L) - assert(CSVTypeCast.castTo(customTimestamp, DateType, dateFormat = dateFormat) == + val expectedTime = options.dateFormat.parse("31/01/2015 00:00").getTime + assert(CSVTypeCast.castTo(customTimestamp, TimestampType, nullable = true, options) == + expectedTime * 1000L) + assert(CSVTypeCast.castTo(customTimestamp, DateType, nullable = true, options) == DateTimeUtils.millisToDays(expectedTime)) val timestamp = "2015-01-01 00:00:00" @@ -116,4 +122,63 @@ class CSVTypeCastSuite extends SparkFunSuite { Locale.setDefault(originalLocale) } } + + test("Float NaN values are parsed correctly") { + val floatVal: Float = CSVTypeCast.castTo( + "nn", FloatType, nullable = true, CSVOptions("nanValue", "nn")).asInstanceOf[Float] + + // Java implements the IEEE-754 floating point standard which guarantees that any comparison + // against NaN will return false (except != which returns true) + assert(floatVal != floatVal) + } + + test("Double NaN values are parsed correctly") { + val doubleVal: Double = CSVTypeCast.castTo( + "-", DoubleType, nullable = true, CSVOptions("nanValue", "-")).asInstanceOf[Double] + + assert(doubleVal.isNaN) + } + + test("Float infinite values can be parsed") { + val floatVal1 = CSVTypeCast.castTo( + "max", FloatType, nullable = true, CSVOptions("negativeInf", "max")).asInstanceOf[Float] + + assert(floatVal1 == Float.NegativeInfinity) + + val floatVal2 = CSVTypeCast.castTo( + "max", FloatType, nullable = true, CSVOptions("positiveInf", "max")).asInstanceOf[Float] + + assert(floatVal2 == Float.PositiveInfinity) + } + + test("Double infinite values can be parsed") { + val doubleVal1 = CSVTypeCast.castTo( + "max", DoubleType, nullable = true, CSVOptions("negativeInf", "max") + ).asInstanceOf[Double] + + assert(doubleVal1 == Double.NegativeInfinity) + + val doubleVal2 = CSVTypeCast.castTo( + "max", DoubleType, nullable = true, CSVOptions("positiveInf", "max") + ).asInstanceOf[Double] + + assert(doubleVal2 == Double.PositiveInfinity) + } + + test("Type-specific null values are used for casting") { + assertNull( + CSVTypeCast.castTo("-", ByteType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", ShortType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", IntegerType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", LongType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", FloatType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", DoubleType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", DecimalType.DoubleDecimal, true, CSVOptions("nullValue", "-"))) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
