Repository: spark
Updated Branches:
  refs/heads/branch-1.5 8ece4ccda -> 064ba906a


[SPARK-9683] [SQL] copy UTF8String when convert unsafe array/map to safe

When we convert unsafe row to safe row, we will do copy if the column is struct 
or string type. However, the string inside unsafe array/map are not copied, 
which may cause problems.

Author: Wenchen Fan <[email protected]>

Closes #7990 from cloud-fan/copy and squashes the following commits:

c13d1e3 [Wenchen Fan] change test name
fe36294 [Wenchen Fan] we should deep copy UTF8String when convert unsafe row to 
safe row

(cherry picked from commit e57d6b56137bf3557efe5acea3ad390c1987b257)
Signed-off-by: Davies Liu <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/064ba906
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/064ba906
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/064ba906

Branch: refs/heads/branch-1.5
Commit: 064ba906a5992b376ff5b5bfea258e9eb879c5ea
Parents: 8ece4cc
Author: Wenchen Fan <[email protected]>
Authored: Fri Aug 7 00:00:43 2015 -0700
Committer: Davies Liu <[email protected]>
Committed: Fri Aug 7 00:00:53 2015 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/FromUnsafe.scala   |  3 ++
 .../execution/RowFormatConvertersSuite.scala    | 38 +++++++++++++++++++-
 2 files changed, 40 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/064ba906/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala
index 3caf0fb..9b960b1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
 
 case class FromUnsafe(child: Expression) extends UnaryExpression
   with ExpectsInputTypes with CodegenFallback {
@@ -52,6 +53,8 @@ case class FromUnsafe(child: Expression) extends 
UnaryExpression
       }
       new GenericArrayData(result)
 
+    case StringType => value.asInstanceOf[UTF8String].clone()
+
     case MapType(kt, vt, _) =>
       val map = value.asInstanceOf[UnsafeMapData]
       val safeKeyArray = convert(map.keys, 
ArrayType(kt)).asInstanceOf[GenericArrayData]

http://git-wip-us.apache.org/repos/asf/spark/blob/064ba906/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
index 707cd9c..8208b25 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
@@ -17,9 +17,13 @@
 
 package org.apache.spark.sql.execution
 
+import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.expressions.{Literal, IsNull}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
Attribute, Literal, IsNull}
 import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StructType, 
StringType}
+import org.apache.spark.unsafe.types.UTF8String
 
 class RowFormatConvertersSuite extends SparkPlanTest {
 
@@ -87,4 +91,36 @@ class RowFormatConvertersSuite extends SparkPlanTest {
       input.map(Row.fromTuple)
     )
   }
+
+  test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") {
+    SparkPlan.currentContext.set(TestSQLContext)
+    val schema = ArrayType(StringType)
+    val rows = (1 to 100).map { i =>
+      InternalRow(new 
GenericArrayData(Array[Any](UTF8String.fromString(i.toString))))
+    }
+    val relation = LocalTableScan(Seq(AttributeReference("t", schema)()), rows)
+
+    val plan =
+      DummyPlan(
+        ConvertToSafe(
+          ConvertToUnsafe(relation)))
+    assert(plan.execute().collect().map(_.getUTF8String(0).toString) === (1 to 
100).map(_.toString))
+  }
+}
+
+case class DummyPlan(child: SparkPlan) extends UnaryNode {
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    child.execute().mapPartitions { iter =>
+      // cache all strings to make sure we have deep copied UTF8String inside 
incoming
+      // safe InternalRow.
+      val strings = new scala.collection.mutable.ArrayBuffer[UTF8String]
+      iter.foreach { row =>
+        strings += row.getArray(0).getUTF8String(0)
+      }
+      strings.map(InternalRow(_)).iterator
+    }
+  }
+
+  override def output: Seq[Attribute] = Seq(AttributeReference("a", 
StringType)())
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to