Repository: spark
Updated Branches:
  refs/heads/master 4c477117b -> 158ad0bba


http://git-wip-us.apache.org/repos/asf/spark/blob/158ad0bb/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDFRegistration.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDFRegistration.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDFRegistration.scala
new file mode 100644
index 0000000..158f26e
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDFRegistration.scala
@@ -0,0 +1,252 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.api.java
+
+import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf}
+import org.apache.spark.sql.types.util.DataTypeConversions._
+
+/**
+ * A collection of functions that allow Java users to register UDFs.  In order 
to handle functions
+ * of varying airities with minimal boilerplate for our users, we generate 
classes and functions
+ * for each airity up to 22.  The code for this generation can be found in 
comments in this trait.
+ */
+private[java] trait UDFRegistration {
+  self: JavaSQLContext =>
+
+  /* The following functions and required interfaces are generated with these 
code fragments:
+
+   (1 to 22).foreach { i =>
+     val extTypeArgs = (1 to i).map(_ => "_").mkString(", ")
+     val anyTypeArgs = (1 to i).map(_ => "Any").mkString(", ")
+     val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs, Any]]"
+     val anyParams = (1 to i).map(_ => "_: Any").mkString(", ")
+     println(s"""
+         |def registerFunction(
+         |    name: String, f: UDF$i[$extTypeArgs, _], @transient dataType: 
DataType) = {
+         |  val scalaType = asScalaDataType(dataType)
+         |  sqlContext.functionRegistry.registerFunction(
+         |    name,
+         |    (e: Seq[Expression]) => ScalaUdf(f$anyCast.call($anyParams), 
scalaType, e))
+         |}
+       """.stripMargin)
+   }
+
+  import java.io.File
+  import org.apache.spark.sql.catalyst.util.stringToFile
+  val directory = new 
File("sql/core/src/main/java/org/apache/spark/sql/api/java/")
+  (1 to 22).foreach { i =>
+    val typeArgs = (1 to i).map(i => s"T$i").mkString(", ")
+    val args = (1 to i).map(i => s"T$i t$i").mkString(", ")
+
+    val contents =
+      s"""/*
+         | * Licensed to the Apache Software Foundation (ASF) under one or more
+         | * contributor license agreements.  See the NOTICE file distributed 
with
+         | * this work for additional information regarding copyright 
ownership.
+         | * The ASF licenses this file to You under the Apache License, 
Version 2.0
+         | * (the "License"); you may not use this file except in compliance 
with
+         | * the License.  You may obtain a copy of the License at
+         | *
+         | *    http://www.apache.org/licenses/LICENSE-2.0
+         | *
+         | * Unless required by applicable law or agreed to in writing, 
software
+         | * distributed under the License is distributed on an "AS IS" BASIS,
+         | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 
implied.
+         | * See the License for the specific language governing permissions 
and
+         | * limitations under the License.
+         | */
+         |
+         |package org.apache.spark.sql.api.java;
+         |
+         |import java.io.Serializable;
+         |
+         |// **************************************************
+         |// THIS FILE IS AUTOGENERATED BY CODE IN
+         |// org.apache.spark.sql.api.java.FunctionRegistration
+         |// **************************************************
+         |
+         |/**
+         | * A Spark SQL UDF that has $i arguments.
+         | */
+         |public interface UDF$i<$typeArgs, R> extends Serializable {
+         |  public R call($args) throws Exception;
+         |}
+         |""".stripMargin
+
+      stringToFile(new File(directory, s"UDF$i.java"), contents)
+  }
+
+  */
+
+  // scalastyle:off
+  def registerFunction(name: String, f: UDF1[_, _], dataType: DataType) = {
+    val scalaType = asScalaDataType(dataType)
+    sqlContext.functionRegistry.registerFunction(
+      name,
+      (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF1[Any, Any]].call(_: 
Any), scalaType, e))
+  }
+
+  def registerFunction(name: String, f: UDF2[_, _, _], dataType: DataType) = {
+    val scalaType = asScalaDataType(dataType)
+    sqlContext.functionRegistry.registerFunction(
+      name,
+      (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF2[Any, Any, 
Any]].call(_: Any, _: Any), scalaType, e))
+  }
+
+  def registerFunction(name: String, f: UDF3[_, _, _, _], dataType: DataType) 
= {
+    val scalaType = asScalaDataType(dataType)
+    sqlContext.functionRegistry.registerFunction(
+      name,
+      (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF3[Any, Any, Any, 
Any]].call(_: Any, _: Any, _: Any), scalaType, e))
+  }
+
+  def registerFunction(name: String, f: UDF4[_, _, _, _, _], dataType: 
DataType) = {
+    val scalaType = asScalaDataType(dataType)
+    sqlContext.functionRegistry.registerFunction(
+      name,
+      (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF4[Any, Any, Any, Any, 
Any]].call(_: Any, _: Any, _: Any, _: Any), scalaType, e))
+  }
+
+  def registerFunction(name: String, f: UDF5[_, _, _, _, _, _], dataType: 
DataType) = {
+    val scalaType = asScalaDataType(dataType)
+    sqlContext.functionRegistry.registerFunction(
+      name,
+      (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF5[Any, Any, Any, Any, 
Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e))
+  }
+
+  def registerFunction(name: String, f: UDF6[_, _, _, _, _, _, _], dataType: 
DataType) = {
+    val scalaType = asScalaDataType(dataType)
+    sqlContext.functionRegistry.registerFunction(
+      name,
+      (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF6[Any, Any, Any, Any, 
Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), 
scalaType, e))
+  }
+
+  def registerFunction(name: String, f: UDF7[_, _, _, _, _, _, _, _], 
dataType: DataType) = {
+    val scalaType = asScalaDataType(dataType)
+    sqlContext.functionRegistry.registerFunction(
+      name,
+      (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF7[Any, Any, Any, Any, 
Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: 
Any), scalaType, e))
+  }
+
+  def registerFunction(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], 
dataType: DataType) = {
+    val scalaType = asScalaDataType(dataType)
+    sqlContext.functionRegistry.registerFunction(
+      name,
+      (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF8[Any, Any, Any, Any, 
Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, 
_: Any, _: Any), scalaType, e))
+  }
+
+  def registerFunction(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], 
dataType: DataType) = {
+    val scalaType = asScalaDataType(dataType)
+    sqlContext.functionRegistry.registerFunction(
+      name,
+      (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF9[Any, Any, Any, Any, 
Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: 
Any, _: Any, _: Any, _: Any), scalaType, e))
+  }
+
+  def registerFunction(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, 
_], dataType: DataType) = {
+    val scalaType = asScalaDataType(dataType)
+    sqlContext.functionRegistry.registerFunction(
+      name,
+      (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF10[Any, Any, Any, 
Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, 
_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e))
+  }
+
+  def registerFunction(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, 
_], dataType: DataType) = {
+    val scalaType = asScalaDataType(dataType)
+    sqlContext.functionRegistry.registerFunction(
+      name,
+      (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF11[Any, Any, Any, 
Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: 
Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e))
+  }
+
+  def registerFunction(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, 
_, _], dataType: DataType) = {
+    val scalaType = asScalaDataType(dataType)
+    sqlContext.functionRegistry.registerFunction(
+      name,
+      (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF12[Any, Any, Any, 
Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, 
_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), 
scalaType, e))
+  }
+
+  def registerFunction(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, 
_, _, _], dataType: DataType) = {
+    val scalaType = asScalaDataType(dataType)
+    sqlContext.functionRegistry.registerFunction(
+      name,
+      (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF13[Any, Any, Any, 
Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: 
Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: 
Any), scalaType, e))
+  }
+
+  def registerFunction(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, 
_, _, _, _], dataType: DataType) = {
+    val scalaType = asScalaDataType(dataType)
+    sqlContext.functionRegistry.registerFunction(
+      name,
+      (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF14[Any, Any, Any, 
Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: 
Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: 
Any, _: Any, _: Any), scalaType, e))
+  }
+
+  def registerFunction(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, 
_, _, _, _, _], dataType: DataType) = {
+    val scalaType = asScalaDataType(dataType)
+    sqlContext.functionRegistry.registerFunction(
+      name,
+      (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF15[Any, Any, Any, 
Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, 
_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, 
_: Any, _: Any, _: Any, _: Any), scalaType, e))
+  }
+
+  def registerFunction(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, 
_, _, _, _, _, _], dataType: DataType) = {
+    val scalaType = asScalaDataType(dataType)
+    sqlContext.functionRegistry.registerFunction(
+      name,
+      (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF16[Any, Any, Any, 
Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: 
Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: 
Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e))
+  }
+
+  def registerFunction(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, 
_, _, _, _, _, _, _], dataType: DataType) = {
+    val scalaType = asScalaDataType(dataType)
+    sqlContext.functionRegistry.registerFunction(
+      name,
+      (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF17[Any, Any, Any, 
Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, 
Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: 
Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), 
scalaType, e))
+  }
+
+  def registerFunction(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, 
_, _, _, _, _, _, _, _], dataType: DataType) = {
+    val scalaType = asScalaDataType(dataType)
+    sqlContext.functionRegistry.registerFunction(
+      name,
+      (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF18[Any, Any, Any, 
Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, 
Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: 
Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), 
scalaType, e))
+  }
+
+  def registerFunction(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, 
_, _, _, _, _, _, _, _, _], dataType: DataType) = {
+    val scalaType = asScalaDataType(dataType)
+    sqlContext.functionRegistry.registerFunction(
+      name,
+      (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF19[Any, Any, Any, 
Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, 
Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: 
Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: 
Any), scalaType, e))
+  }
+
+  def registerFunction(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, 
_, _, _, _, _, _, _, _, _, _], dataType: DataType) = {
+    val scalaType = asScalaDataType(dataType)
+    sqlContext.functionRegistry.registerFunction(
+      name,
+      (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF20[Any, Any, Any, 
Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, 
Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, 
_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, 
_: Any, _: Any), scalaType, e))
+  }
+
+  def registerFunction(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, 
_, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = {
+    val scalaType = asScalaDataType(dataType)
+    sqlContext.functionRegistry.registerFunction(
+      name,
+      (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF21[Any, Any, Any, 
Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, 
Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: 
Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: 
Any, _: Any, _: Any, _: Any), scalaType, e))
+  }
+
+  def registerFunction(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, 
_, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = {
+    val scalaType = asScalaDataType(dataType)
+    sqlContext.functionRegistry.registerFunction(
+      name,
+      (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF22[Any, Any, Any, 
Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, 
Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: 
Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: 
Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e))
+  }
+
+  // scalastyle:on
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/158ad0bb/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 8bec015..f0c958f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -286,6 +286,8 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         execution.ExistingRdd(Nil, singleRowRdd) :: Nil
       case logical.Repartition(expressions, child) =>
         execution.Exchange(HashPartitioning(expressions, numPartitions), 
planLater(child)) :: Nil
+      case e @ EvaluatePython(udf, child) =>
+        BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
       case SparkLogicalPlan(existingPlan) => existingPlan :: Nil
       case _ => Nil
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/158ad0bb/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
new file mode 100644
index 0000000..b92091b
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -0,0 +1,177 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution
+
+import java.util.{List => JList, Map => JMap}
+
+import net.razorvine.pickle.{Pickler, Unpickler}
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.api.python.PythonRDD
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.{Accumulator, Logging => SparkLogging}
+
+import scala.collection.JavaConversions._
+
+/**
+ * A serialized version of a Python lambda function.  Suitable for use in a 
[[PythonRDD]].
+ */
+private[spark] case class PythonUDF(
+    name: String,
+    command: Array[Byte],
+    envVars: JMap[String, String],
+    pythonIncludes: JList[String],
+    pythonExec: String,
+    accumulator: Accumulator[JList[Array[Byte]]],
+    dataType: DataType,
+    children: Seq[Expression]) extends Expression with SparkLogging {
+
+  override def toString = s"PythonUDF#$name(${children.mkString(",")})"
+
+  def nullable: Boolean = true
+  def references: Set[Attribute] = children.flatMap(_.references).toSet
+
+  override def eval(input: Row) = sys.error("PythonUDFs can not be directly 
evaluated.")
+}
+
+/**
+ * Extracts PythonUDFs from operators, rewriting the query plan so that the 
UDF can be evaluated
+ * alone in a batch.
+ *
+ * This has the limitation that the input to the Python UDF is not allowed 
include attributes from
+ * multiple child operators.
+ */
+private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan) = plan transform {
+    // Skip EvaluatePython nodes.
+    case p: EvaluatePython => p
+
+    case l: LogicalPlan =>
+      // Extract any PythonUDFs from the current operator.
+      val udfs = l.expressions.flatMap(_.collect { case udf: PythonUDF => udf})
+      if (udfs.isEmpty) {
+        // If there aren't any, we are done.
+        l
+      } else {
+        // Pick the UDF we are going to evaluate (TODO: Support evaluating 
multiple UDFs at a time)
+        // If there is more than one, we will add another evaluation operator 
in a subsequent pass.
+        val udf = udfs.head
+
+        var evaluation: EvaluatePython = null
+
+        // Rewrite the child that has the input required for the UDF
+        val newChildren = l.children.map { child =>
+          // Check to make sure that the UDF can be evaluated with only the 
input of this child.
+          // Other cases are disallowed as they are ambiguous or would require 
a cartisian product.
+          if (udf.references.subsetOf(child.outputSet)) {
+            evaluation = EvaluatePython(udf, child)
+            evaluation
+          } else if (udf.references.intersect(child.outputSet).nonEmpty) {
+            sys.error(s"Invalid PythonUDF $udf, requires attributes from more 
than one child.")
+          } else {
+            child
+          }
+        }
+
+        assert(evaluation != null, "Unable to evaluate PythonUDF.  Missing 
input attributes.")
+
+        // Trim away the new UDF value if it was only used for filtering or 
something.
+        logical.Project(
+          l.output,
+          l.transformExpressions {
+            case p: PythonUDF if p.id == udf.id => evaluation.resultAttribute
+          }.withNewChildren(newChildren))
+      }
+  }
+}
+
+/**
+ * :: DeveloperApi ::
+ * Evaluates a [[PythonUDF]], appending the result to the end of the input 
tuple.
+ */
+@DeveloperApi
+case class EvaluatePython(udf: PythonUDF, child: LogicalPlan) extends 
logical.UnaryNode {
+  val resultAttribute = AttributeReference("pythonUDF", udf.dataType, 
nullable=true)()
+
+  def references = Set.empty
+  def output = child.output :+ resultAttribute
+}
+
+/**
+ * :: DeveloperApi ::
+ * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a 
time.  The input
+ * data is cached and zipped with the result of the udf evaluation.
+ */
+@DeveloperApi
+case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], 
child: SparkPlan)
+  extends SparkPlan {
+  def children = child :: Nil
+
+  def execute() = {
+    // TODO: Clean up after ourselves?
+    val childResults = child.execute().map(_.copy()).cache()
+
+    val parent = childResults.mapPartitions { iter =>
+      val pickle = new Pickler
+      val currentRow = newMutableProjection(udf.children, child.output)()
+      iter.grouped(1000).map { inputRows =>
+        val toBePickled = inputRows.map(currentRow(_).toArray).toArray
+        pickle.dumps(toBePickled)
+      }
+    }
+
+    val pyRDD = new PythonRDD(
+      parent,
+      udf.command,
+      udf.envVars,
+      udf.pythonIncludes,
+      false,
+      udf.pythonExec,
+      Seq[Broadcast[Array[Byte]]](),
+      udf.accumulator
+    ).mapPartitions { iter =>
+      val pickle = new Unpickler
+      iter.flatMap { pickedResult =>
+        val unpickledBatch = pickle.loads(pickedResult)
+        unpickledBatch.asInstanceOf[java.util.ArrayList[Any]]
+      }
+    }.mapPartitions { iter =>
+      val row = new GenericMutableRow(1)
+      iter.map { result =>
+        row(0) = udf.dataType match {
+          case StringType => result.toString
+          case other => result
+        }
+        row: Row
+      }
+    }
+
+    childResults.zip(pyRDD).mapPartitions { iter =>
+      val joinedRow = new JoinedRow()
+      iter.map {
+        case (row, udfResult) =>
+          joinedRow(row, udfResult)
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/158ad0bb/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java 
b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java
new file mode 100644
index 0000000..a9a1128
--- /dev/null
+++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java;
+
+import java.io.Serializable;
+
+import org.apache.spark.sql.api.java.UDF1;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runners.Suite;
+import org.junit.runner.RunWith;
+
+import org.apache.spark.api.java.JavaSparkContext;
+
+// The test suite itself is Serializable so that anonymous Function 
implementations can be
+// serialized, as an alternative to converting these anonymous classes to 
static inner classes;
+// see http://stackoverflow.com/questions/758570/.
+public class JavaAPISuite implements Serializable {
+  private transient JavaSparkContext sc;
+  private transient JavaSQLContext sqlContext;
+
+  @Before
+  public void setUp() {
+    sc = new JavaSparkContext("local", "JavaAPISuite");
+    sqlContext = new JavaSQLContext(sc);
+  }
+
+  @After
+  public void tearDown() {
+    sc.stop();
+    sc = null;
+  }
+
+  @SuppressWarnings("unchecked")
+  @Test
+  public void udf1Test() {
+    // With Java 8 lambdas:
+    // sqlContext.registerFunction(
+    //   "stringLengthTest", (String str) -> str.length(), 
DataType.IntegerType);
+
+    sqlContext.registerFunction("stringLengthTest", new UDF1<String, 
Integer>() {
+      @Override
+      public Integer call(String str) throws Exception {
+        return str.length();
+      }
+    }, DataType.IntegerType);
+
+    // TODO: Why do we need this cast?
+    Row result = (Row) sqlContext.sql("SELECT 
stringLengthTest('test')").first();
+    assert(result.getInt(0) == 4);
+  }
+
+  @SuppressWarnings("unchecked")
+  @Test
+  public void udf2Test() {
+    // With Java 8 lambdas:
+    // sqlContext.registerFunction(
+    //   "stringLengthTest",
+    //   (String str1, String str2) -> str1.length() + str2.length,
+    //   DataType.IntegerType);
+
+    sqlContext.registerFunction("stringLengthTest", new UDF2<String, String, 
Integer>() {
+      @Override
+      public Integer call(String str1, String str2) throws Exception {
+        return str1.length() + str2.length();
+      }
+    }, DataType.IntegerType);
+
+    // TODO: Why do we need this cast?
+    Row result = (Row) sqlContext.sql("SELECT stringLengthTest('test', 
'test2')").first();
+    assert(result.getInt(0) == 9);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/158ad0bb/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala
index 4f0b85f..23a711d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql
 
-import java.io.File
+import _root_.java.io.File
 
 /* Implicits */
 import org.apache.spark.sql.test.TestSQLContext._

http://git-wip-us.apache.org/repos/asf/spark/blob/158ad0bb/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
new file mode 100644
index 0000000..76aa9b0
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.test._
+
+/* Implicits */
+import TestSQLContext._
+
+class UDFSuite extends QueryTest {
+
+  test("Simple UDF") {
+    registerFunction("strLenScala", (_: String).length)
+    assert(sql("SELECT strLenScala('test')").first().getInt(0) === 4)
+  }
+
+  test("TwoArgument UDF") {
+    registerFunction("strLenScala", (_: String).length + (_:Int))
+    assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/158ad0bb/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 2c7270d..3c70b3f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -23,7 +23,7 @@ import java.util.{ArrayList => JArrayList}
 
 import scala.collection.JavaConversions._
 import scala.language.implicitConversions
-import scala.reflect.runtime.universe.TypeTag
+import scala.reflect.runtime.universe.{TypeTag, typeTag}
 
 import org.apache.hadoop.hive.conf.HiveConf
 import org.apache.hadoop.hive.ql.Driver
@@ -35,8 +35,9 @@ import org.apache.spark.SparkContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.catalyst.analysis.{Analyzer, OverrideCatalog}
+import org.apache.spark.sql.catalyst.analysis.{OverrideFunctionRegistry, 
Analyzer, OverrideCatalog}
 import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.execution.ExtractPythonUdfs
 import org.apache.spark.sql.execution.QueryExecutionException
 import org.apache.spark.sql.execution.{Command => PhysicalCommand}
 import org.apache.spark.sql.hive.execution.DescribeHiveTableCommand
@@ -155,10 +156,14 @@ class HiveContext(sc: SparkContext) extends 
SQLContext(sc) {
     }
   }
 
+  // Note that HiveUDFs will be overridden by functions registered in this 
context.
+  override protected[sql] lazy val functionRegistry =
+    new HiveFunctionRegistry with OverrideFunctionRegistry
+
   /* An analyzer that uses the Hive metastore. */
   @transient
   override protected[sql] lazy val analyzer =
-    new Analyzer(catalog, HiveFunctionRegistry, caseSensitive = false)
+    new Analyzer(catalog, functionRegistry, caseSensitive = false)
 
   /**
    * Runs the specified SQL query using Hive.
@@ -250,7 +255,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
   protected[sql] abstract class QueryExecution extends super.QueryExecution {
     // TODO: Create mixin for the analyzer instead of overriding things here.
     override lazy val optimizedPlan =
-      optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed)))
+      
optimizer(ExtractPythonUdfs(catalog.PreInsertionCasts(catalog.CreateTables(analyzed))))
 
     override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy())
 

http://git-wip-us.apache.org/repos/asf/spark/blob/158ad0bb/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
index 728452a..c605e8a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
@@ -297,8 +297,8 @@ class TestHiveContext(sc: SparkContext) extends 
HiveContext(sc) {
   def reset() {
     try {
       // HACK: Hive is too noisy by default.
-      org.apache.log4j.LogManager.getCurrentLoggers.foreach { logger =>
-        
logger.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN)
+      org.apache.log4j.LogManager.getCurrentLoggers.foreach { log =>
+        
log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN)
       }
 
       // It is important that we RESET first as broken hooks that might have 
been set could break

http://git-wip-us.apache.org/repos/asf/spark/blob/158ad0bb/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index d181921..179aac5 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -34,7 +34,8 @@ import 
org.apache.spark.util.Utils.getContextOrSparkClassLoader
 /* Implicit conversions */
 import scala.collection.JavaConversions._
 
-private[hive] object HiveFunctionRegistry extends analysis.FunctionRegistry 
with HiveInspectors {
+private[hive] abstract class HiveFunctionRegistry
+  extends analysis.FunctionRegistry with HiveInspectors {
 
   def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name)
 
@@ -92,9 +93,8 @@ private[hive] abstract class HiveUdf extends Expression with 
Logging with HiveFu
 }
 
 private[hive] case class HiveSimpleUdf(functionClassName: String, children: 
Seq[Expression])
-  extends HiveUdf {
+  extends HiveUdf with HiveInspectors {
 
-  import org.apache.spark.sql.hive.HiveFunctionRegistry._
   type UDFType = UDF
 
   @transient

http://git-wip-us.apache.org/repos/asf/spark/blob/158ad0bb/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala 
b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 11d8b1f..95921c3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -51,9 +51,9 @@ class QueryTest extends FunSuite {
         fail(
           s"""
             |Exception thrown while executing query:
-            |${rdd.logicalPlan}
+            |${rdd.queryExecution}
             |== Exception ==
-            |$e
+            |${stackTraceToString(e)}
           """.stripMargin)
     }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to