This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9cea0fb663c [SPARK-39979][SQL] Add option to use large variable width
vectors for arrow UDF operations
9cea0fb663c is described below
commit 9cea0fb663caa0ff13e07b2424cabeb56e6b9dbd
Author: Adam Binford <[email protected]>
AuthorDate: Mon May 29 09:05:24 2023 +0900
[SPARK-39979][SQL] Add option to use large variable width vectors for arrow
UDF operations
### What changes were proposed in this pull request?
Adds a new config that uses the `LargeUtf8` and `LargeBinary` arrow types
for arrow-based UDF operations. These arrow types make arrow use
`LargeVarCharVector` and `LargeVarBinaryVector` instead of the regular
`VarCharVector` and `VarBinaryVector` respectively. This config is disabled by
default to maintain the current behavior.
### Why are the changes needed?
`VarCharVector` and `VarBinaryVector` have a size limit of 2 GiB for a
single vector. This is because they use 4 byte integers to track the offsets of
each value in the vector. During certain operations, it is possible to hit this
limit. The most affected way that we've run into this is during a
`applyInPandas` operation, since the entire group is sent as a single
RecordBatch, and there is no way to chunk up any smaller than the entire group.
However, other map and UDF operations can [...]
The large vector types use an 8 byte long to track value offsets, removing
the 2 GiB total size limit.
### Does this PR introduce _any_ user-facing change?
Adds an option that can help users get around what currently results in
`IndexOutOfBoundsException`, though this exception being raised is a bug that
was fixed in Arrow and it should actually be a `OversizedAllocationException`
in the next release which suggests using the large variable width types instead.
### How was this patch tested?
A few new tests are added. I also enabled the setting by default for a full
CI run and all existing tests passed. I can add more tests if needed.
Closes #39572 from Kimahriman/large-binary-vector.
Authored-by: Adam Binford <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/pandas/types.py | 4 ++
python/pyspark/sql/tests/pandas/test_pandas_map.py | 19 +++++++-
python/pyspark/sql/tests/test_arrow_map.py | 15 ++++++
.../spark/sql/vectorized/ArrowColumnVector.java | 44 +++++++++++++++++
.../spark/sql/execution/arrow/ArrowWriter.scala | 30 ++++++++++++
.../org/apache/spark/sql/internal/SQLConf.scala | 12 +++++
.../org/apache/spark/sql/util/ArrowUtils.scala | 37 +++++++++-----
.../execution/python/AggregateInPandasExec.scala | 2 +
.../ApplyInPandasWithStatePythonRunner.scala | 2 +
.../sql/execution/python/ArrowEvalPythonExec.scala | 2 +
.../sql/execution/python/ArrowPythonRunner.scala | 1 +
.../python/FlatMapGroupsInPandasExec.scala | 2 +
.../sql/execution/python/MapInBatchExec.scala | 3 ++
.../sql/execution/python/PythonArrowInput.scala | 5 +-
.../sql/execution/python/WindowInPandasExec.scala | 2 +
.../sql/execution/arrow/ArrowWriterSuite.scala | 8 +++-
.../sql/vectorized/ArrowColumnVectorSuite.scala | 56 +++++++++++++++++++++-
17 files changed, 229 insertions(+), 15 deletions(-)
diff --git a/python/pyspark/sql/pandas/types.py
b/python/pyspark/sql/pandas/types.py
index ae7c25e0828..757deff6130 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -166,8 +166,12 @@ def from_arrow_type(at: "pa.DataType",
prefer_timestamp_ntz: bool = False) -> Da
spark_type = DecimalType(precision=at.precision, scale=at.scale)
elif types.is_string(at):
spark_type = StringType()
+ elif types.is_large_string(at):
+ spark_type = StringType()
elif types.is_binary(at):
spark_type = BinaryType()
+ elif types.is_large_binary(at):
+ spark_type = BinaryType()
elif types.is_date32(at):
spark_type = DateType()
elif types.is_timestamp(at) and prefer_timestamp_ntz and at.tz is None:
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py
b/python/pyspark/sql/tests/pandas/test_pandas_map.py
index 2f6f3f0df57..3d9a90bc81c 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py
@@ -22,7 +22,7 @@ import unittest
from typing import cast
from pyspark.sql import Row
-from pyspark.sql.functions import lit
+from pyspark.sql.functions import col, encode, lit
from pyspark.errors import PythonException
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
@@ -68,6 +68,23 @@ class MapInPandasTestsMixin:
expected = df.collect()
self.assertEqual(actual, expected)
+ def test_large_variable_types(self):
+ with self.sql_conf({"spark.sql.execution.arrow.useLargeVarTypes":
True}):
+
+ def func(iterator):
+ for pdf in iterator:
+ assert isinstance(pdf, pd.DataFrame)
+ yield pdf
+
+ df = (
+ self.spark.range(10, numPartitions=3)
+ .select(col("id").cast("string").alias("str"))
+ .withColumn("bin", encode(col("str"), "utf8"))
+ )
+ actual = df.mapInPandas(func, "str string, bin binary").collect()
+ expected = df.collect()
+ self.assertEqual(actual, expected)
+
def test_different_output_length(self):
def func(iterator):
for _ in iterator:
diff --git a/python/pyspark/sql/tests/test_arrow_map.py
b/python/pyspark/sql/tests/test_arrow_map.py
index ff3d9b96b6b..050f2c32665 100644
--- a/python/pyspark/sql/tests/test_arrow_map.py
+++ b/python/pyspark/sql/tests/test_arrow_map.py
@@ -64,6 +64,21 @@ class MapInArrowTestsMixin(object):
expected = df.collect()
self.assertEqual(actual, expected)
+ def test_large_variable_width_types(self):
+ with self.sql_conf({"spark.sql.execution.arrow.useLargeVarTypes":
True}):
+ data = [("foo", b"foo"), (None, None), ("bar", b"bar")]
+ df = self.spark.createDataFrame(data, "a string, b binary")
+
+ def func(iterator):
+ for batch in iterator:
+ assert isinstance(batch, pa.RecordBatch)
+ assert batch.schema.types == [pa.large_string(),
pa.large_binary()]
+ yield batch
+
+ actual = df.mapInArrow(func, df.schema).collect()
+ expected = df.collect()
+ self.assertEqual(actual, expected)
+
def test_different_output_length(self):
def func(iterator):
for _ in iterator:
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
index 742cf511395..635ad9994cb 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
@@ -19,6 +19,7 @@ package org.apache.spark.sql.vectorized;
import org.apache.arrow.vector.*;
import org.apache.arrow.vector.complex.*;
+import org.apache.arrow.vector.holders.NullableLargeVarCharHolder;
import org.apache.arrow.vector.holders.NullableVarCharHolder;
import org.apache.spark.annotation.DeveloperApi;
@@ -160,8 +161,12 @@ public class ArrowColumnVector extends ColumnVector {
accessor = new DecimalAccessor((DecimalVector) vector);
} else if (vector instanceof VarCharVector) {
accessor = new StringAccessor((VarCharVector) vector);
+ } else if (vector instanceof LargeVarCharVector) {
+ accessor = new LargeStringAccessor((LargeVarCharVector) vector);
} else if (vector instanceof VarBinaryVector) {
accessor = new BinaryAccessor((VarBinaryVector) vector);
+ } else if (vector instanceof LargeVarBinaryVector) {
+ accessor = new LargeBinaryAccessor((LargeVarBinaryVector) vector);
} else if (vector instanceof DateDayVector) {
accessor = new DateAccessor((DateDayVector) vector);
} else if (vector instanceof TimeStampMicroTZVector) {
@@ -406,6 +411,30 @@ public class ArrowColumnVector extends ColumnVector {
}
}
+ static class LargeStringAccessor extends ArrowVectorAccessor {
+
+ private final LargeVarCharVector accessor;
+ private final NullableLargeVarCharHolder stringResult = new
NullableLargeVarCharHolder();
+
+ LargeStringAccessor(LargeVarCharVector vector) {
+ super(vector);
+ this.accessor = vector;
+ }
+
+ @Override
+ final UTF8String getUTF8String(int rowId) {
+ accessor.get(rowId, stringResult);
+ if (stringResult.isSet == 0) {
+ return null;
+ } else {
+ return UTF8String.fromAddress(null,
+ stringResult.buffer.memoryAddress() + stringResult.start,
+ // A single string cannot be larger than the max integer size, so
the conversion is safe
+ (int)(stringResult.end - stringResult.start));
+ }
+ }
+ }
+
static class BinaryAccessor extends ArrowVectorAccessor {
private final VarBinaryVector accessor;
@@ -421,6 +450,21 @@ public class ArrowColumnVector extends ColumnVector {
}
}
+ static class LargeBinaryAccessor extends ArrowVectorAccessor {
+
+ private final LargeVarBinaryVector accessor;
+
+ LargeBinaryAccessor(LargeVarBinaryVector vector) {
+ super(vector);
+ this.accessor = vector;
+ }
+
+ @Override
+ final byte[] getBinary(int rowId) {
+ return accessor.getObject(rowId);
+ }
+ }
+
static class DateAccessor extends ArrowVectorAccessor {
private final DateDayVector accessor;
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
index efdbc583207..a55e4f0cfcd 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -60,7 +60,9 @@ object ArrowWriter {
case (DecimalType.Fixed(precision, scale), vector: DecimalVector) =>
new DecimalWriter(vector, precision, scale)
case (StringType, vector: VarCharVector) => new StringWriter(vector)
+ case (StringType, vector: LargeVarCharVector) => new
LargeStringWriter(vector)
case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector)
+ case (BinaryType, vector: LargeVarBinaryVector) => new
LargeBinaryWriter(vector)
case (DateType, vector: DateDayVector) => new DateWriter(vector)
case (TimestampType, vector: TimeStampMicroTZVector) => new
TimestampWriter(vector)
case (TimestampNTZType, vector: TimeStampMicroVector) => new
TimestampNTZWriter(vector)
@@ -255,6 +257,21 @@ private[arrow] class StringWriter(val valueVector:
VarCharVector) extends ArrowF
}
}
+private[arrow] class LargeStringWriter(
+ val valueVector: LargeVarCharVector) extends ArrowFieldWriter {
+
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ val utf8 = input.getUTF8String(ordinal)
+ val utf8ByteBuffer = utf8.getByteBuffer
+ // todo: for off-heap UTF8String, how to pass in to arrow without copy?
+ valueVector.setSafe(count, utf8ByteBuffer, utf8ByteBuffer.position(),
utf8.numBytes())
+ }
+}
+
private[arrow] class BinaryWriter(
val valueVector: VarBinaryVector) extends ArrowFieldWriter {
@@ -268,6 +285,19 @@ private[arrow] class BinaryWriter(
}
}
+private[arrow] class LargeBinaryWriter(
+ val valueVector: LargeVarBinaryVector) extends ArrowFieldWriter {
+
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ val bytes = input.getBinary(ordinal)
+ valueVector.setSafe(count, bytes, 0, bytes.length)
+ }
+}
+
private[arrow] class DateWriter(val valueVector: DateDayVector) extends
ArrowFieldWriter {
override def setNull(): Unit = {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index b1e0285e6ae..e8185202a7e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2825,6 +2825,16 @@ object SQLConf {
.intConf
.createWithDefault(10000)
+ val ARROW_EXECUTION_USE_LARGE_VAR_TYPES =
+ buildConf("spark.sql.execution.arrow.useLargeVarTypes")
+ .doc("When using Apache Arrow, use large variable width vectors for
string and binary " +
+ "types. Regular string and binary types have a 2GiB limit for a column
in a single " +
+ "record batch. Large variable types remove this limitation at the cost
of higher memory " +
+ "usage per value.")
+ .version("3.5.0")
+ .booleanConf
+ .createWithDefault(false)
+
val PANDAS_UDF_BUFFER_SIZE =
buildConf("spark.sql.execution.pandas.udf.buffer.size")
.doc(
@@ -4890,6 +4900,8 @@ class SQLConf extends Serializable with Logging {
def arrowMaxRecordsPerBatch: Int =
getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH)
+ def arrowUseLargeVarTypes: Boolean =
getConf(ARROW_EXECUTION_USE_LARGE_VAR_TYPES)
+
def pandasUDFBufferSize: Int = getConf(PANDAS_UDF_BUFFER_SIZE)
def pandasStructHandlingMode: String = getConf(PANDAS_STRUCT_HANDLING_MODE)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
index 719691a338f..e880e973176 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
@@ -37,7 +37,8 @@ private[sql] object ArrowUtils {
// todo: support more types.
/** Maps data type from Spark to Arrow. NOTE: timeZoneId required for
TimestampTypes */
- def toArrowType(dt: DataType, timeZoneId: String): ArrowType = dt match {
+ def toArrowType(
+ dt: DataType, timeZoneId: String, largeVarTypes: Boolean = false):
ArrowType = dt match {
case BooleanType => ArrowType.Bool.INSTANCE
case ByteType => new ArrowType.Int(8, true)
case ShortType => new ArrowType.Int(8 * 2, true)
@@ -45,8 +46,10 @@ private[sql] object ArrowUtils {
case LongType => new ArrowType.Int(8 * 8, true)
case FloatType => new
ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
case DoubleType => new
ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
- case StringType => ArrowType.Utf8.INSTANCE
- case BinaryType => ArrowType.Binary.INSTANCE
+ case StringType if !largeVarTypes => ArrowType.Utf8.INSTANCE
+ case BinaryType if !largeVarTypes => ArrowType.Binary.INSTANCE
+ case StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE
+ case BinaryType if largeVarTypes => ArrowType.LargeBinary.INSTANCE
case DecimalType.Fixed(precision, scale) => new
ArrowType.Decimal(precision, scale)
case DateType => new ArrowType.Date(DateUnit.DAY)
case TimestampType if timeZoneId == null =>
@@ -73,6 +76,8 @@ private[sql] object ArrowUtils {
if float.getPrecision() == FloatingPointPrecision.DOUBLE => DoubleType
case ArrowType.Utf8.INSTANCE => StringType
case ArrowType.Binary.INSTANCE => BinaryType
+ case ArrowType.LargeUtf8.INSTANCE => StringType
+ case ArrowType.LargeBinary.INSTANCE => BinaryType
case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale)
case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType
case ts: ArrowType.Timestamp
@@ -86,17 +91,22 @@ private[sql] object ArrowUtils {
/** Maps field from Spark to Arrow. NOTE: timeZoneId required for
TimestampType */
def toArrowField(
- name: String, dt: DataType, nullable: Boolean, timeZoneId: String):
Field = {
+ name: String,
+ dt: DataType,
+ nullable: Boolean,
+ timeZoneId: String,
+ largeVarTypes: Boolean = false): Field = {
dt match {
case ArrayType(elementType, containsNull) =>
val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null)
new Field(name, fieldType,
- Seq(toArrowField("element", elementType, containsNull,
timeZoneId)).asJava)
+ Seq(toArrowField("element", elementType, containsNull, timeZoneId,
+ largeVarTypes)).asJava)
case StructType(fields) =>
val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE,
null)
new Field(name, fieldType,
fields.map { field =>
- toArrowField(field.name, field.dataType, field.nullable,
timeZoneId)
+ toArrowField(field.name, field.dataType, field.nullable,
timeZoneId, largeVarTypes)
}.toSeq.asJava)
case MapType(keyType, valueType, valueContainsNull) =>
val mapType = new FieldType(nullable, new ArrowType.Map(false), null)
@@ -107,10 +117,13 @@ private[sql] object ArrowUtils {
.add(MapVector.KEY_NAME, keyType, nullable = false)
.add(MapVector.VALUE_NAME, valueType, nullable =
valueContainsNull),
nullable = false,
- timeZoneId)).asJava)
- case udt: UserDefinedType[_] => toArrowField(name, udt.sqlType,
nullable, timeZoneId)
+ timeZoneId,
+ largeVarTypes)).asJava)
+ case udt: UserDefinedType[_] =>
+ toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes)
case dataType =>
- val fieldType = new FieldType(nullable, toArrowType(dataType,
timeZoneId), null)
+ val fieldType = new FieldType(nullable, toArrowType(dataType,
timeZoneId,
+ largeVarTypes), null)
new Field(name, fieldType, Seq.empty[Field].asJava)
}
}
@@ -140,13 +153,15 @@ private[sql] object ArrowUtils {
def toArrowSchema(
schema: StructType,
timeZoneId: String,
- errorOnDuplicatedFieldNames: Boolean): Schema = {
+ errorOnDuplicatedFieldNames: Boolean,
+ largeVarTypes: Boolean = false): Schema = {
new Schema(schema.map { field =>
toArrowField(
field.name,
deduplicateFieldNames(field.dataType, errorOnDuplicatedFieldNames),
field.nullable,
- timeZoneId)
+ timeZoneId,
+ largeVarTypes)
}.asJava)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
index a9a9679bb36..c51a3a5cce3 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
@@ -101,6 +101,7 @@ case class AggregateInPandasExec(
val inputRDD = child.execute()
val sessionLocalTimeZone = conf.sessionLocalTimeZone
+ val largeVarTypes = conf.arrowUseLargeVarTypes
val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip
@@ -167,6 +168,7 @@ case class AggregateInPandasExec(
argOffsets,
aggInputSchema,
sessionLocalTimeZone,
+ largeVarTypes,
pythonRunnerConf,
pythonMetrics).compute(projectedRowIter, context.partitionId(),
context)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
index ac73e53266d..35676406f14 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
@@ -78,6 +78,8 @@ class ApplyInPandasWithStatePythonRunner(
override val simplifiedTraceback: Boolean =
sqlConf.pysparkSimplifiedTraceback
+ override protected val largeVarTypes: Boolean = sqlConf.arrowUseLargeVarTypes
+
override val bufferSize: Int = {
val configuredSize = sqlConf.pandasUDFBufferSize
if (configuredSize < 4) {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index b11dd4947af..86a5d13aed0 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -65,6 +65,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF],
resultAttrs: Seq[Attribute]
private val batchSize = conf.arrowMaxRecordsPerBatch
private val sessionLocalTimeZone = conf.sessionLocalTimeZone
+ private val largeVarTypes = conf.arrowUseLargeVarTypes
private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
protected override def evaluate(
@@ -85,6 +86,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF],
resultAttrs: Seq[Attribute]
argOffsets,
schema,
sessionLocalTimeZone,
+ largeVarTypes,
pythonRunnerConf,
pythonMetrics).compute(batchIter, context.partitionId(), context)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index d727c1b5ca0..175d67e9043 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -33,6 +33,7 @@ class ArrowPythonRunner(
argOffsets: Array[Array[Int]],
protected override val schema: StructType,
protected override val timeZoneId: String,
+ protected override val largeVarTypes: Boolean,
protected override val workerConf: Map[String, String],
val pythonMetrics: Map[String, SQLMetric])
extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs,
evalType, argOffsets)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
index 271ccdb6b27..8da53cc6c99 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
@@ -53,6 +53,7 @@ case class FlatMapGroupsInPandasExec(
extends SparkPlan with UnaryExecNode with PythonSQLMetrics {
private val sessionLocalTimeZone = conf.sessionLocalTimeZone
+ private val largeVarTypes = conf.arrowUseLargeVarTypes
private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
private val pandasFunction = func.asInstanceOf[PythonUDF].func
private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
@@ -89,6 +90,7 @@ case class FlatMapGroupsInPandasExec(
Array(argOffsets),
StructType.fromAttributes(dedupAttributes),
sessionLocalTimeZone,
+ largeVarTypes,
pythonRunnerConf,
pythonMetrics)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
index 0fe3acb14e8..8281435ca92 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
@@ -49,6 +49,8 @@ trait MapInBatchExec extends UnaryExecNode with
PythonSQLMetrics {
private val batchSize = conf.arrowMaxRecordsPerBatch
+ private val largeVarTypes = conf.arrowUseLargeVarTypes
+
override def outputPartitioning: Partitioning = child.outputPartitioning
override protected def doExecute(): RDD[InternalRow] = {
@@ -77,6 +79,7 @@ trait MapInBatchExec extends UnaryExecNode with
PythonSQLMetrics {
argOffsets,
StructType(Array(StructField("struct", outputTypes))),
sessionLocalTimeZone,
+ largeVarTypes,
pythonRunnerConf,
pythonMetrics).compute(batchIter, context.partitionId(), context)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index 26ce10b6aae..c78ea564f18 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -44,6 +44,8 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
protected val errorOnDuplicatedFieldNames: Boolean
+ protected val largeVarTypes: Boolean
+
protected def pythonMetrics: Map[String, SQLMetric]
protected def writeIteratorToArrowStream(
@@ -75,7 +77,8 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
}
protected override def writeIteratorToStream(dataOut: DataOutputStream):
Unit = {
- val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId,
errorOnDuplicatedFieldNames)
+ val arrowSchema = ArrowUtils.toArrowSchema(
+ schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
s"stdout writer for $pythonExec", 0, Long.MaxValue)
val root = VectorSchemaRoot.create(arrowSchema, allocator)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
index c5493079e40..e6a65dd61dc 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
@@ -185,6 +185,7 @@ case class WindowInPandasExec(
val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold
val spillThreshold = conf.windowExecBufferSpillThreshold
val sessionLocalTimeZone = conf.sessionLocalTimeZone
+ val largeVarTypes = conf.arrowUseLargeVarTypes
// Extract window expressions and window functions
val windowExpressions = expressions.flatMap(_.collect { case e:
WindowExpression => e })
@@ -385,6 +386,7 @@ case class WindowInPandasExec(
argOffsets,
pythonInputSchema,
sessionLocalTimeZone,
+ largeVarTypes,
pythonRunnerConf,
pythonMetrics).compute(pythonInput, context.partitionId(), context)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
index a88f423ae01..86a961137f4 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
@@ -27,7 +27,11 @@ import org.apache.spark.unsafe.types.UTF8String
class ArrowWriterSuite extends SparkFunSuite {
test("simple") {
- def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit =
{
+ def check(
+ dt: DataType,
+ data: Seq[Any],
+ timeZoneId: String = null,
+ largeVarTypes: Boolean = false): Unit = {
val datatype = dt match {
case _: DayTimeIntervalType => DayTimeIntervalType()
case _: YearMonthIntervalType => YearMonthIntervalType()
@@ -77,7 +81,9 @@ class ArrowWriterSuite extends SparkFunSuite {
check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d))
check(DecimalType.SYSTEM_DEFAULT, Seq(Decimal(1), Decimal(2), null,
Decimal(4)))
check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString))
+ check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString),
null, true)
check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null,
"d".getBytes()))
+ check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null,
"d".getBytes()), null, true)
check(DateType, Seq(0, 1, 2, null, 4))
check(TimestampType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong),
"America/Los_Angeles")
check(TimestampNTZType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong))
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/vectorized/ArrowColumnVectorSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/vectorized/ArrowColumnVectorSuite.scala
index 25beda99cd6..436cea50ad9 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/vectorized/ArrowColumnVectorSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/vectorized/ArrowColumnVectorSuite.scala
@@ -250,9 +250,36 @@ class ArrowColumnVectorSuite extends SparkFunSuite {
allocator.close()
}
+ test("large_string") {
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator("string", 0,
Long.MaxValue)
+ val vector = ArrowUtils.toArrowField("string", StringType, nullable =
true, null, true)
+ .createVector(allocator).asInstanceOf[LargeVarCharVector]
+ vector.allocateNew()
+
+ (0 until 10).foreach { i =>
+ val utf8 = s"str$i".getBytes("utf8")
+ vector.setSafe(i, utf8, 0, utf8.length)
+ }
+ vector.setNull(10)
+ vector.setValueCount(11)
+
+ val columnVector = new ArrowColumnVector(vector)
+ assert(columnVector.dataType === StringType)
+ assert(columnVector.hasNull)
+ assert(columnVector.numNulls === 1)
+
+ (0 until 10).foreach { i =>
+ assert(columnVector.getUTF8String(i) === UTF8String.fromString(s"str$i"))
+ }
+ assert(columnVector.isNullAt(10))
+
+ columnVector.close()
+ allocator.close()
+ }
+
test("binary") {
val allocator = ArrowUtils.rootAllocator.newChildAllocator("binary", 0,
Long.MaxValue)
- val vector = ArrowUtils.toArrowField("binary", BinaryType, nullable =
true, null)
+ val vector = ArrowUtils.toArrowField("binary", BinaryType, nullable =
true, null, false)
.createVector(allocator).asInstanceOf[VarBinaryVector]
vector.allocateNew()
@@ -277,6 +304,33 @@ class ArrowColumnVectorSuite extends SparkFunSuite {
allocator.close()
}
+ test("large_binary") {
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator("binary", 0,
Long.MaxValue)
+ val vector = ArrowUtils.toArrowField("binary", BinaryType, nullable =
true, null, true)
+ .createVector(allocator).asInstanceOf[LargeVarBinaryVector]
+ vector.allocateNew()
+
+ (0 until 10).foreach { i =>
+ val utf8 = s"str$i".getBytes("utf8")
+ vector.setSafe(i, utf8, 0, utf8.length)
+ }
+ vector.setNull(10)
+ vector.setValueCount(11)
+
+ val columnVector = new ArrowColumnVector(vector)
+ assert(columnVector.dataType === BinaryType)
+ assert(columnVector.hasNull)
+ assert(columnVector.numNulls === 1)
+
+ (0 until 10).foreach { i =>
+ assert(columnVector.getBinary(i) === s"str$i".getBytes("utf8"))
+ }
+ assert(columnVector.isNullAt(10))
+
+ columnVector.close()
+ allocator.close()
+ }
+
test("array") {
val allocator = ArrowUtils.rootAllocator.newChildAllocator("array", 0,
Long.MaxValue)
val vector = ArrowUtils.toArrowField("array", ArrayType(IntegerType),
nullable = true, null)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]