mattmartin14 commented on code in PR #1534:
URL: https://github.com/apache/iceberg-python/pull/1534#discussion_r1951521796


##########
pyiceberg/table/upsert_util.py:
##########
@@ -0,0 +1,131 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pyarrow as pa
+from pyarrow import Table as pyarrow_table
+from pyarrow import compute as pc
+
+from pyiceberg.expressions import (
+    And,
+    BooleanExpression,
+    EqualTo,
+    In,
+    Or,
+)
+
+
+def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> 
BooleanExpression:
+    unique_keys = df.select(join_cols).group_by(join_cols).aggregate([])
+
+    if len(join_cols) == 1:
+        return In(join_cols[0], unique_keys[0].to_pylist())
+    else:
+        return Or(*[And(*[EqualTo(col, row[col]) for col in join_cols]) for 
row in unique_keys.to_pylist()])
+
+
+def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool:
+    """Check for duplicate rows in a PyArrow table based on the join 
columns."""
+    return len(df.select(join_cols).group_by(join_cols).aggregate([([], 
"count_all")]).filter(pc.field("count_all") > 1)) > 0
+
+
+def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, 
join_cols: list[str]) -> pa.Table:
+    """
+    Return a table with rows that need to be updated in the target table based 
on the join columns.
+
+    When a row is matched, an additional scan is done to evaluate the non-key 
columns to detect if an actual change has occurred.
+    Only matched rows that have an actual change to a non-key column value 
will be returned in the final output.
+    """
+    all_columns = set(source_table.column_names)
+    join_cols_set = set(join_cols)
+
+    non_key_cols = list(all_columns - join_cols_set)
+
+    match_expr = None
+
+    for col in join_cols:
+        target_values = target_table.column(col).to_pylist()
+        expr = pc.field(col).isin(target_values)
+
+        if match_expr is None:
+            match_expr = expr
+        else:
+            match_expr = match_expr & expr
+
+    matching_source_rows = source_table.filter(match_expr)
+
+    rows_to_update = []
+
+    for index in range(matching_source_rows.num_rows):
+        source_row = matching_source_rows.slice(index, 1)
+
+        target_filter = None
+
+        for col in join_cols:
+            target_value = source_row.column(col)[0].as_py()
+            if target_filter is None:
+                target_filter = pc.field(col) == target_value
+            else:
+                target_filter = target_filter & (pc.field(col) == target_value)

Review Comment:
   when trying to add that, i'm getting an error on my unit tests saying:
   
   ```bash
   UnboundLocalError: cannot access local variable 'target_value' where it is 
not associated with a value
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org
For additional commands, e-mail: issues-h...@iceberg.apache.org

Reply via email to