This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 919e60f84ca2 [SPARK-51739][PYTHON] Validate Arrow schema from 
mapInArrow & mapInPandas & DataSource
919e60f84ca2 is described below

commit 919e60f84ca2bc357355c04fc7b2603ef61d818e
Author: Haoyu Weng <wengh...@gmail.com>
AuthorDate: Mon Apr 14 07:51:32 2025 +0900

    [SPARK-51739][PYTHON] Validate Arrow schema from mapInArrow & mapInPandas & 
DataSource
    
    ### What changes were proposed in this pull request?
    
    Check the actual Arrow batch schema against the declared schema in 
`MapInBatchEvaluator`, throwing error if they don't match.
    
    Also fix Pandas to Arrow conversion in `ArrowStreamPandasUDFSerializer` to 
respect nullability of output schema fields.
    
    ### Why are the changes needed?
    
    To improve error message and reject suspicious usage.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.
    
    #### Behaviour change
    
    1. Some previously suspicious but accepted schema mismatches are now no 
longer valid.
    
        This includes:
        - extraneous fields (previously ignored)
        - wrong order of fields of the same type (previously accepted but in 
wrong order)
        - expected non-nullable field is actually nullable (previously ignored)
    
        Example:
        ```py
        from pyspark.sql.datasource import DataSource, DataSourceReader
        from pyspark.sql.pandas.types import to_arrow_schema
        import pyarrow as pa
    
        expected = StructType.fromDDL("a int, b int")
        actual = StructType.fromDDL("b int, a int")  # wrong order of fields
    
        class TestDataSource(DataSource):
            def schema(self):
                return expected
            def reader(self, schema):
                return TestReader()
    
        class TestReader(DataSourceReader):
            def read(self, partition):
                schema = to_arrow_schema(actual)
                yield pa.record_batch([[1], [2]], schema=schema)
    
        spark.dataSource.register(TestDataSource)
        spark.read.format("TestDataSource").load().show()
        ```
        Before:
        ```
        +---+---+
        |  a|  b|
        +---+---+
        |  1|  2|
        +---+---+
        ```
        Now:
        ```
        org.apache.spark.SparkException: [ARROW_TYPE_MISMATCH] Invalid schema 
from pandas_udf(): expected 
StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)), 
got StructType(StructField(b,LongType,true),StructField(a,LongType,true)). 
SQLSTATE: 42K0G
        ```
    
    2. For other schema mismatches, the error changed from internal error to a 
clearer `ARROW_TYPE_MISMATCH` error.
    
        This includes
        - wrong field types
        - less than expected number of fields
    
        Example:
        ```py
        from pyspark.sql.pandas.types import to_arrow_schema
        from pyspark.sql.types import StructType, StructField, IntegerType
        import pyarrow as pa
    
        expected = StructType([StructField("a", IntegerType()), 
StructField("b", IntegerType())])
        actual = StructType([StructField("a", IntegerType())])  # missing a 
column
    
        def fun(iterator):
            for batch in iterator:
                schema = to_arrow_schema(actual)
                yield pa.record_batch([[1]], schema=schema)
    
        spark.range(2).mapInArrow(fun, expected).show()
        ```
        Before:
        ```
        java.lang.ArrayIndexOutOfBoundsException: Index 1 out of bounds for 
length 1
            at 
org.apache.spark.sql.vectorized.ArrowColumnVector.getChild(ArrowColumnVector.java:134)
            at 
org.apache.spark.sql.execution.python.MapInBatchEvaluatorFactory$MapInBatchEvaluator.$anonfun$eval$3(MapInBatchEvaluatorFactory.scala:82)
            ...
        ```
        Now:
        ```
        org.apache.spark.SparkException: [ARROW_TYPE_MISMATCH] Invalid schema 
from pandas_udf(): expected 
StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)), 
got StructType(StructField(a,IntegerType,true)). SQLSTATE: 42K0G
        ```
    
    ### How was this patch tested?
    
    End-to-end tests in `python/pyspark/sql/tests/arrow/test_arrow_map.py`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #50531 from wengh/validate-arrow-type.
    
    Authored-by: Haoyu Weng <wengh...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/pandas/serializers.py           | 17 ++++++---
 python/pyspark/sql/tests/arrow/test_arrow_map.py   | 40 +++++++++++++++++++++-
 python/pyspark/sql/tests/pandas/test_pandas_map.py | 25 ++++++++++++++
 .../org/apache/spark/sql/internal/SQLConf.scala    | 11 ++++++
 .../v2/python/UserDefinedPythonDataSource.scala    |  1 +
 .../python/MapInBatchEvaluatorFactory.scala        | 23 ++++++++++---
 .../sql/execution/python/MapInBatchExec.scala      |  1 +
 7 files changed, 109 insertions(+), 9 deletions(-)

diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index de9020f633c5..e036add0210d 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -20,6 +20,7 @@ Serializers for PyArrow and pandas conversions. See 
`pyspark.serializers` for mo
 """
 
 from itertools import groupby
+from typing import TYPE_CHECKING, Optional
 
 import pyspark
 from pyspark.errors import PySparkRuntimeError, PySparkTypeError, 
PySparkValueError
@@ -48,6 +49,10 @@ from pyspark.sql.types import (
     IntegerType,
 )
 
+if TYPE_CHECKING:
+    import pandas as pd
+    import pyarrow as pa
+
 
 class SpecialLengths:
     END_OF_DATA_SECTION = -1
@@ -472,7 +477,12 @@ class 
ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
             )
         return s
 
-    def _create_struct_array(self, df, arrow_struct_type, spark_type=None):
+    def _create_struct_array(
+        self,
+        df: "pd.DataFrame",
+        arrow_struct_type: "pa.StructType",
+        spark_type: Optional[StructType] = None,
+    ):
         """
         Create an Arrow StructArray from the given pandas.DataFrame and arrow 
struct type.
 
@@ -480,7 +490,7 @@ class 
ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
         ----------
         df : pandas.DataFrame
             A pandas DataFrame
-        arrow_struct_type : pyarrow.DataType
+        arrow_struct_type : pyarrow.StructType
             pyarrow struct type
 
         Returns
@@ -518,8 +528,7 @@ class 
ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
                 for i, field in enumerate(arrow_struct_type)
             ]
 
-        struct_names = [field.name for field in arrow_struct_type]
-        return pa.StructArray.from_arrays(struct_arrs, struct_names)
+        return pa.StructArray.from_arrays(struct_arrs, 
fields=list(arrow_struct_type))
 
     def _create_batch(self, series):
         """
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_map.py 
b/python/pyspark/sql/tests/arrow/test_arrow_map.py
index fb0d7b751e53..fa2ce69c4fa5 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_map.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_map.py
@@ -124,7 +124,7 @@ class MapInArrowTestsMixin(object):
         def empty_rows(_):
             return iter([pa.RecordBatch.from_pandas(pd.DataFrame({"a": []}))])
 
-        self.assertEqual(self.spark.range(10).mapInArrow(empty_rows, "a 
int").count(), 0)
+        self.assertEqual(self.spark.range(10).mapInArrow(empty_rows, "a 
double").count(), 0)
 
     def test_chain_map_in_arrow(self):
         def func(iterator):
@@ -175,6 +175,44 @@ class MapInArrowTestsMixin(object):
             with 
self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}):
                 MapInArrowTests.test_map_in_arrow(self)
 
+    def test_nested_extraneous_field(self):
+        def func(iterator):
+            for _ in iterator:
+                struct_arr = pa.StructArray.from_arrays([[1, 2], [3, 4]], 
names=["a", "b"])
+                yield pa.RecordBatch.from_arrays([struct_arr], ["x"])
+
+        df = self.spark.range(1)
+        with self.assertRaisesRegex(Exception, 
r"ARROW_TYPE_MISMATCH.*SQL_MAP_ARROW_ITER_UDF"):
+            df.mapInArrow(func, "x struct<b:int>").collect()
+
+    def test_top_level_wrong_order(self):
+        def func(iterator):
+            for _ in iterator:
+                yield pa.RecordBatch.from_arrays([[1], [2]], ["b", "a"])
+
+        df = self.spark.range(1)
+        with self.assertRaisesRegex(Exception, 
r"ARROW_TYPE_MISMATCH.*SQL_MAP_ARROW_ITER_UDF"):
+            df.mapInArrow(func, "a int, b int").collect()
+
+    def test_nullability_widen(self):
+        def func(iterator):
+            for _ in iterator:
+                yield pa.RecordBatch.from_arrays([[1]], ["a"])
+
+        df = self.spark.range(1)
+        with self.assertRaisesRegex(Exception, 
r"ARROW_TYPE_MISMATCH.*SQL_MAP_ARROW_ITER_UDF"):
+            df.mapInArrow(func, "a int not null").collect()
+
+    def test_nullability_narrow(self):
+        def func(iterator):
+            for _ in iterator:
+                yield pa.RecordBatch.from_arrays(
+                    [[1]], pa.schema([pa.field("a", pa.int32(), 
nullable=False)])
+                )
+
+        df = self.spark.range(1)
+        df.mapInArrow(func, "a int").collect()
+
 
 class MapInArrowTests(MapInArrowTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_map.py
index 692f9705411e..7e2221fc1a77 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py
@@ -24,6 +24,8 @@ from typing import cast
 from pyspark.sql import Row
 from pyspark.sql.functions import col, encode, lit
 from pyspark.errors import PythonException
+from pyspark.sql.session import SparkSession
+from pyspark.sql.types import StructType
 from pyspark.testing.sqlutils import (
     ReusedSQLTestCase,
     have_pandas,
@@ -42,6 +44,8 @@ if have_pandas:
     cast(str, pandas_requirement_message or pyarrow_requirement_message),
 )
 class MapInPandasTestsMixin:
+    spark: SparkSession
+
     @staticmethod
     def identity_dataframes_iter(*columns: str):
         def func(iterator):
@@ -128,6 +132,27 @@ class MapInPandasTestsMixin:
         expected = df.collect()
         self.assertEqual(actual, expected)
 
+    def test_not_null(self):
+        def func(iterator):
+            for _ in iterator:
+                yield pd.DataFrame({"a": [1, 2]})
+
+        schema = "a long not null"
+        df = self.spark.range(1).mapInPandas(func, schema)
+        self.assertEqual(df.schema, StructType.fromDDL(schema))
+        self.assertEqual(df.collect(), [Row(1), Row(2)])
+
+    def test_violate_not_null(self):
+        def func(iterator):
+            for _ in iterator:
+                yield pd.DataFrame({"a": [1, None]})
+
+        schema = "a long not null"
+        df = self.spark.range(1).mapInPandas(func, schema)
+        self.assertEqual(df.schema, StructType.fromDDL(schema))
+        with self.assertRaisesRegex(Exception, "is null"):
+            df.collect()
+
     def test_different_output_length(self):
         def func(iterator):
             for _ in iterator:
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 653dbfcf330c..76d2dcd91c32 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3666,6 +3666,15 @@ object SQLConf {
       // show full stacktrace in tests but hide in production by default.
       .createWithDefault(!Utils.isTesting)
 
+  val PYSPARK_ARROW_VALIDATE_SCHEMA =
+    buildConf("spark.sql.execution.arrow.pyspark.validateSchema.enabled")
+      .doc(
+        "When true, validate the schema of Arrow batches returned by 
mapInArrow, mapInPandas " +
+        "and DataSource against the expected schema to ensure that they are 
compatible.")
+      .version("4.1.0")
+      .booleanConf
+      .createWithDefault(true)
+
   val PYTHON_UDF_ARROW_ENABLED =
     buildConf("spark.sql.execution.pythonUDF.arrow.enabled")
       .doc("Enable Arrow optimization in regular Python UDFs. This 
optimization " +
@@ -6583,6 +6592,8 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
 
   def pysparkSimplifiedTraceback: Boolean = 
getConf(PYSPARK_SIMPLIFIED_TRACEBACK)
 
+  def pysparkArrowValidateSchema: Boolean = 
getConf(PYSPARK_ARROW_VALIDATE_SCHEMA)
+
   def pandasGroupedMapAssignColumnsByName: Boolean =
     getConf(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME)
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
index 4664e957ab31..14aeba92dafe 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
@@ -163,6 +163,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: 
PythonFunction) {
       toAttributes(outputSchema),
       Seq((ChainedPythonFunctions(Seq(pythonUDF.func)), 
pythonUDF.resultId.id)),
       inputSchema,
+      outputSchema,
       conf.arrowMaxRecordsPerBatch,
       pythonEvalType,
       conf.sessionLocalTimeZone,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
index 88b63f3b2dd0..9e3e8610ed37 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
@@ -20,17 +20,20 @@ package org.apache.spark.sql.execution.python
 import scala.jdk.CollectionConverters._
 
 import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory, 
TaskContext}
-import org.apache.spark.api.python.ChainedPythonFunctions
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.metric.SQLMetric
-import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
 import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
 
 class MapInBatchEvaluatorFactory(
     output: Seq[Attribute],
     chainedFunc: Seq[(ChainedPythonFunctions, Long)],
-    outputTypes: StructType,
+    inputSchema: StructType,
+    outputSchema: DataType,
     batchSize: Int,
     pythonEvalType: Int,
     sessionLocalTimeZone: String,
@@ -63,7 +66,7 @@ class MapInBatchEvaluatorFactory(
         chainedFunc,
         pythonEvalType,
         argOffsets,
-        StructType(Array(StructField("struct", outputTypes))),
+        StructType(Array(StructField("struct", inputSchema))),
         sessionLocalTimeZone,
         largeVarTypes,
         pythonRunnerConf,
@@ -75,6 +78,18 @@ class MapInBatchEvaluatorFactory(
       val unsafeProj = UnsafeProjection.create(output, output)
 
       columnarBatchIter.flatMap { batch =>
+        if (SQLConf.get.pysparkArrowValidateSchema) {
+          // Ensure the schema matches the expected schema, but allowing 
nullable fields in the
+          // output schema to become non-nullable in the actual schema.
+          val actualSchema = batch.column(0).dataType()
+          val isCompatible =
+            DataType.equalsIgnoreCompatibleNullability(from = actualSchema, to 
= outputSchema)
+          if (!isCompatible) {
+            throw QueryExecutionErrors.arrowDataTypeMismatchError(
+              PythonEvalType.toString(pythonEvalType), Seq(outputSchema), 
Seq(actualSchema))
+          }
+        }
+
         // Scalar Iterator UDF returns a StructType column in ColumnarBatch, 
select
         // the children here
         val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
index 096e9d7d1642..c003d503c7ca 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
@@ -56,6 +56,7 @@ trait MapInBatchExec extends UnaryExecNode with 
PythonSQLMetrics {
       output,
       chainedFunc,
       child.schema,
+      pythonUDF.dataType,
       conf.arrowMaxRecordsPerBatch,
       pythonEvalType,
       conf.sessionLocalTimeZone,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to