This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d29370b7e92 Fix SqlToS3Operator _partition_dataframe for proper polars 
support (#54588)
d29370b7e92 is described below

commit d29370b7e922cccbb615a9da0308f4a5f46fc607
Author: Guan Ming(Wesley) Chiu <[email protected]>
AuthorDate: Sat Aug 23 02:51:38 2025 +0800

    Fix SqlToS3Operator _partition_dataframe for proper polars support (#54588)
    
    * Fix _partition_dataframe for proper polars support
    
    * Refactor to reduce code duplicattion
    
    Co-authored-by: Ranuga Disansa 
<[email protected]>
    
    ---------
    
    Co-authored-by: Ranuga Disansa 
<[email protected]>
---
 .../providers/amazon/aws/transfers/sql_to_s3.py    | 48 +++++++++++++++-------
 .../unit/amazon/aws/transfers/test_sql_to_s3.py    |  6 ++-
 2 files changed, 38 insertions(+), 16 deletions(-)

diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py 
b/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py
index e13d5ca6b4e..45224059e99 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py
@@ -261,38 +261,56 @@ class SqlToS3Operator(BaseOperator):
                 file_obj=buf, key=object_key, bucket_name=self.s3_bucket, 
replace=self.replace
             )
 
-    def _partition_dataframe(self, df: pd.DataFrame | pl.DataFrame) -> 
Iterable[tuple[str, pd.DataFrame]]:
-        """Partition dataframe using pandas groupby() method."""
+    def _partition_dataframe(
+        self, df: pd.DataFrame | pl.DataFrame
+    ) -> Iterable[tuple[str, pd.DataFrame | pl.DataFrame]]:
+        """Partition dataframe using pandas or polars groupby() method."""
         try:
             import secrets
             import string
 
             import numpy as np
+            import pandas as pd
             import polars as pl
         except ImportError:
             pass
 
-        if isinstance(df, pl.DataFrame):
-            df = df.to_pandas()
-
         # if max_rows_per_file argument is specified, a temporary column with 
a random unusual name will be
         # added to the dataframe. This column is used to dispatch the 
dataframe into smaller ones using groupby()
-
-        random_column_name = ""
+        random_column_name = None
         if self.max_rows_per_file and not self.groupby_kwargs:
             random_column_name = "".join(secrets.choice(string.ascii_letters) 
for _ in range(20))
-            df[random_column_name] = np.arange(len(df)) // 
self.max_rows_per_file
             self.groupby_kwargs = {"by": random_column_name}
+
+        if random_column_name:
+            if isinstance(df, pd.DataFrame):
+                df[random_column_name] = np.arange(len(df)) // 
self.max_rows_per_file
+            elif isinstance(df, pl.DataFrame):
+                df = df.with_columns(
+                    (pl.int_range(pl.len()) // 
self.max_rows_per_file).alias(random_column_name)
+                )
+
         if not self.groupby_kwargs:
             yield "", df
             return
-        for group_label in (grouped_df := 
df.groupby(**self.groupby_kwargs)).groups:
-            yield (
-                cast("str", group_label),
-                grouped_df.get_group(group_label)
-                .drop(random_column_name, axis=1, errors="ignore")
-                .reset_index(drop=True),
-            )
+
+        if isinstance(df, pd.DataFrame):
+            for group_label in (grouped_df := 
df.groupby(**self.groupby_kwargs)).groups:
+                group_df = grouped_df.get_group(group_label)
+                if random_column_name:
+                    group_df = group_df.drop(random_column_name, axis=1, 
errors="ignore")
+                yield (
+                    cast("str", group_label[0] if isinstance(group_label, 
tuple) else group_label),
+                    group_df.reset_index(drop=True),
+                )
+        elif isinstance(df, pl.DataFrame):
+            for group_label, group_df in df.group_by(**self.groupby_kwargs):  
# type: ignore[assignment]
+                if random_column_name:
+                    group_df = group_df.drop(random_column_name)
+                yield (
+                    cast("str", group_label[0] if isinstance(group_label, 
tuple) else group_label),
+                    group_df,
+                )
 
     def _get_hook(self) -> DbApiHook:
         self.log.debug("Get connection for %s", self.sql_conn_id)
diff --git a/providers/amazon/tests/unit/amazon/aws/transfers/test_sql_to_s3.py 
b/providers/amazon/tests/unit/amazon/aws/transfers/test_sql_to_s3.py
index 402803640e6..1a1e424928f 100644
--- a/providers/amazon/tests/unit/amazon/aws/transfers/test_sql_to_s3.py
+++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_sql_to_s3.py
@@ -22,6 +22,7 @@ import io
 from unittest import mock
 
 import pandas as pd
+import polars as pl
 import pytest
 
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
@@ -442,7 +443,10 @@ class TestSqlToS3Operator:
 
         assert len(partitions) == 2
         for group_name, df in partitions:
-            assert isinstance(df, pd.DataFrame)
+            if df_type == "polars":
+                assert isinstance(df, pl.DataFrame)
+            else:
+                assert isinstance(df, pd.DataFrame)
             assert group_name in ["A", "B"]
 
     @pytest.mark.parametrize(

Reply via email to