Fokko commented on code in PR #1534:
URL: https://github.com/apache/iceberg-python/pull/1534#discussion_r1940000258


##########
pyiceberg/table/merge_rows_util.py:
##########
@@ -0,0 +1,165 @@
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+from datafusion import SessionContext

Review Comment:
   We cannot import this at the top of the file. When someone imports this file 
that hasn't `datafusion` installed, it will raise a `ModuleNotFoundError` right 
away



##########
pyiceberg/table/__init__.py:
##########
@@ -1064,6 +1065,110 @@ def name_mapping(self) -> Optional[NameMapping]:
         """Return the table's field-id NameMapping."""
         return self.metadata.name_mapping()
 
+    def merge_rows(self, df: pa.Table, join_cols: list
+                   , when_matched_update_all: bool = True
+                   , when_not_matched_insert_all: bool = True
+                ) -> Dict:
+        """
+        Shorthand API for performing an upsert/merge to an iceberg table.
+        
+        Args:
+            df: The input dataframe to merge with the table's data.
+            join_cols: The columns to join on.
+            when_matched_update_all: Bool indicating to update rows that are 
matched but require an update due to a value in a non-key column changing
+            when_not_matched_insert_all: Bool indicating new rows to be 
inserted that do not match any existing rows in the table
+
+        Returns:
+            A dictionary containing the number of rows updated and inserted.
+        """
+
+        from pyiceberg.table import merge_rows_util
+
+        try:
+            from datafusion import SessionContext
+        except ModuleNotFoundError as e:
+            raise ModuleNotFoundError("For merge_rows, DataFusion needs to be 
installed") from e
+        
+        try:
+            from pyarrow import dataset as ds
+        except ModuleNotFoundError as e:
+            raise ModuleNotFoundError("For merge_rows, PyArrow needs to be 
installed") from e
+
+        source_table_name = "source"
+        target_table_name = "target"
+
+        if when_matched_update_all == False and when_not_matched_insert_all == 
False:
+            return {'rows_updated': 0, 'rows_inserted': 0, 'msg': 'no merge 
options selected...exiting'}

Review Comment:
   Should we return a `@dataclass` with known properties? The users would need 
to look into the source on what keys to expect in the return dictionary 
(`rows_{updated,inserted}`, `msg`, `error_message`, etc).



##########
pyiceberg/table/__init__.py:
##########
@@ -1064,6 +1065,110 @@ def name_mapping(self) -> Optional[NameMapping]:
         """Return the table's field-id NameMapping."""
         return self.metadata.name_mapping()
 
+    def merge_rows(self, df: pa.Table, join_cols: list
+                   , when_matched_update_all: bool = True
+                   , when_not_matched_insert_all: bool = True
+                ) -> Dict:
+        """
+        Shorthand API for performing an upsert/merge to an iceberg table.
+        
+        Args:
+            df: The input dataframe to merge with the table's data.
+            join_cols: The columns to join on.
+            when_matched_update_all: Bool indicating to update rows that are 
matched but require an update due to a value in a non-key column changing
+            when_not_matched_insert_all: Bool indicating new rows to be 
inserted that do not match any existing rows in the table
+
+        Returns:
+            A dictionary containing the number of rows updated and inserted.
+        """
+
+        from pyiceberg.table import merge_rows_util
+
+        try:
+            from datafusion import SessionContext
+        except ModuleNotFoundError as e:
+            raise ModuleNotFoundError("For merge_rows, DataFusion needs to be 
installed") from e
+        
+        try:
+            from pyarrow import dataset as ds
+        except ModuleNotFoundError as e:
+            raise ModuleNotFoundError("For merge_rows, PyArrow needs to be 
installed") from e
+
+        source_table_name = "source"
+        target_table_name = "target"
+
+        if when_matched_update_all == False and when_not_matched_insert_all == 
False:
+            return {'rows_updated': 0, 'rows_inserted': 0, 'msg': 'no merge 
options selected...exiting'}
+
+        missing_columns = merge_rows_util.do_join_columns_exist(df, self, 
join_cols)
+        
+        if missing_columns['source'] or missing_columns['target']:
+
+            return {'error_msgs': f"Join columns missing in tables: Source 
table columns missing: {missing_columns['source']}, Target table columns 
missing: {missing_columns['target']}"}

Review Comment:
   I think we can re-use the check for the columns from the `append` method: 
https://github.com/apache/iceberg-python/blob/5018efc203b6393898a000d98e6f305f287e42a7/pyiceberg/table/__init__.py#L448-L451



##########
pyiceberg/table/merge_rows_util.py:
##########
@@ -0,0 +1,165 @@
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+from datafusion import SessionContext
+from pyarrow import Table as pyarrow_table
+import pyarrow as pa
+from pyiceberg import table as pyiceberg_table
+
+from pyiceberg.expressions import (
+    BooleanExpression,
+    And,
+    EqualTo,
+    Or,
+    In,
+)
+
+def get_filter_list(df: pyarrow_table, join_cols: list) -> BooleanExpression:
+
+    unique_keys = df.select(join_cols).group_by(join_cols).aggregate([])
+
+    pred = None
+
+    if len(join_cols) == 1:
+        pred = In(join_cols[0], unique_keys[0].to_pylist())
+    else:
+        pred = Or(*[
+            And(*[
+                EqualTo(col, row[col])
+                for col in join_cols
+            ])
+            for row in unique_keys.to_pylist()
+        ])
+
+    return pred
+
+
+def get_table_column_list(connection: SessionContext, table_name: str) -> list:
+    """
+    This function retrieves the column names and their data types for the 
specified table.
+    It returns a list of tuples where each tuple contains the column name and 
its data type.
+    
+    Args:
+        connection: DataFusion SessionContext.
+        table_name: The name of the table for which to retrieve column 
information.
+    
+    Returns:
+        A list of tuples containing column names and their corresponding data 
types.
+    """
+    # DataFusion logic
+    res = connection.sql(f"""
+        SELECT column_name, data_type 
+        FROM information_schema.columns 
+        WHERE table_name = '{table_name}'
+    """).collect()
+
+    column_names = res[0][0].to_pylist()  # Extract the first column (column 
names)
+    data_types = res[0][1].to_pylist()    # Extract the second column (data 
types)
+
+    return list(zip(column_names, data_types))  # Combine into list of tuples
+
+def dups_check_in_source(source_table: str, join_cols: list, connection: 
SessionContext) -> bool:

Review Comment:
   How would you feel about using plain PyArrow for this:
   
   ```python
   import pyarrow as pa
   import pyarrow.compute as pc
   
   join_cols = ["animals"]
   
   df = pa.table([
       pa.array([2, 4, 4]),
       pa.array(["Parrot", "Dog", "Dog"])
   ], names=['n_legs', 'animals'])
   
   source_dup_count = len(
       df.select(join_cols)
       .group_by(join_cols)
       .aggregate([([], "count_all")])
       .filter(pc.field("count_all") > 1
   )
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org
For additional commands, e-mail: issues-h...@iceberg.apache.org

Reply via email to