This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 919e60f84ca2 [SPARK-51739][PYTHON] Validate Arrow schema from mapInArrow & mapInPandas & DataSource 919e60f84ca2 is described below commit 919e60f84ca2bc357355c04fc7b2603ef61d818e Author: Haoyu Weng <wengh...@gmail.com> AuthorDate: Mon Apr 14 07:51:32 2025 +0900 [SPARK-51739][PYTHON] Validate Arrow schema from mapInArrow & mapInPandas & DataSource ### What changes were proposed in this pull request? Check the actual Arrow batch schema against the declared schema in `MapInBatchEvaluator`, throwing error if they don't match. Also fix Pandas to Arrow conversion in `ArrowStreamPandasUDFSerializer` to respect nullability of output schema fields. ### Why are the changes needed? To improve error message and reject suspicious usage. ### Does this PR introduce _any_ user-facing change? Yes. #### Behaviour change 1. Some previously suspicious but accepted schema mismatches are now no longer valid. This includes: - extraneous fields (previously ignored) - wrong order of fields of the same type (previously accepted but in wrong order) - expected non-nullable field is actually nullable (previously ignored) Example: ```py from pyspark.sql.datasource import DataSource, DataSourceReader from pyspark.sql.pandas.types import to_arrow_schema import pyarrow as pa expected = StructType.fromDDL("a int, b int") actual = StructType.fromDDL("b int, a int") # wrong order of fields class TestDataSource(DataSource): def schema(self): return expected def reader(self, schema): return TestReader() class TestReader(DataSourceReader): def read(self, partition): schema = to_arrow_schema(actual) yield pa.record_batch([[1], [2]], schema=schema) spark.dataSource.register(TestDataSource) spark.read.format("TestDataSource").load().show() ``` Before: ``` +---+---+ | a| b| +---+---+ | 1| 2| +---+---+ ``` Now: ``` org.apache.spark.SparkException: [ARROW_TYPE_MISMATCH] Invalid schema from pandas_udf(): expected StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)), got StructType(StructField(b,LongType,true),StructField(a,LongType,true)). SQLSTATE: 42K0G ``` 2. For other schema mismatches, the error changed from internal error to a clearer `ARROW_TYPE_MISMATCH` error. This includes - wrong field types - less than expected number of fields Example: ```py from pyspark.sql.pandas.types import to_arrow_schema from pyspark.sql.types import StructType, StructField, IntegerType import pyarrow as pa expected = StructType([StructField("a", IntegerType()), StructField("b", IntegerType())]) actual = StructType([StructField("a", IntegerType())]) # missing a column def fun(iterator): for batch in iterator: schema = to_arrow_schema(actual) yield pa.record_batch([[1]], schema=schema) spark.range(2).mapInArrow(fun, expected).show() ``` Before: ``` java.lang.ArrayIndexOutOfBoundsException: Index 1 out of bounds for length 1 at org.apache.spark.sql.vectorized.ArrowColumnVector.getChild(ArrowColumnVector.java:134) at org.apache.spark.sql.execution.python.MapInBatchEvaluatorFactory$MapInBatchEvaluator.$anonfun$eval$3(MapInBatchEvaluatorFactory.scala:82) ... ``` Now: ``` org.apache.spark.SparkException: [ARROW_TYPE_MISMATCH] Invalid schema from pandas_udf(): expected StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)), got StructType(StructField(a,IntegerType,true)). SQLSTATE: 42K0G ``` ### How was this patch tested? End-to-end tests in `python/pyspark/sql/tests/arrow/test_arrow_map.py` ### Was this patch authored or co-authored using generative AI tooling? No Closes #50531 from wengh/validate-arrow-type. Authored-by: Haoyu Weng <wengh...@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/pandas/serializers.py | 17 ++++++--- python/pyspark/sql/tests/arrow/test_arrow_map.py | 40 +++++++++++++++++++++- python/pyspark/sql/tests/pandas/test_pandas_map.py | 25 ++++++++++++++ .../org/apache/spark/sql/internal/SQLConf.scala | 11 ++++++ .../v2/python/UserDefinedPythonDataSource.scala | 1 + .../python/MapInBatchEvaluatorFactory.scala | 23 ++++++++++--- .../sql/execution/python/MapInBatchExec.scala | 1 + 7 files changed, 109 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index de9020f633c5..e036add0210d 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -20,6 +20,7 @@ Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for mo """ from itertools import groupby +from typing import TYPE_CHECKING, Optional import pyspark from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError @@ -48,6 +49,10 @@ from pyspark.sql.types import ( IntegerType, ) +if TYPE_CHECKING: + import pandas as pd + import pyarrow as pa + class SpecialLengths: END_OF_DATA_SECTION = -1 @@ -472,7 +477,12 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): ) return s - def _create_struct_array(self, df, arrow_struct_type, spark_type=None): + def _create_struct_array( + self, + df: "pd.DataFrame", + arrow_struct_type: "pa.StructType", + spark_type: Optional[StructType] = None, + ): """ Create an Arrow StructArray from the given pandas.DataFrame and arrow struct type. @@ -480,7 +490,7 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): ---------- df : pandas.DataFrame A pandas DataFrame - arrow_struct_type : pyarrow.DataType + arrow_struct_type : pyarrow.StructType pyarrow struct type Returns @@ -518,8 +528,7 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): for i, field in enumerate(arrow_struct_type) ] - struct_names = [field.name for field in arrow_struct_type] - return pa.StructArray.from_arrays(struct_arrs, struct_names) + return pa.StructArray.from_arrays(struct_arrs, fields=list(arrow_struct_type)) def _create_batch(self, series): """ diff --git a/python/pyspark/sql/tests/arrow/test_arrow_map.py b/python/pyspark/sql/tests/arrow/test_arrow_map.py index fb0d7b751e53..fa2ce69c4fa5 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_map.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_map.py @@ -124,7 +124,7 @@ class MapInArrowTestsMixin(object): def empty_rows(_): return iter([pa.RecordBatch.from_pandas(pd.DataFrame({"a": []}))]) - self.assertEqual(self.spark.range(10).mapInArrow(empty_rows, "a int").count(), 0) + self.assertEqual(self.spark.range(10).mapInArrow(empty_rows, "a double").count(), 0) def test_chain_map_in_arrow(self): def func(iterator): @@ -175,6 +175,44 @@ class MapInArrowTestsMixin(object): with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}): MapInArrowTests.test_map_in_arrow(self) + def test_nested_extraneous_field(self): + def func(iterator): + for _ in iterator: + struct_arr = pa.StructArray.from_arrays([[1, 2], [3, 4]], names=["a", "b"]) + yield pa.RecordBatch.from_arrays([struct_arr], ["x"]) + + df = self.spark.range(1) + with self.assertRaisesRegex(Exception, r"ARROW_TYPE_MISMATCH.*SQL_MAP_ARROW_ITER_UDF"): + df.mapInArrow(func, "x struct<b:int>").collect() + + def test_top_level_wrong_order(self): + def func(iterator): + for _ in iterator: + yield pa.RecordBatch.from_arrays([[1], [2]], ["b", "a"]) + + df = self.spark.range(1) + with self.assertRaisesRegex(Exception, r"ARROW_TYPE_MISMATCH.*SQL_MAP_ARROW_ITER_UDF"): + df.mapInArrow(func, "a int, b int").collect() + + def test_nullability_widen(self): + def func(iterator): + for _ in iterator: + yield pa.RecordBatch.from_arrays([[1]], ["a"]) + + df = self.spark.range(1) + with self.assertRaisesRegex(Exception, r"ARROW_TYPE_MISMATCH.*SQL_MAP_ARROW_ITER_UDF"): + df.mapInArrow(func, "a int not null").collect() + + def test_nullability_narrow(self): + def func(iterator): + for _ in iterator: + yield pa.RecordBatch.from_arrays( + [[1]], pa.schema([pa.field("a", pa.int32(), nullable=False)]) + ) + + df = self.spark.range(1) + df.mapInArrow(func, "a int").collect() + class MapInArrowTests(MapInArrowTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py b/python/pyspark/sql/tests/pandas/test_pandas_map.py index 692f9705411e..7e2221fc1a77 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py @@ -24,6 +24,8 @@ from typing import cast from pyspark.sql import Row from pyspark.sql.functions import col, encode, lit from pyspark.errors import PythonException +from pyspark.sql.session import SparkSession +from pyspark.sql.types import StructType from pyspark.testing.sqlutils import ( ReusedSQLTestCase, have_pandas, @@ -42,6 +44,8 @@ if have_pandas: cast(str, pandas_requirement_message or pyarrow_requirement_message), ) class MapInPandasTestsMixin: + spark: SparkSession + @staticmethod def identity_dataframes_iter(*columns: str): def func(iterator): @@ -128,6 +132,27 @@ class MapInPandasTestsMixin: expected = df.collect() self.assertEqual(actual, expected) + def test_not_null(self): + def func(iterator): + for _ in iterator: + yield pd.DataFrame({"a": [1, 2]}) + + schema = "a long not null" + df = self.spark.range(1).mapInPandas(func, schema) + self.assertEqual(df.schema, StructType.fromDDL(schema)) + self.assertEqual(df.collect(), [Row(1), Row(2)]) + + def test_violate_not_null(self): + def func(iterator): + for _ in iterator: + yield pd.DataFrame({"a": [1, None]}) + + schema = "a long not null" + df = self.spark.range(1).mapInPandas(func, schema) + self.assertEqual(df.schema, StructType.fromDDL(schema)) + with self.assertRaisesRegex(Exception, "is null"): + df.collect() + def test_different_output_length(self): def func(iterator): for _ in iterator: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 653dbfcf330c..76d2dcd91c32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3666,6 +3666,15 @@ object SQLConf { // show full stacktrace in tests but hide in production by default. .createWithDefault(!Utils.isTesting) + val PYSPARK_ARROW_VALIDATE_SCHEMA = + buildConf("spark.sql.execution.arrow.pyspark.validateSchema.enabled") + .doc( + "When true, validate the schema of Arrow batches returned by mapInArrow, mapInPandas " + + "and DataSource against the expected schema to ensure that they are compatible.") + .version("4.1.0") + .booleanConf + .createWithDefault(true) + val PYTHON_UDF_ARROW_ENABLED = buildConf("spark.sql.execution.pythonUDF.arrow.enabled") .doc("Enable Arrow optimization in regular Python UDFs. This optimization " + @@ -6583,6 +6592,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def pysparkSimplifiedTraceback: Boolean = getConf(PYSPARK_SIMPLIFIED_TRACEBACK) + def pysparkArrowValidateSchema: Boolean = getConf(PYSPARK_ARROW_VALIDATE_SCHEMA) + def pandasGroupedMapAssignColumnsByName: Boolean = getConf(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala index 4664e957ab31..14aeba92dafe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala @@ -163,6 +163,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { toAttributes(outputSchema), Seq((ChainedPythonFunctions(Seq(pythonUDF.func)), pythonUDF.resultId.id)), inputSchema, + outputSchema, conf.arrowMaxRecordsPerBatch, pythonEvalType, conf.sessionLocalTimeZone, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala index 88b63f3b2dd0..9e3e8610ed37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala @@ -20,17 +20,20 @@ package org.apache.spark.sql.execution.python import scala.jdk.CollectionConverters._ import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory, TaskContext} -import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} class MapInBatchEvaluatorFactory( output: Seq[Attribute], chainedFunc: Seq[(ChainedPythonFunctions, Long)], - outputTypes: StructType, + inputSchema: StructType, + outputSchema: DataType, batchSize: Int, pythonEvalType: Int, sessionLocalTimeZone: String, @@ -63,7 +66,7 @@ class MapInBatchEvaluatorFactory( chainedFunc, pythonEvalType, argOffsets, - StructType(Array(StructField("struct", outputTypes))), + StructType(Array(StructField("struct", inputSchema))), sessionLocalTimeZone, largeVarTypes, pythonRunnerConf, @@ -75,6 +78,18 @@ class MapInBatchEvaluatorFactory( val unsafeProj = UnsafeProjection.create(output, output) columnarBatchIter.flatMap { batch => + if (SQLConf.get.pysparkArrowValidateSchema) { + // Ensure the schema matches the expected schema, but allowing nullable fields in the + // output schema to become non-nullable in the actual schema. + val actualSchema = batch.column(0).dataType() + val isCompatible = + DataType.equalsIgnoreCompatibleNullability(from = actualSchema, to = outputSchema) + if (!isCompatible) { + throw QueryExecutionErrors.arrowDataTypeMismatchError( + PythonEvalType.toString(pythonEvalType), Seq(outputSchema), Seq(actualSchema)) + } + } + // Scalar Iterator UDF returns a StructType column in ColumnarBatch, select // the children here val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala index 096e9d7d1642..c003d503c7ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala @@ -56,6 +56,7 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics { output, chainedFunc, child.schema, + pythonUDF.dataType, conf.arrowMaxRecordsPerBatch, pythonEvalType, conf.sessionLocalTimeZone, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org