This is an automated email from the ASF dual-hosted git repository.
ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new f88addf04a34 [SPARK-55662][PS] Implementation of idxmin Axis argument
f88addf04a34 is described below
commit f88addf04a34b166d85646bb669963962c47dd60
Author: Devin Petersohn <[email protected]>
AuthorDate: Thu Feb 26 11:54:30 2026 -0800
[SPARK-55662][PS] Implementation of idxmin Axis argument
### What changes were proposed in this pull request?
Add axis=1 support for DataFrame.idxmin, matching the existing idxmax
axis=1 implementation.
### Why are the changes needed?
Implements missing API parameter
### Does this PR introduce _any_ user-facing change?
Yes, new parameter
### How was this patch tested?
CI
### Was this patch authored or co-authored using generative AI tooling?
Co-authored-by: Claude Opus 4
Closes #54455 from devin-petersohn/devin/idxmin_axis_v2.
Authored-by: Devin Petersohn <[email protected]>
Signed-off-by: Takuya Ueshin <[email protected]>
---
python/pyspark/pandas/frame.py | 80 +++++++++++++--
.../pandas/tests/computation/test_idxmax_idxmin.py | 112 +++++++++++++++++++++
2 files changed, 181 insertions(+), 11 deletions(-)
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index a9ffcbc9d59d..4e85c6b73301 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -12450,7 +12450,6 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
return first_series(DataFrame(internal))
- # TODO(SPARK-46168): axis = 1
def idxmin(self, axis: Axis = 0) -> "Series":
"""
Return index of first occurrence of minimum over requested axis.
@@ -12461,8 +12460,8 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
Parameters
----------
- axis : 0 or 'index'
- Can only be set to 0 now.
+ axis : {0 or 'index', 1 or 'columns'}, default 0
+ The axis to use. 0 or 'index' for row-wise, 1 or 'columns' for
column-wise.
Returns
-------
@@ -12490,6 +12489,15 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
c 1
dtype: int64
+ For axis=1, return the column label of the minimum value in each row:
+
+ >>> psdf.idxmin(axis=1)
+ 0 a
+ 1 a
+ 2 a
+ 3 b
+ dtype: object
+
For Multi-column Index
>>> psdf = ps.DataFrame({'a': [1, 2, 3, 2],
@@ -12510,17 +12518,67 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
c z 1
dtype: int64
"""
- min_cols = map(lambda scol: F.min(scol),
self._internal.data_spark_columns)
- sdf_min = self._internal.spark_frame.select(*min_cols).head()
+ axis = validate_axis(axis)
+ if axis == 0:
+ min_cols = map(lambda scol: F.min(scol),
self._internal.data_spark_columns)
+ sdf_min = self._internal.spark_frame.select(*min_cols).head()
- conds = (
- scol == min_val for scol, min_val in
zip(self._internal.data_spark_columns, sdf_min)
- )
- cond = reduce(lambda x, y: x | y, conds)
+ conds = (
+ scol == min_val for scol, min_val in
zip(self._internal.data_spark_columns, sdf_min)
+ )
+ cond = reduce(lambda x, y: x | y, conds)
- psdf: DataFrame = DataFrame(self._internal.with_filter(cond))
+ psdf: DataFrame = DataFrame(self._internal.with_filter(cond))
- return cast(ps.Series,
ps.from_pandas(psdf._to_internal_pandas().idxmin()))
+ return cast(ps.Series,
ps.from_pandas(psdf._to_internal_pandas().idxmin()))
+ else:
+ from pyspark.pandas.series import first_series
+
+ column_labels = self._internal.column_labels
+
+ if len(column_labels) == 0:
+ # Check if DataFrame has rows - if yes, raise error; if no,
return empty Series
+ # to match pandas behavior
+ if len(self) > 0:
+ raise ValueError("attempt to get argmin of an empty
sequence")
+ else:
+ return ps.Series([], dtype=np.int64)
+
+ if self._internal.column_labels_level > 1:
+ raise NotImplementedError(
+ "idxmin with axis=1 does not support MultiIndex columns
yet"
+ )
+
+ min_value = F.least(
+ *[
+ F.coalesce(self._internal.spark_column_for(label),
F.lit(float("inf")))
+ for label in column_labels
+ ],
+ F.lit(float("inf")),
+ )
+
+ result = None
+ # Iterate over the column labels in reverse order to get the first
occurrence of the
+ # minimum value.
+ for label in reversed(column_labels):
+ scol = self._internal.spark_column_for(label)
+ label_value = label[0] if len(label) == 1 else label
+ condition = (scol == min_value) & scol.isNotNull()
+
+ result = (
+ F.when(condition, F.lit(label_value))
+ if result is None
+ else F.when(condition,
F.lit(label_value)).otherwise(result)
+ )
+
+ result = F.when(min_value == float("inf"),
F.lit(None)).otherwise(result)
+
+ internal = self._internal.with_new_columns(
+ [result.alias(SPARK_DEFAULT_SERIES_NAME)],
+ column_labels=[None],
+ )
+
+ return first_series(DataFrame(internal))
def info(
self,
diff --git a/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py
b/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py
index 2de14fccbc43..6c33b4fd3261 100644
--- a/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py
+++ b/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py
@@ -155,6 +155,118 @@ class FrameIdxMaxMinMixin:
# Verify both raise the same error message
self.assertEqual(str(pdf_context.exception),
str(psdf_context.exception))
+ def test_idxmin(self):
+ # Test basic axis=0 (default)
+ pdf = pd.DataFrame(
+ {
+ "a": [1, 2, 3, 2],
+ "b": [4.0, 2.0, 3.0, 1.0],
+ "c": [300, 200, 400, 200],
+ }
+ )
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(psdf.idxmin(), pdf.idxmin())
+ self.assert_eq(psdf.idxmin(axis=0), pdf.idxmin(axis=0))
+ self.assert_eq(psdf.idxmin(axis="index"), pdf.idxmin(axis="index"))
+
+ # Test axis=1
+ self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1))
+ self.assert_eq(psdf.idxmin(axis="columns"), pdf.idxmin(axis="columns"))
+
+ # Test with NAs
+ pdf = pd.DataFrame(
+ {
+ "a": [1.0, None, 3.0],
+ "b": [None, 2.0, None],
+ "c": [3.0, 4.0, None],
+ }
+ )
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(psdf.idxmin(), pdf.idxmin())
+ self.assert_eq(psdf.idxmin(axis=0), pdf.idxmin(axis=0))
+ self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1))
+
+ # Test with all-NA row
+ pdf = pd.DataFrame(
+ {
+ "a": [1.0, None],
+ "b": [2.0, None],
+ }
+ )
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1))
+
+ # Test with ties (first occurrence should win)
+ pdf = pd.DataFrame(
+ {
+ "a": [3, 2, 1],
+ "b": [3, 5, 1],
+ "c": [1, 5, 1],
+ }
+ )
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1))
+
+ # Test with single column
+ pdf = pd.DataFrame({"a": [1, 2, 3]})
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1))
+
+ # Test with empty DataFrame
+ pdf = pd.DataFrame({})
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1))
+
+ # Test with different data types
+ pdf = pd.DataFrame(
+ {
+ "int_col": [1, 2, 3],
+ "float_col": [1.5, 2.5, 0.5],
+ "negative": [-5, -10, -1],
+ }
+ )
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1))
+
+ # Test with custom index
+ pdf = pd.DataFrame(
+ {
+ "a": [1, 2, 3],
+ "b": [4, 5, 6],
+ "c": [7, 8, 9],
+ },
+ index=["row1", "row2", "row3"],
+ )
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1))
+
+ def test_idxmin_multiindex_columns(self):
+ # Test that MultiIndex columns raise NotImplementedError for axis=1
+ pdf = pd.DataFrame(
+ {
+ "a": [1, 2, 3],
+ "b": [4, 5, 6],
+ "c": [7, 8, 9],
+ }
+ )
+ pdf.columns = pd.MultiIndex.from_tuples([("x", "a"), ("y", "b"), ("z",
"c")])
+ psdf = ps.from_pandas(pdf)
+
+ # axis=0 should work fine (it uses pandas internally)
+ self.assert_eq(psdf.idxmin(axis=0), pdf.idxmin(axis=0))
+
+ # axis=1 should raise NotImplementedError
+ with self.assertRaises(NotImplementedError):
+ psdf.idxmin(axis=1)
+
class FrameIdxMaxMinTests(
FrameIdxMaxMinMixin,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]