Repository: spark
Updated Branches:
  refs/heads/master 96941b12f -> 1b070637f


[SPARK-14295][SPARK-14274][SQL] Implements buildReader() for LibSVM

## What changes were proposed in this pull request?

This PR implements `FileFormat.buildReader()` for the LibSVM data source. 
Besides that, a new interface method `prepareRead()` is added to `FileFormat`:

```scala
  def prepareRead(
      sqlContext: SQLContext,
      options: Map[String, String],
      files: Seq[FileStatus]): Map[String, String] = options
```

After migrating from `buildInternalScan()` to `buildReader()`, we lost the 
opportunity to collect necessary global information, since `buildReader()` 
works in a per-partition manner. For example, LibSVM needs to infer the total 
number of features if the `numFeatures` data source option is not set. Any 
necessary collected global information should be returned using the data source 
options map. By default, this method just returns the original options 
untouched.

An alternative approach is to absorb `inferSchema()` into `prepareRead()`, 
since schema inference is also some kind of global information gathering. 
However, this approach wasn't chosen because schema inference is optional, 
while `prepareRead()` must be called whenever a `HadoopFsRelation` based data 
source relation is instantiated.

One unaddressed problem is that, when `numFeatures` is absent, now the input 
data will be scanned twice. The `buildInternalScan()` code path doesn't need to 
do this because it caches the raw parsed RDD in memory before computing the 
total number of features. However, with `FileScanRDD`, the raw parsed RDD is 
created in a different way (e.g. partitioning) from the final RDD.

## How was this patch tested?

Tested using existing test suites.

Author: Cheng Lian <[email protected]>

Closes #12088 from liancheng/spark-14295-libsvm-build-reader.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1b070637
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1b070637
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1b070637

Branch: refs/heads/master
Commit: 1b070637fa03ab4966f76427b15e433050eaa956
Parents: 96941b1
Author: Cheng Lian <[email protected]>
Authored: Thu Mar 31 23:46:08 2016 -0700
Committer: Xiangrui Meng <[email protected]>
Committed: Thu Mar 31 23:46:08 2016 -0700

----------------------------------------------------------------------
 .../spark/ml/source/libsvm/LibSVMRelation.scala | 87 +++++++++++++++++++-
 .../org/apache/spark/mllib/util/MLUtils.scala   | 73 +++++++++-------
 .../sql/execution/datasources/DataSource.scala  |  5 +-
 .../datasources/FileSourceStrategy.scala        |  1 +
 .../apache/spark/sql/sources/interfaces.scala   |  9 ++
 5 files changed, 141 insertions(+), 34 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1b070637/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala 
b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index 13a13f0..2e9b6be 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.source.libsvm
 
 import java.io.IOException
 
+import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.{FileStatus, Path}
 import org.apache.hadoop.io.{NullWritable, Text}
 import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
@@ -26,12 +27,16 @@ import 
org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
 
 import org.apache.spark.annotation.Since
 import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
+import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
JoinedRow}
+import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, 
HadoopFileLinesReader, PartitionedFile}
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types._
 import org.apache.spark.util.SerializableConfiguration
@@ -110,13 +115,16 @@ class DefaultSource extends FileFormat with 
DataSourceRegister {
   @Since("1.6.0")
   override def shortName(): String = "libsvm"
 
+  override def toString: String = "LibSVM"
+
   private def verifySchema(dataSchema: StructType): Unit = {
     if (dataSchema.size != 2 ||
       (!dataSchema(0).dataType.sameType(DataTypes.DoubleType)
         || !dataSchema(1).dataType.sameType(new VectorUDT()))) {
-      throw new IOException(s"Illegal schema for libsvm data, 
schema=${dataSchema}")
+      throw new IOException(s"Illegal schema for libsvm data, 
schema=$dataSchema")
     }
   }
+
   override def inferSchema(
       sqlContext: SQLContext,
       options: Map[String, String],
@@ -127,6 +135,32 @@ class DefaultSource extends FileFormat with 
DataSourceRegister {
         StructField("features", new VectorUDT(), nullable = false) :: Nil))
   }
 
+  override def prepareRead(
+      sqlContext: SQLContext,
+      options: Map[String, String],
+      files: Seq[FileStatus]): Map[String, String] = {
+    def computeNumFeatures(): Int = {
+      val dataFiles = files.filterNot(_.getPath.getName startsWith "_")
+      val path = if (dataFiles.length == 1) {
+        dataFiles.head.getPath.toUri.toString
+      } else if (dataFiles.isEmpty) {
+        throw new IOException("No input path specified for libsvm data")
+      } else {
+        throw new IOException("Multiple input paths are not supported for 
libsvm data.")
+      }
+
+      val sc = sqlContext.sparkContext
+      val parsed = MLUtils.parseLibSVMFile(sc, path, sc.defaultParallelism)
+      MLUtils.computeNumFeatures(parsed)
+    }
+
+    val numFeatures = options.get("numFeatures").filter(_.toInt > 0).getOrElse 
{
+      computeNumFeatures()
+    }
+
+    new CaseInsensitiveMap(options + ("numFeatures" -> numFeatures.toString))
+  }
+
   override def prepareWrite(
       sqlContext: SQLContext,
       job: Job,
@@ -158,7 +192,7 @@ class DefaultSource extends FileFormat with 
DataSourceRegister {
     verifySchema(dataSchema)
     val dataFiles = inputFiles.filterNot(_.getPath.getName startsWith "_")
 
-    val path = if (dataFiles.length == 1) dataFiles(0).getPath.toUri.toString
+    val path = if (dataFiles.length == 1) dataFiles.head.getPath.toUri.toString
     else if (dataFiles.isEmpty) throw new IOException("No input path specified 
for libsvm data")
     else throw new IOException("Multiple input paths are not supported for 
libsvm data.")
 
@@ -176,4 +210,51 @@ class DefaultSource extends FileFormat with 
DataSourceRegister {
       externalRows.map(converter.toRow)
     }
   }
+
+  override def buildReader(
+      sqlContext: SQLContext,
+      dataSchema: StructType,
+      partitionSchema: StructType,
+      requiredSchema: StructType,
+      filters: Seq[Filter],
+      options: Map[String, String]): (PartitionedFile) => 
Iterator[InternalRow] = {
+    val numFeatures = options("numFeatures").toInt
+    assert(numFeatures > 0)
+
+    val sparse = options.getOrElse("vectorType", "sparse") == "sparse"
+
+    val broadcastedConf = sqlContext.sparkContext.broadcast(
+      new SerializableConfiguration(new 
Configuration(sqlContext.sparkContext.hadoopConfiguration))
+    )
+
+    (file: PartitionedFile) => {
+      val points =
+        new HadoopFileLinesReader(file, broadcastedConf.value.value)
+          .map(_.toString.trim)
+          .filterNot(line => line.isEmpty || line.startsWith("#"))
+          .map { line =>
+            val (label, indices, values) = MLUtils.parseLibSVMRecord(line)
+            LabeledPoint(label, Vectors.sparse(numFeatures, indices, values))
+          }
+
+      val converter = RowEncoder(requiredSchema)
+
+      val unsafeRowIterator = points.map { pt =>
+        val features = if (sparse) pt.features.toSparse else 
pt.features.toDense
+        converter.toRow(Row(pt.label, features))
+      }
+
+      def toAttribute(f: StructField): AttributeReference =
+        AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
+
+      // Appends partition values
+      val fullOutput = (requiredSchema ++ partitionSchema).map(toAttribute)
+      val joinedRow = new JoinedRow()
+      val appendPartitionColumns = 
GenerateUnsafeProjection.generate(fullOutput, fullOutput)
+
+      unsafeRowIterator.map { dataRow =>
+        appendPartitionColumns(joinedRow(dataRow, file.partitionValues))
+      }
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1b070637/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index c3b1d5c..4b9d779 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -67,42 +67,14 @@ object MLUtils {
       path: String,
       numFeatures: Int,
       minPartitions: Int): RDD[LabeledPoint] = {
-    val parsed = sc.textFile(path, minPartitions)
-      .map(_.trim)
-      .filter(line => !(line.isEmpty || line.startsWith("#")))
-      .map { line =>
-        val items = line.split(' ')
-        val label = items.head.toDouble
-        val (indices, values) = items.tail.filter(_.nonEmpty).map { item =>
-          val indexAndValue = item.split(':')
-          val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 
0-based.
-          val value = indexAndValue(1).toDouble
-          (index, value)
-        }.unzip
-
-        // check if indices are one-based and in ascending order
-        var previous = -1
-        var i = 0
-        val indicesLength = indices.length
-        while (i < indicesLength) {
-          val current = indices(i)
-          require(current > previous, s"indices should be one-based and in 
ascending order;"
-            + " found current=$current, previous=$previous; line=\"$line\"")
-          previous = current
-          i += 1
-        }
-
-        (label, indices.toArray, values.toArray)
-      }
+    val parsed = parseLibSVMFile(sc, path, minPartitions)
 
     // Determine number of features.
     val d = if (numFeatures > 0) {
       numFeatures
     } else {
       parsed.persist(StorageLevel.MEMORY_ONLY)
-      parsed.map { case (label, indices, values) =>
-        indices.lastOption.getOrElse(0)
-      }.reduce(math.max) + 1
+      computeNumFeatures(parsed)
     }
 
     parsed.map { case (label, indices, values) =>
@@ -110,6 +82,47 @@ object MLUtils {
     }
   }
 
+  private[spark] def computeNumFeatures(rdd: RDD[(Double, Array[Int], 
Array[Double])]): Int = {
+    rdd.map { case (label, indices, values) =>
+      indices.lastOption.getOrElse(0)
+    }.reduce(math.max) + 1
+  }
+
+  private[spark] def parseLibSVMFile(
+      sc: SparkContext,
+      path: String,
+      minPartitions: Int): RDD[(Double, Array[Int], Array[Double])] = {
+    sc.textFile(path, minPartitions)
+      .map(_.trim)
+      .filter(line => !(line.isEmpty || line.startsWith("#")))
+      .map(parseLibSVMRecord)
+  }
+
+  private[spark] def parseLibSVMRecord(line: String): (Double, Array[Int], 
Array[Double]) = {
+    val items = line.split(' ')
+    val label = items.head.toDouble
+    val (indices, values) = items.tail.filter(_.nonEmpty).map { item =>
+      val indexAndValue = item.split(':')
+      val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 
0-based.
+    val value = indexAndValue(1).toDouble
+      (index, value)
+    }.unzip
+
+    // check if indices are one-based and in ascending order
+    var previous = -1
+    var i = 0
+    val indicesLength = indices.length
+    while (i < indicesLength) {
+      val current = indices(i)
+      require(current > previous, s"indices should be one-based and in 
ascending order;"
+        + " found current=$current, previous=$previous; line=\"$line\"")
+      previous = current
+      i += 1
+    }
+
+    (label, indices, values)
+  }
+
   /**
    * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with 
the default number of
    * partitions.

http://git-wip-us.apache.org/repos/asf/spark/blob/1b070637/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index c66921f..1850810 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -299,6 +299,9 @@ case class DataSource(
             "It must be specified manually")
         }
 
+        val enrichedOptions =
+          format.prepareRead(sqlContext, caseInsensitiveOptions, 
fileCatalog.allFiles())
+
         HadoopFsRelation(
           sqlContext,
           fileCatalog,
@@ -306,7 +309,7 @@ case class DataSource(
           dataSchema = dataSchema.asNullable,
           bucketSpec = bucketSpec,
           format,
-          options)
+          enrichedOptions)
 
       case _ =>
         throw new AnalysisException(

http://git-wip-us.apache.org/repos/asf/spark/blob/1b070637/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
index 5542987..a143ac6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
@@ -59,6 +59,7 @@ private[sql] object FileSourceStrategy extends Strategy with 
Logging {
       if (files.fileFormat.toString == "TestFileFormat" ||
          files.fileFormat.isInstanceOf[parquet.DefaultSource] ||
          files.fileFormat.toString == "ORC" ||
+         files.fileFormat.toString == "LibSVM" ||
          files.fileFormat.isInstanceOf[csv.DefaultSource] ||
          files.fileFormat.isInstanceOf[text.DefaultSource] ||
          files.fileFormat.isInstanceOf[json.DefaultSource]) &&

http://git-wip-us.apache.org/repos/asf/spark/blob/1b070637/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 6b95a3d..e8834d0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -439,6 +439,15 @@ trait FileFormat {
       files: Seq[FileStatus]): Option[StructType]
 
   /**
+   * Prepares a read job and returns a potentially updated data source option 
[[Map]]. This method
+   * can be useful for collecting necessary global information for scanning 
input data.
+   */
+  def prepareRead(
+      sqlContext: SQLContext,
+      options: Map[String, String],
+      files: Seq[FileStatus]): Map[String, String] = options
+
+  /**
    * Prepares a write job and returns an [[OutputWriterFactory]].  Client side 
job preparation can
    * be put here.  For example, user defined output committer can be 
configured here
    * by setting the output committer class in the conf of 
spark.sql.sources.outputCommitterClass.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to