Repository: spark
Updated Branches:
  refs/heads/master 0182d9599 -> 507bea5ca


[SPARK-14143] Options for parsing NaNs, Infinity and nulls for numeric types

1. Adds the following options for parsing NaNs: nanValue

2. Adds the following options for parsing infinity: positiveInf, negativeInf.

`TypeCast.castTo` is unit tested and an end-to-end test is added to `CSVSuite`

Author: Hossein <[email protected]>

Closes #11947 from falaki/SPARK-14143.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/507bea5c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/507bea5c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/507bea5c

Branch: refs/heads/master
Commit: 507bea5ca6d95c995f8152b8473713c136e23754
Parents: 0182d95
Author: Hossein <[email protected]>
Authored: Sat Apr 30 18:11:56 2016 -0700
Committer: Reynold Xin <[email protected]>
Committed: Sat Apr 30 18:12:03 2016 -0700

----------------------------------------------------------------------
 .../datasources/csv/CSVInferSchema.scala        | 83 ++++++++++++--------
 .../execution/datasources/csv/CSVOptions.scala  | 15 ++++
 .../execution/datasources/csv/CSVRelation.scala |  3 +-
 sql/core/src/test/resources/numbers.csv         |  9 +++
 .../execution/datasources/csv/CSVSuite.scala    | 23 ++++++
 .../datasources/csv/CSVTypeCastSuite.scala      | 83 +++++++++++++++++---
 6 files changed, 174 insertions(+), 42 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/507bea5c/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
index a26a808..cfd66af 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -190,40 +190,61 @@ private[csv] object CSVTypeCast {
       datum: String,
       castType: DataType,
       nullable: Boolean = true,
-      nullValue: String = "",
-      dateFormat: SimpleDateFormat = null): Any = {
+      options: CSVOptions = CSVOptions()): Any = {
 
-    if (datum == nullValue && nullable && 
(!castType.isInstanceOf[StringType])) {
-      null
-    } else {
-      castType match {
-        case _: ByteType => datum.toByte
-        case _: ShortType => datum.toShort
-        case _: IntegerType => datum.toInt
-        case _: LongType => datum.toLong
-        case _: FloatType => Try(datum.toFloat)
-          
.getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue())
-        case _: DoubleType => Try(datum.toDouble)
-          
.getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue())
-        case _: BooleanType => datum.toBoolean
-        case dt: DecimalType =>
+    castType match {
+      case _: ByteType => if (datum == options.nullValue && nullable) null 
else datum.toByte
+      case _: ShortType => if (datum == options.nullValue && nullable) null 
else datum.toShort
+      case _: IntegerType => if (datum == options.nullValue && nullable) null 
else datum.toInt
+      case _: LongType => if (datum == options.nullValue && nullable) null 
else datum.toLong
+      case _: FloatType =>
+        if (datum == options.nullValue && nullable) {
+          null
+        } else if (datum == options.nanValue) {
+          Float.NaN
+        } else if (datum == options.negativeInf) {
+          Float.NegativeInfinity
+        } else if (datum == options.positiveInf) {
+          Float.PositiveInfinity
+        } else {
+          Try(datum.toFloat)
+            
.getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue())
+        }
+      case _: DoubleType =>
+        if (datum == options.nullValue && nullable) {
+          null
+        } else if (datum == options.nanValue) {
+          Double.NaN
+        } else if (datum == options.negativeInf) {
+          Double.NegativeInfinity
+        } else if (datum == options.positiveInf) {
+          Double.PositiveInfinity
+        } else {
+          Try(datum.toDouble)
+            
.getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue())
+        }
+      case _: BooleanType => datum.toBoolean
+      case dt: DecimalType =>
+        if (datum == options.nullValue && nullable) {
+          null
+        } else {
           val value = new BigDecimal(datum.replaceAll(",", ""))
           Decimal(value, dt.precision, dt.scale)
-        case _: TimestampType if dateFormat != null =>
-          // This one will lose microseconds parts.
-          // See https://issues.apache.org/jira/browse/SPARK-10681.
-          dateFormat.parse(datum).getTime * 1000L
-        case _: TimestampType =>
-          // This one will lose microseconds parts.
-          // See https://issues.apache.org/jira/browse/SPARK-10681.
-          DateTimeUtils.stringToTime(datum).getTime  * 1000L
-        case _: DateType if dateFormat != null =>
-          DateTimeUtils.millisToDays(dateFormat.parse(datum).getTime)
-        case _: DateType =>
-          DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime)
-        case _: StringType => UTF8String.fromString(datum)
-        case _ => throw new RuntimeException(s"Unsupported type: 
${castType.typeName}")
-      }
+        }
+      case _: TimestampType if options.dateFormat != null =>
+        // This one will lose microseconds parts.
+        // See https://issues.apache.org/jira/browse/SPARK-10681.
+        options.dateFormat.parse(datum).getTime * 1000L
+      case _: TimestampType =>
+        // This one will lose microseconds parts.
+        // See https://issues.apache.org/jira/browse/SPARK-10681.
+        DateTimeUtils.stringToTime(datum).getTime  * 1000L
+      case _: DateType if options.dateFormat != null =>
+        DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime)
+      case _: DateType =>
+        DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime)
+      case _: StringType => UTF8String.fromString(datum)
+      case _ => throw new RuntimeException(s"Unsupported type: 
${castType.typeName}")
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/507bea5c/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index b87d19f..e4fd094 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -90,6 +90,12 @@ private[sql] class CSVOptions(@transient private val 
parameters: Map[String, Str
 
   val nullValue = parameters.getOrElse("nullValue", "")
 
+  val nanValue = parameters.getOrElse("nanValue", "NaN")
+
+  val positiveInf = parameters.getOrElse("positiveInf", "Inf")
+  val negativeInf = parameters.getOrElse("negativeInf", "-Inf")
+
+
   val compressionCodec: Option[String] = {
     val name = parameters.get("compression").orElse(parameters.get("codec"))
     name.map(CompressionCodecs.getCodecClassName)
@@ -111,3 +117,12 @@ private[sql] class CSVOptions(@transient private val 
parameters: Map[String, Str
 
   val rowSeparator = "\n"
 }
+
+object CSVOptions {
+
+  def apply(): CSVOptions = new CSVOptions(Map.empty)
+
+  def apply(paramName: String, paramValue: String): CSVOptions = {
+    new CSVOptions(Map(paramName -> paramValue))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/507bea5c/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
index 9a72363..4f2d438 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
@@ -99,8 +99,7 @@ object CSVRelation extends Logging {
               indexSafeTokens(index),
               field.dataType,
               field.nullable,
-              params.nullValue,
-              params.dateFormat)
+              params)
             if (subIndex < requiredSize) {
               row(subIndex) = value
             }

http://git-wip-us.apache.org/repos/asf/spark/blob/507bea5c/sql/core/src/test/resources/numbers.csv
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/numbers.csv 
b/sql/core/src/test/resources/numbers.csv
new file mode 100644
index 0000000..af8feac
--- /dev/null
+++ b/sql/core/src/test/resources/numbers.csv
@@ -0,0 +1,9 @@
+int,long,float,double
+8,1000000,1.042,23848545.0374
+--,34232323,98.343,184721.23987223
+34,--,98.343,184721.23987223
+34,43323123,--,184721.23987223
+34,43323123,223823.9484,--
+34,43323123,223823.NAN,NAN
+34,43323123,223823.INF,INF
+34,43323123,223823.-INF,-INF

http://git-wip-us.apache.org/repos/asf/spark/blob/507bea5c/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 8847c76..07f00a0 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -46,6 +46,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with 
SQLTestUtils {
   private val disableCommentsFile = "disable_comments.csv"
   private val boolFile = "bool.csv"
   private val simpleSparseFile = "simple_sparse.csv"
+  private val numbersFile = "numbers.csv"
   private val datesFile = "dates.csv"
   private val unescapedQuotesFile = "unescaped-quotes.csv"
 
@@ -535,4 +536,26 @@ class CSVSuite extends QueryTest with SharedSQLContext 
with SQLTestUtils {
 
     verifyCars(cars, withHeader = false, checkTypes = false)
   }
+
+  test("nulls, NaNs and Infinity values can be parsed") {
+    val numbers = sqlContext
+      .read
+      .format("csv")
+      .schema(StructType(List(
+        StructField("int", IntegerType, true),
+        StructField("long", LongType, true),
+        StructField("float", FloatType, true),
+        StructField("double", DoubleType, true)
+      )))
+      .options(Map(
+        "header" -> "true",
+        "mode" -> "DROPMALFORMED",
+        "nullValue" -> "--",
+        "nanValue" -> "NAN",
+        "negativeInf" -> "-INF",
+        "positiveInf" -> "INF"))
+      .load(testFile(numbersFile))
+
+    assert(numbers.count() == 8)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/507bea5c/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
index 8b59bc1..26b33b2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
@@ -29,6 +29,8 @@ import org.apache.spark.unsafe.types.UTF8String
 
 class CSVTypeCastSuite extends SparkFunSuite {
 
+  private def assertNull(v: Any) = assert(v == null)
+
   test("Can parse decimal type values") {
     val stringValues = Seq("10.05", "1,000.01", "158,058,049.001")
     val decimalValues = Seq(10.05, 1000.01, 158058049.001)
@@ -66,17 +68,21 @@ class CSVTypeCastSuite extends SparkFunSuite {
   }
 
   test("Nullable types are handled") {
-    assert(CSVTypeCast.castTo("", IntegerType, nullable = true) == null)
+    assert(CSVTypeCast.castTo("", IntegerType, nullable = true, CSVOptions()) 
== null)
   }
 
   test("String type should always return the same as the input") {
-    assert(CSVTypeCast.castTo("", StringType, nullable = true) == 
UTF8String.fromString(""))
-    assert(CSVTypeCast.castTo("", StringType, nullable = false) == 
UTF8String.fromString(""))
+    assert(
+      CSVTypeCast.castTo("", StringType, nullable = true, CSVOptions()) ==
+        UTF8String.fromString(""))
+    assert(
+      CSVTypeCast.castTo("", StringType, nullable = false, CSVOptions()) ==
+        UTF8String.fromString(""))
   }
 
   test("Throws exception for empty string with non null type") {
     val exception = intercept[NumberFormatException]{
-      CSVTypeCast.castTo("", IntegerType, nullable = false)
+      CSVTypeCast.castTo("", IntegerType, nullable = false, CSVOptions())
     }
     assert(exception.getMessage.contains("For input string: \"\""))
   }
@@ -90,12 +96,12 @@ class CSVTypeCastSuite extends SparkFunSuite {
     assert(CSVTypeCast.castTo("1.00", DoubleType) == 1.0)
     assert(CSVTypeCast.castTo("true", BooleanType) == true)
 
-    val dateFormat = new SimpleDateFormat("dd/MM/yyyy hh:mm")
+    val options = CSVOptions("dateFormat", "dd/MM/yyyy hh:mm")
     val customTimestamp = "31/01/2015 00:00"
-    val expectedTime = dateFormat.parse("31/01/2015 00:00").getTime
-    assert(CSVTypeCast.castTo(customTimestamp, TimestampType, dateFormat = 
dateFormat)
-      == expectedTime * 1000L)
-    assert(CSVTypeCast.castTo(customTimestamp, DateType, dateFormat = 
dateFormat) ==
+    val expectedTime = options.dateFormat.parse("31/01/2015 00:00").getTime
+    assert(CSVTypeCast.castTo(customTimestamp, TimestampType, nullable = true, 
options) ==
+      expectedTime * 1000L)
+    assert(CSVTypeCast.castTo(customTimestamp, DateType, nullable = true, 
options) ==
       DateTimeUtils.millisToDays(expectedTime))
 
     val timestamp = "2015-01-01 00:00:00"
@@ -116,4 +122,63 @@ class CSVTypeCastSuite extends SparkFunSuite {
       Locale.setDefault(originalLocale)
     }
   }
+
+  test("Float NaN values are parsed correctly") {
+    val floatVal: Float = CSVTypeCast.castTo(
+      "nn", FloatType, nullable = true, CSVOptions("nanValue", 
"nn")).asInstanceOf[Float]
+
+    // Java implements the IEEE-754 floating point standard which guarantees 
that any comparison
+    // against NaN will return false (except != which returns true)
+    assert(floatVal != floatVal)
+  }
+
+  test("Double NaN values are parsed correctly") {
+    val doubleVal: Double = CSVTypeCast.castTo(
+      "-", DoubleType, nullable = true, CSVOptions("nanValue", 
"-")).asInstanceOf[Double]
+
+    assert(doubleVal.isNaN)
+  }
+
+  test("Float infinite values can be parsed") {
+    val floatVal1 = CSVTypeCast.castTo(
+      "max", FloatType, nullable = true, CSVOptions("negativeInf", 
"max")).asInstanceOf[Float]
+
+    assert(floatVal1 == Float.NegativeInfinity)
+
+    val floatVal2 = CSVTypeCast.castTo(
+      "max", FloatType, nullable = true, CSVOptions("positiveInf", 
"max")).asInstanceOf[Float]
+
+    assert(floatVal2 == Float.PositiveInfinity)
+  }
+
+  test("Double infinite values can be parsed") {
+    val doubleVal1 = CSVTypeCast.castTo(
+      "max", DoubleType, nullable = true, CSVOptions("negativeInf", "max")
+    ).asInstanceOf[Double]
+
+    assert(doubleVal1 == Double.NegativeInfinity)
+
+    val doubleVal2 = CSVTypeCast.castTo(
+      "max", DoubleType, nullable = true, CSVOptions("positiveInf", "max")
+    ).asInstanceOf[Double]
+
+    assert(doubleVal2 == Double.PositiveInfinity)
+  }
+
+  test("Type-specific null values are used for casting") {
+    assertNull(
+      CSVTypeCast.castTo("-", ByteType, nullable = true, 
CSVOptions("nullValue", "-")))
+    assertNull(
+      CSVTypeCast.castTo("-", ShortType, nullable = true, 
CSVOptions("nullValue", "-")))
+    assertNull(
+      CSVTypeCast.castTo("-", IntegerType, nullable = true, 
CSVOptions("nullValue", "-")))
+    assertNull(
+      CSVTypeCast.castTo("-", LongType, nullable = true, 
CSVOptions("nullValue", "-")))
+    assertNull(
+      CSVTypeCast.castTo("-", FloatType, nullable = true, 
CSVOptions("nullValue", "-")))
+    assertNull(
+      CSVTypeCast.castTo("-", DoubleType, nullable = true, 
CSVOptions("nullValue", "-")))
+    assertNull(
+      CSVTypeCast.castTo("-", DecimalType.DoubleDecimal, true, 
CSVOptions("nullValue", "-")))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to