This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f88addf04a34 [SPARK-55662][PS] Implementation of idxmin Axis argument
f88addf04a34 is described below

commit f88addf04a34b166d85646bb669963962c47dd60
Author: Devin Petersohn <[email protected]>
AuthorDate: Thu Feb 26 11:54:30 2026 -0800

    [SPARK-55662][PS] Implementation of idxmin Axis argument
    
    ### What changes were proposed in this pull request?
    Add axis=1 support for DataFrame.idxmin, matching the existing idxmax 
axis=1 implementation.
    
    ### Why are the changes needed?
    Implements missing API parameter
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, new parameter
    
    ### How was this patch tested?
    CI
    
    ### Was this patch authored or co-authored using generative AI tooling?
    Co-authored-by: Claude Opus 4
    
    Closes #54455 from devin-petersohn/devin/idxmin_axis_v2.
    
    Authored-by: Devin Petersohn <[email protected]>
    Signed-off-by: Takuya Ueshin <[email protected]>
---
 python/pyspark/pandas/frame.py                     |  80 +++++++++++++--
 .../pandas/tests/computation/test_idxmax_idxmin.py | 112 +++++++++++++++++++++
 2 files changed, 181 insertions(+), 11 deletions(-)

diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index a9ffcbc9d59d..4e85c6b73301 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -12450,7 +12450,6 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
 
             return first_series(DataFrame(internal))
 
-    # TODO(SPARK-46168): axis = 1
     def idxmin(self, axis: Axis = 0) -> "Series":
         """
         Return index of first occurrence of minimum over requested axis.
@@ -12461,8 +12460,8 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
 
         Parameters
         ----------
-        axis : 0 or 'index'
-            Can only be set to 0 now.
+        axis : {0 or 'index', 1 or 'columns'}, default 0
+            The axis to use. 0 or 'index' for row-wise, 1 or 'columns' for 
column-wise.
 
         Returns
         -------
@@ -12490,6 +12489,15 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
         c    1
         dtype: int64
 
+        For axis=1, return the column label of the minimum value in each row:
+
+        >>> psdf.idxmin(axis=1)
+        0    a
+        1    a
+        2    a
+        3    b
+        dtype: object
+
         For Multi-column Index
 
         >>> psdf = ps.DataFrame({'a': [1, 2, 3, 2],
@@ -12510,17 +12518,67 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
         c  z    1
         dtype: int64
         """
-        min_cols = map(lambda scol: F.min(scol), 
self._internal.data_spark_columns)
-        sdf_min = self._internal.spark_frame.select(*min_cols).head()
+        axis = validate_axis(axis)
+        if axis == 0:
+            min_cols = map(lambda scol: F.min(scol), 
self._internal.data_spark_columns)
+            sdf_min = self._internal.spark_frame.select(*min_cols).head()
 
-        conds = (
-            scol == min_val for scol, min_val in 
zip(self._internal.data_spark_columns, sdf_min)
-        )
-        cond = reduce(lambda x, y: x | y, conds)
+            conds = (
+                scol == min_val for scol, min_val in 
zip(self._internal.data_spark_columns, sdf_min)
+            )
+            cond = reduce(lambda x, y: x | y, conds)
 
-        psdf: DataFrame = DataFrame(self._internal.with_filter(cond))
+            psdf: DataFrame = DataFrame(self._internal.with_filter(cond))
 
-        return cast(ps.Series, 
ps.from_pandas(psdf._to_internal_pandas().idxmin()))
+            return cast(ps.Series, 
ps.from_pandas(psdf._to_internal_pandas().idxmin()))
+        else:
+            from pyspark.pandas.series import first_series
+
+            column_labels = self._internal.column_labels
+
+            if len(column_labels) == 0:
+                # Check if DataFrame has rows - if yes, raise error; if no, 
return empty Series
+                # to match pandas behavior
+                if len(self) > 0:
+                    raise ValueError("attempt to get argmin of an empty 
sequence")
+                else:
+                    return ps.Series([], dtype=np.int64)
+
+            if self._internal.column_labels_level > 1:
+                raise NotImplementedError(
+                    "idxmin with axis=1 does not support MultiIndex columns 
yet"
+                )
+
+            min_value = F.least(
+                *[
+                    F.coalesce(self._internal.spark_column_for(label), 
F.lit(float("inf")))
+                    for label in column_labels
+                ],
+                F.lit(float("inf")),
+            )
+
+            result = None
+            # Iterate over the column labels in reverse order to get the first 
occurrence of the
+            # minimum value.
+            for label in reversed(column_labels):
+                scol = self._internal.spark_column_for(label)
+                label_value = label[0] if len(label) == 1 else label
+                condition = (scol == min_value) & scol.isNotNull()
+
+                result = (
+                    F.when(condition, F.lit(label_value))
+                    if result is None
+                    else F.when(condition, 
F.lit(label_value)).otherwise(result)
+                )
+
+            result = F.when(min_value == float("inf"), 
F.lit(None)).otherwise(result)
+
+            internal = self._internal.with_new_columns(
+                [result.alias(SPARK_DEFAULT_SERIES_NAME)],
+                column_labels=[None],
+            )
+
+            return first_series(DataFrame(internal))
 
     def info(
         self,
diff --git a/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py 
b/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py
index 2de14fccbc43..6c33b4fd3261 100644
--- a/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py
+++ b/python/pyspark/pandas/tests/computation/test_idxmax_idxmin.py
@@ -155,6 +155,118 @@ class FrameIdxMaxMinMixin:
         # Verify both raise the same error message
         self.assertEqual(str(pdf_context.exception), 
str(psdf_context.exception))
 
+    def test_idxmin(self):
+        # Test basic axis=0 (default)
+        pdf = pd.DataFrame(
+            {
+                "a": [1, 2, 3, 2],
+                "b": [4.0, 2.0, 3.0, 1.0],
+                "c": [300, 200, 400, 200],
+            }
+        )
+        psdf = ps.from_pandas(pdf)
+
+        self.assert_eq(psdf.idxmin(), pdf.idxmin())
+        self.assert_eq(psdf.idxmin(axis=0), pdf.idxmin(axis=0))
+        self.assert_eq(psdf.idxmin(axis="index"), pdf.idxmin(axis="index"))
+
+        # Test axis=1
+        self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1))
+        self.assert_eq(psdf.idxmin(axis="columns"), pdf.idxmin(axis="columns"))
+
+        # Test with NAs
+        pdf = pd.DataFrame(
+            {
+                "a": [1.0, None, 3.0],
+                "b": [None, 2.0, None],
+                "c": [3.0, 4.0, None],
+            }
+        )
+        psdf = ps.from_pandas(pdf)
+
+        self.assert_eq(psdf.idxmin(), pdf.idxmin())
+        self.assert_eq(psdf.idxmin(axis=0), pdf.idxmin(axis=0))
+        self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1))
+
+        # Test with all-NA row
+        pdf = pd.DataFrame(
+            {
+                "a": [1.0, None],
+                "b": [2.0, None],
+            }
+        )
+        psdf = ps.from_pandas(pdf)
+
+        self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1))
+
+        # Test with ties (first occurrence should win)
+        pdf = pd.DataFrame(
+            {
+                "a": [3, 2, 1],
+                "b": [3, 5, 1],
+                "c": [1, 5, 1],
+            }
+        )
+        psdf = ps.from_pandas(pdf)
+
+        self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1))
+
+        # Test with single column
+        pdf = pd.DataFrame({"a": [1, 2, 3]})
+        psdf = ps.from_pandas(pdf)
+
+        self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1))
+
+        # Test with empty DataFrame
+        pdf = pd.DataFrame({})
+        psdf = ps.from_pandas(pdf)
+
+        self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1))
+
+        # Test with different data types
+        pdf = pd.DataFrame(
+            {
+                "int_col": [1, 2, 3],
+                "float_col": [1.5, 2.5, 0.5],
+                "negative": [-5, -10, -1],
+            }
+        )
+        psdf = ps.from_pandas(pdf)
+
+        self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1))
+
+        # Test with custom index
+        pdf = pd.DataFrame(
+            {
+                "a": [1, 2, 3],
+                "b": [4, 5, 6],
+                "c": [7, 8, 9],
+            },
+            index=["row1", "row2", "row3"],
+        )
+        psdf = ps.from_pandas(pdf)
+
+        self.assert_eq(psdf.idxmin(axis=1), pdf.idxmin(axis=1))
+
+    def test_idxmin_multiindex_columns(self):
+        # Test that MultiIndex columns raise NotImplementedError for axis=1
+        pdf = pd.DataFrame(
+            {
+                "a": [1, 2, 3],
+                "b": [4, 5, 6],
+                "c": [7, 8, 9],
+            }
+        )
+        pdf.columns = pd.MultiIndex.from_tuples([("x", "a"), ("y", "b"), ("z", 
"c")])
+        psdf = ps.from_pandas(pdf)
+
+        # axis=0 should work fine (it uses pandas internally)
+        self.assert_eq(psdf.idxmin(axis=0), pdf.idxmin(axis=0))
+
+        # axis=1 should raise NotImplementedError
+        with self.assertRaises(NotImplementedError):
+            psdf.idxmin(axis=1)
+
 
 class FrameIdxMaxMinTests(
     FrameIdxMaxMinMixin,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to