This is an automated email from the ASF dual-hosted git repository.
ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5633e9312137 [SPARK-52288][PS] Avoid INVALID_ARRAY_INDEX in
`split`/`rsplit` when ANSI mode is on
5633e9312137 is described below
commit 5633e9312137ed648609023e6af5ccae56b88986
Author: Xinrong Meng <[email protected]>
AuthorDate: Fri Jun 6 11:22:39 2025 -0700
[SPARK-52288][PS] Avoid INVALID_ARRAY_INDEX in `split`/`rsplit` when ANSI
mode is on
### What changes were proposed in this pull request?
Avoid INVALID_ARRAY_INDEX in `split`/`rsplit` when ANSI mode is on
### Why are the changes needed?
Ensure pandas on Spark works well with ANSI mode on.
Part of https://issues.apache.org/jira/browse/SPARK-52169.
### Does this PR introduce _any_ user-facing change?
Yes. INVALID_ARRAY_INDEX no longer fails `split`/`rsplit` when ANSI mode is
on
```py
>>> spark.conf.get("spark.sql.ansi.enabled")
'true'
>>> import pandas as pd
>>> pser = pd.Series(["hello-world", "short"])
>>> psser = ps.from_pandas(pser)
```
FROM
```py
>>> psser.str.split("-", n=1, expand=True)
25/05/28 14:52:10 ERROR Executor: Exception in task 10.0 in stage 2.0 (TID
15)
org.apache.spark.SparkArrayIndexOutOfBoundsException: [INVALID_ARRAY_INDEX]
The index 1 is out of bounds. The array has 1 elements. Use the SQL function
`get()` to tolerate accessing element at invalid index and return NULL instead.
SQLSTATE: 22003
== DataFrame ==
"__getitem__" was called from
<stdin>:1
...
```
TO
```py
>>> psser.str.split("-", n=1, expand=True)
0 1
0 hello world
1 short None
```
### How was this patch tested?
Unit tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #51006 from xinrong-meng/arr_idx_enable.
Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Takuya Ueshin <[email protected]>
---
python/pyspark/pandas/strings.py | 17 +++++++++++++++--
.../pyspark/pandas/tests/series/test_string_ops_adv.py | 7 ++++---
2 files changed, 19 insertions(+), 5 deletions(-)
diff --git a/python/pyspark/pandas/strings.py b/python/pyspark/pandas/strings.py
index 7e572bd1fae3..dc1544d8be39 100644
--- a/python/pyspark/pandas/strings.py
+++ b/python/pyspark/pandas/strings.py
@@ -32,6 +32,7 @@ from typing import (
import numpy as np
import pandas as pd
+from pyspark.pandas.utils import is_ansi_mode_enabled
from pyspark.sql.types import StringType, BinaryType, ArrayType, LongType,
MapType
from pyspark.sql import functions as F
from pyspark.sql.functions import pandas_udf
@@ -2031,7 +2032,13 @@ class StringMethods:
if expand:
psdf = psser.to_frame()
scol = psdf._internal.data_spark_columns[0]
- spark_columns = [scol[i].alias(str(i)) for i in range(n + 1)]
+ spark_session = self._data._internal.spark_frame.sparkSession
+ if is_ansi_mode_enabled(spark_session):
+ spark_columns = [
+ F.try_element_at(scol, F.lit(i + 1)).alias(str(i)) for i
in range(n + 1)
+ ]
+ else:
+ spark_columns = [scol[i].alias(str(i)) for i in range(n + 1)]
column_labels = [(i,) for i in range(n + 1)]
internal = psdf._internal.with_new_columns(
spark_columns,
@@ -2178,7 +2185,13 @@ class StringMethods:
if expand:
psdf = psser.to_frame()
scol = psdf._internal.data_spark_columns[0]
- spark_columns = [scol[i].alias(str(i)) for i in range(n + 1)]
+ spark_session = self._data._internal.spark_frame.sparkSession
+ if is_ansi_mode_enabled(spark_session):
+ spark_columns = [
+ F.try_element_at(scol, F.lit(i + 1)).alias(str(i)) for i
in range(n + 1)
+ ]
+ else:
+ spark_columns = [scol[i].alias(str(i)) for i in range(n + 1)]
column_labels = [(i,) for i in range(n + 1)]
internal = psdf._internal.with_new_columns(
spark_columns,
diff --git a/python/pyspark/pandas/tests/series/test_string_ops_adv.py
b/python/pyspark/pandas/tests/series/test_string_ops_adv.py
index e00252110dae..b0e4c69a35ea 100644
--- a/python/pyspark/pandas/tests/series/test_string_ops_adv.py
+++ b/python/pyspark/pandas/tests/series/test_string_ops_adv.py
@@ -22,7 +22,6 @@ import re
from pyspark import pandas as ps
from pyspark.testing.pandasutils import PandasOnSparkTestCase
from pyspark.testing.sqlutils import SQLTestUtils
-from pyspark.testing.utils import is_ansi_mode_test,
ansi_mode_not_supported_message
class SeriesStringOpsAdvMixin:
@@ -174,7 +173,6 @@ class SeriesStringOpsAdvMixin:
self.check_func(lambda x: x.str.slice_replace(stop=2, repl="X"))
self.check_func(lambda x: x.str.slice_replace(start=1, stop=3,
repl="X"))
- @unittest.skipIf(is_ansi_mode_test, ansi_mode_not_supported_message)
def test_string_split(self):
self.check_func_on_series(lambda x: repr(x.str.split()),
self.pser[:-1])
self.check_func_on_series(lambda x: repr(x.str.split(r"p*")),
self.pser[:-1])
@@ -185,7 +183,8 @@ class SeriesStringOpsAdvMixin:
with self.assertRaises(NotImplementedError):
self.check_func(lambda x: x.str.split(expand=True))
- @unittest.skipIf(is_ansi_mode_test, ansi_mode_not_supported_message)
+ self.check_func_on_series(lambda x: repr(x.str.split("-", n=1,
expand=True)), pser)
+
def test_string_rsplit(self):
self.check_func_on_series(lambda x: repr(x.str.rsplit()),
self.pser[:-1])
self.check_func_on_series(lambda x: repr(x.str.rsplit(r"p*")),
self.pser[:-1])
@@ -196,6 +195,8 @@ class SeriesStringOpsAdvMixin:
with self.assertRaises(NotImplementedError):
self.check_func(lambda x: x.str.rsplit(expand=True))
+ self.check_func_on_series(lambda x: repr(x.str.rsplit("-", n=1,
expand=True)), pser)
+
def test_string_translate(self):
m = str.maketrans({"a": "X", "e": "Y", "i": None})
self.check_func(lambda x: x.str.translate(m))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]