This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d29370b7e92 Fix SqlToS3Operator _partition_dataframe for proper polars
support (#54588)
d29370b7e92 is described below
commit d29370b7e922cccbb615a9da0308f4a5f46fc607
Author: Guan Ming(Wesley) Chiu <[email protected]>
AuthorDate: Sat Aug 23 02:51:38 2025 +0800
Fix SqlToS3Operator _partition_dataframe for proper polars support (#54588)
* Fix _partition_dataframe for proper polars support
* Refactor to reduce code duplicattion
Co-authored-by: Ranuga Disansa
<[email protected]>
---------
Co-authored-by: Ranuga Disansa
<[email protected]>
---
.../providers/amazon/aws/transfers/sql_to_s3.py | 48 +++++++++++++++-------
.../unit/amazon/aws/transfers/test_sql_to_s3.py | 6 ++-
2 files changed, 38 insertions(+), 16 deletions(-)
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py
b/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py
index e13d5ca6b4e..45224059e99 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/sql_to_s3.py
@@ -261,38 +261,56 @@ class SqlToS3Operator(BaseOperator):
file_obj=buf, key=object_key, bucket_name=self.s3_bucket,
replace=self.replace
)
- def _partition_dataframe(self, df: pd.DataFrame | pl.DataFrame) ->
Iterable[tuple[str, pd.DataFrame]]:
- """Partition dataframe using pandas groupby() method."""
+ def _partition_dataframe(
+ self, df: pd.DataFrame | pl.DataFrame
+ ) -> Iterable[tuple[str, pd.DataFrame | pl.DataFrame]]:
+ """Partition dataframe using pandas or polars groupby() method."""
try:
import secrets
import string
import numpy as np
+ import pandas as pd
import polars as pl
except ImportError:
pass
- if isinstance(df, pl.DataFrame):
- df = df.to_pandas()
-
# if max_rows_per_file argument is specified, a temporary column with
a random unusual name will be
# added to the dataframe. This column is used to dispatch the
dataframe into smaller ones using groupby()
-
- random_column_name = ""
+ random_column_name = None
if self.max_rows_per_file and not self.groupby_kwargs:
random_column_name = "".join(secrets.choice(string.ascii_letters)
for _ in range(20))
- df[random_column_name] = np.arange(len(df)) //
self.max_rows_per_file
self.groupby_kwargs = {"by": random_column_name}
+
+ if random_column_name:
+ if isinstance(df, pd.DataFrame):
+ df[random_column_name] = np.arange(len(df)) //
self.max_rows_per_file
+ elif isinstance(df, pl.DataFrame):
+ df = df.with_columns(
+ (pl.int_range(pl.len()) //
self.max_rows_per_file).alias(random_column_name)
+ )
+
if not self.groupby_kwargs:
yield "", df
return
- for group_label in (grouped_df :=
df.groupby(**self.groupby_kwargs)).groups:
- yield (
- cast("str", group_label),
- grouped_df.get_group(group_label)
- .drop(random_column_name, axis=1, errors="ignore")
- .reset_index(drop=True),
- )
+
+ if isinstance(df, pd.DataFrame):
+ for group_label in (grouped_df :=
df.groupby(**self.groupby_kwargs)).groups:
+ group_df = grouped_df.get_group(group_label)
+ if random_column_name:
+ group_df = group_df.drop(random_column_name, axis=1,
errors="ignore")
+ yield (
+ cast("str", group_label[0] if isinstance(group_label,
tuple) else group_label),
+ group_df.reset_index(drop=True),
+ )
+ elif isinstance(df, pl.DataFrame):
+ for group_label, group_df in df.group_by(**self.groupby_kwargs):
# type: ignore[assignment]
+ if random_column_name:
+ group_df = group_df.drop(random_column_name)
+ yield (
+ cast("str", group_label[0] if isinstance(group_label,
tuple) else group_label),
+ group_df,
+ )
def _get_hook(self) -> DbApiHook:
self.log.debug("Get connection for %s", self.sql_conn_id)
diff --git a/providers/amazon/tests/unit/amazon/aws/transfers/test_sql_to_s3.py
b/providers/amazon/tests/unit/amazon/aws/transfers/test_sql_to_s3.py
index 402803640e6..1a1e424928f 100644
--- a/providers/amazon/tests/unit/amazon/aws/transfers/test_sql_to_s3.py
+++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_sql_to_s3.py
@@ -22,6 +22,7 @@ import io
from unittest import mock
import pandas as pd
+import polars as pl
import pytest
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
@@ -442,7 +443,10 @@ class TestSqlToS3Operator:
assert len(partitions) == 2
for group_name, df in partitions:
- assert isinstance(df, pd.DataFrame)
+ if df_type == "polars":
+ assert isinstance(df, pl.DataFrame)
+ else:
+ assert isinstance(df, pd.DataFrame)
assert group_name in ["A", "B"]
@pytest.mark.parametrize(