This is an automated email from the ASF dual-hosted git repository.
holden pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3f124c30ddd3 [SPARK-46168][PS] Add axis argument for idxmax
3f124c30ddd3 is described below
commit 3f124c30ddd3a169e6c2534f685cf006fae6fda1
Author: Devin Petersohn <[email protected]>
AuthorDate: Mon Feb 23 12:41:34 2026 -0800
[SPARK-46168][PS] Add axis argument for idxmax
### What changes were proposed in this pull request?
Add support for axis argument for idxmax.
### Why are the changes needed?
To support a missing API parameter
### Does this PR introduce _any_ user-facing change?
Yes, a new API
### How was this patch tested?
CI
### Was this patch authored or co-authored using generative AI tooling?
Co-authored-by: Claude Sonnet 4.5
Closes #54044 from devin-petersohn/devin/idxmax_axis.
Authored-by: Devin Petersohn <[email protected]>
Signed-off-by: Holden Karau <[email protected]>
---
dev/sparktestsupport/modules.py | 2 +
python/pyspark/pandas/frame.py | 92 ++++++++---
.../pandas/tests/computation/test_idxmax_idxmin.py | 170 +++++++++++++++++++++
.../computation/test_parity_idxmax_idxmin.py | 34 +++++
4 files changed, 281 insertions(+), 17 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index d62bf4414ffd..e1851602d020 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -874,6 +874,7 @@ pyspark_pandas = Module(
"pyspark.pandas.tests.computation.test_cumulative",
"pyspark.pandas.tests.computation.test_describe",
"pyspark.pandas.tests.computation.test_eval",
+ "pyspark.pandas.tests.computation.test_idxmax_idxmin",
"pyspark.pandas.tests.computation.test_melt",
"pyspark.pandas.tests.computation.test_missing_data",
"pyspark.pandas.tests.computation.test_pivot",
@@ -1322,6 +1323,7 @@ pyspark_pandas_connect = Module(
"pyspark.pandas.tests.connect.computation.test_parity_cumulative",
"pyspark.pandas.tests.connect.computation.test_parity_describe",
"pyspark.pandas.tests.connect.computation.test_parity_eval",
+ "pyspark.pandas.tests.connect.computation.test_parity_idxmax_idxmin",
"pyspark.pandas.tests.connect.computation.test_parity_melt",
"pyspark.pandas.tests.connect.computation.test_parity_missing_data",
"pyspark.pandas.tests.connect.computation.test_parity_pivot",
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 2fa90e8e15cf..5609f76cd719 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -12311,7 +12311,6 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
return self._apply_series_op(op, should_resolve=True)
- # TODO(SPARK-46168): axis = 1
def idxmax(self, axis: Axis = 0) -> "Series":
"""
Return index of first occurrence of maximum over requested axis.
@@ -12322,8 +12321,8 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
Parameters
----------
- axis : 0 or 'index'
- Can only be set to 0 now.
+ axis : {0 or 'index', 1 or 'columns'}, default 0
+ The axis to use. 0 or 'index' for row-wise, 1 or 'columns' for
column-wise.
Returns
-------
@@ -12351,6 +12350,15 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
c 2
dtype: int64
+ For axis=1, return the column label of the maximum value in each row:
+
+ >>> psdf.idxmax(axis=1)
+ 0 c
+ 1 c
+ 2 c
+ 3 c
+ dtype: object
+
For Multi-column Index
>>> psdf = ps.DataFrame({'a': [1, 2, 3, 2],
@@ -12371,23 +12379,73 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
c z 2
dtype: int64
"""
- max_cols = map(lambda scol: F.max(scol),
self._internal.data_spark_columns)
- sdf_max = self._internal.spark_frame.select(*max_cols).head()
- # `sdf_max` looks like below
- # +------+------+------+
- # |(a, x)|(b, y)|(c, z)|
- # +------+------+------+
- # | 3| 4.0| 400|
- # +------+------+------+
+ axis = validate_axis(axis)
+ if axis == 0:
+ max_cols = map(lambda scol: F.max(scol),
self._internal.data_spark_columns)
+ sdf_max = self._internal.spark_frame.select(*max_cols).head()
+ # `sdf_max` looks like below
+ # +------+------+------+
+ # |(a, x)|(b, y)|(c, z)|
+ # +------+------+------+
+ # | 3| 4.0| 400|
+ # +------+------+------+
+
+ conds = (
+ scol == max_val for scol, max_val in
zip(self._internal.data_spark_columns, sdf_max)
+ )
+ cond = reduce(lambda x, y: x | y, conds)
- conds = (
- scol == max_val for scol, max_val in
zip(self._internal.data_spark_columns, sdf_max)
- )
- cond = reduce(lambda x, y: x | y, conds)
+ psdf: DataFrame = DataFrame(self._internal.with_filter(cond))
- psdf: DataFrame = DataFrame(self._internal.with_filter(cond))
+ return cast(ps.Series,
ps.from_pandas(psdf._to_internal_pandas().idxmax()))
+ else:
+ from pyspark.pandas.series import first_series
+
+ column_labels = self._internal.column_labels
+
+ if len(column_labels) == 0:
+ # Check if DataFrame has rows - if yes, raise error; if no,
return empty Series
+ # to match pandas behavior
+ if len(self) > 0:
+ raise ValueError("attempt to get argmax of an empty
sequence")
+ else:
+ return ps.Series([], dtype=np.int64)
+
+ if self._internal.column_labels_level > 1:
+ raise NotImplementedError(
+ "idxmax with axis=1 does not support MultiIndex columns
yet"
+ )
+
+ max_value = F.greatest(
+ *[
+ F.coalesce(self._internal.spark_column_for(label),
F.lit(float("-inf")))
+ for label in column_labels
+ ],
+ F.lit(float("-inf")),
+ )
+
+ result = None
+ # Iterate over the column labels in reverse order to get the first
occurrence of the
+ # maximum value.
+ for label in reversed(column_labels):
+ scol = self._internal.spark_column_for(label)
+ label_value = label[0] if len(label) == 1 else label
+ condition = (scol == max_value) & scol.isNotNull()
+
+ result = (
+ F.when(condition, F.lit(label_value))
+ if result is None
+ else F.when(condition,
F.lit(label_value)).otherwise(result)
+ )
+
+ result = F.when(max_value == float("-inf"),
F.lit(None)).otherwise(result)
+
+ internal = self._internal.with_new_columns(
+ [result.alias(SPARK_DEFAULT_SERIES_NAME)],
+ column_labels=[None],
+ )
- return cast(ps.Series,
ps.from_pandas(psdf._to_internal_pandas().idxmax()))
+ return first_series(DataFrame(internal))
# TODO(SPARK-46168): axis = 1
def idxmin(self, axis: Axis = 0) -> "Series":
diff --git a/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py
b/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py
new file mode 100644
index 000000000000..2de14fccbc43
--- /dev/null
+++ b/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py
@@ -0,0 +1,170 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
+
+
+class FrameIdxMaxMinMixin:
+ def test_idxmax(self):
+ # Test basic axis=0 (default)
+ pdf = pd.DataFrame(
+ {
+ "a": [1, 2, 3, 2],
+ "b": [4.0, 2.0, 3.0, 1.0],
+ "c": [300, 200, 400, 200],
+ }
+ )
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(psdf.idxmax(), pdf.idxmax())
+ self.assert_eq(psdf.idxmax(axis=0), pdf.idxmax(axis=0))
+ self.assert_eq(psdf.idxmax(axis="index"), pdf.idxmax(axis="index"))
+
+ # Test axis=1
+ self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1))
+ self.assert_eq(psdf.idxmax(axis="columns"), pdf.idxmax(axis="columns"))
+
+ # Test with NAs
+ pdf = pd.DataFrame(
+ {
+ "a": [1.0, None, 3.0],
+ "b": [None, 2.0, None],
+ "c": [3.0, 4.0, None],
+ }
+ )
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(psdf.idxmax(), pdf.idxmax())
+ self.assert_eq(psdf.idxmax(axis=0), pdf.idxmax(axis=0))
+ self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1))
+
+ # Test with all-NA row
+ pdf = pd.DataFrame(
+ {
+ "a": [1.0, None],
+ "b": [2.0, None],
+ }
+ )
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1))
+
+ # Test with ties (first occurrence should win)
+ pdf = pd.DataFrame(
+ {
+ "a": [3, 2, 1],
+ "b": [3, 5, 1],
+ "c": [1, 5, 1],
+ }
+ )
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1))
+
+ # Test with single column
+ pdf = pd.DataFrame({"a": [1, 2, 3]})
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1))
+
+ # Test with empty DataFrame
+ pdf = pd.DataFrame({})
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1))
+
+ # Test with different data types
+ pdf = pd.DataFrame(
+ {
+ "int_col": [1, 2, 3],
+ "float_col": [1.5, 2.5, 0.5],
+ "negative": [-5, -10, -1],
+ }
+ )
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1))
+
+ # Test with custom index
+ pdf = pd.DataFrame(
+ {
+ "a": [1, 2, 3],
+ "b": [4, 5, 6],
+ "c": [7, 8, 9],
+ },
+ index=["row1", "row2", "row3"],
+ )
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1))
+
+ def test_idxmax_multiindex_columns(self):
+ # Test that MultiIndex columns raise NotImplementedError for axis=1
+ pdf = pd.DataFrame(
+ {
+ "a": [1, 2, 3],
+ "b": [4, 5, 6],
+ "c": [7, 8, 9],
+ }
+ )
+ pdf.columns = pd.MultiIndex.from_tuples([("x", "a"), ("y", "b"), ("z",
"c")])
+ psdf = ps.from_pandas(pdf)
+
+ # axis=0 should work fine (it uses pandas internally)
+ self.assert_eq(psdf.idxmax(axis=0), pdf.idxmax(axis=0))
+
+ # axis=1 should raise NotImplementedError
+ with self.assertRaises(NotImplementedError):
+ psdf.idxmax(axis=1)
+
+ def test_idxmax_empty_dataframe(self):
+ # Test empty DataFrame with no rows and no columns - should return
empty Series
+ pdf = pd.DataFrame({})
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(psdf.idxmax(axis=1), pdf.idxmax(axis=1))
+
+ # Test empty DataFrame with rows but no columns - should raise
ValueError
+ pdf = pd.DataFrame(index=range(3))
+ psdf = ps.from_pandas(pdf)
+
+ with self.assertRaises(ValueError) as pdf_context:
+ pdf.idxmax(axis=1)
+
+ with self.assertRaises(ValueError) as psdf_context:
+ psdf.idxmax(axis=1)
+
+ # Verify both raise the same error message
+ self.assertEqual(str(pdf_context.exception),
str(psdf_context.exception))
+
+
+class FrameIdxMaxMinTests(
+ FrameIdxMaxMinMixin,
+ PandasOnSparkTestCase,
+ SQLTestUtils,
+):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.testing import main
+
+ main()
diff --git
a/python/pyspark/pandas/tests/connect/computation/test_parity_idxmax_idxmin.py
b/python/pyspark/pandas/tests/connect/computation/test_parity_idxmax_idxmin.py
new file mode 100644
index 000000000000..06e723d39708
--- /dev/null
+++
b/python/pyspark/pandas/tests/connect/computation/test_parity_idxmax_idxmin.py
@@ -0,0 +1,34 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.pandas.tests.computation.test_idxmax_idxmin import
FrameIdxMaxMinMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
+
+
+class FrameParityIdxMaxMinTests(
+ FrameIdxMaxMinMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
+):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.testing import main
+
+ main()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]