kevinjqliu commented on code in PR #1878: URL: https://github.com/apache/iceberg-python/pull/1878#discussion_r2051555688
########## pyiceberg/table/upsert_util.py: ########## @@ -82,14 +82,54 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols ], ) - return ( - source_table - # We already know that the schema is compatible, this is to fix large_ types - .cast(target_table.schema) - .join(target_table, keys=list(join_cols_set), join_type="inner", left_suffix="-lhs", right_suffix="-rhs") - .filter(diff_expr) - .drop_columns([f"{col}-rhs" for col in non_key_cols]) - .rename_columns({f"{col}-lhs" if col not in join_cols else col: col for col in source_table.column_names}) - # Finally cast to the original schema since it doesn't carry nullability: - # https://github.com/apache/arrow/issues/45557 - ).cast(target_table.schema) + try: + return ( + source_table + # We already know that the schema is compatible, this is to fix large_ types + .cast(target_table.schema) + .join(target_table, keys=list(join_cols_set), join_type="inner", left_suffix="-lhs", right_suffix="-rhs") + .filter(diff_expr) + .drop_columns([f"{col}-rhs" for col in non_key_cols]) + .rename_columns({f"{col}-lhs" if col not in join_cols else col: col for col in source_table.column_names}) + # Finally cast to the original schema since it doesn't carry nullability: + # https://github.com/apache/arrow/issues/45557 + ).cast(target_table.schema) + except pa.ArrowInvalid: + # When we are not able to compare (e.g. due to unsupported types), + # fall back to selecting only rows in the source table that do NOT already exist in the target. + # See: https://github.com/apache/arrow/issues/35785 + MARKER_COLUMN_NAME = "__from_target" + INDEX_COLUMN_NAME = "__source_index" + + if MARKER_COLUMN_NAME in join_cols_set or INDEX_COLUMN_NAME in join_cols_set: + raise ValueError( + f"{MARKER_COLUMN_NAME} and {INDEX_COLUMN_NAME} are reserved for joining " + f"DataFrames, and cannot be used as column names" + ) from None + + # Step 1: Prepare source index with join keys and a marker index + # Cast to target table schema, so we can do the join + # See: https://github.com/apache/arrow/issues/37542 + source_index = ( + source_table.cast(target_table.schema) + .select(join_cols_set) + .append_column(INDEX_COLUMN_NAME, pa.array(range(len(source_table)))) + ) + + # Step 2: Prepare target index with join keys and a marker + target_index = target_table.select(join_cols_set).append_column(MARKER_COLUMN_NAME, pa.repeat(True, len(target_table))) + + # Step 3: Perform a left outer join to find which rows from source exist in target + joined = source_index.join(target_index, keys=list(join_cols_set), join_type="left outer") + + # Step 4: Restore original source order + joined = joined.sort_by(INDEX_COLUMN_NAME) + + # Step 5: Create a boolean mask for rows that do exist in the target + # i.e., where marker column is true after the join + to_update_mask = pc.invert(pc.is_null(joined[MARKER_COLUMN_NAME])) + + # Step 6: Filter source table using the mask and cast to target schema + filtered = source_table.filter(to_update_mask) + + return filtered Review Comment: does this also need `.cast(target_table.schema)` for nullability? ########## tests/table/test_upsert.py: ########## @@ -511,6 +511,76 @@ def test_upsert_without_identifier_fields(catalog: Catalog) -> None: tbl.upsert(df) +def test_upsert_struct_field_fails_in_join(catalog: Catalog) -> None: + identifier = "default.test_upsert_struct_field_fails" + _drop_table(catalog, identifier) + + schema = Schema( + NestedField(1, "id", IntegerType(), required=True), + NestedField( + 2, + "nested_type", + # Struct<type: string, coordinates: list<double>> + StructType( + NestedField(3, "sub1", StringType(), required=True), + NestedField(4, "sub2", StringType(), required=True), + ), + required=False, + ), + identifier_field_ids=[1], + ) + + tbl = catalog.create_table(identifier, schema=schema) + + arrow_schema = pa.schema( + [ + pa.field("id", pa.int32(), nullable=False), + pa.field( + "nested_type", + pa.struct( + [ + pa.field("sub1", pa.large_string(), nullable=False), + pa.field("sub2", pa.large_string(), nullable=False), + ] + ), + nullable=True, + ), + ] + ) + + initial_data = pa.Table.from_pylist( + [ + { + "id": 1, + "nested_type": {"sub1": "bla1", "sub2": "bla"}, + } + ], + schema=arrow_schema, + ) + tbl.append(initial_data) + + update_data = pa.Table.from_pylist( + [ + { + "id": 2, + "nested_type": {"sub1": "bla1", "sub2": "bla"}, + }, + { + "id": 1, + "nested_type": {"sub1": "bla1", "sub2": "bla"}, + }, + ], + schema=arrow_schema, + ) + + upd = tbl.upsert(update_data, join_cols=["id"]) + + # Row needs to be updated even tho it's not changed. + # When pyarrow isn't able to compare rows, just update everything + assert upd.rows_updated == 1 + assert upd.rows_inserted == 1 + Review Comment: nit can we add an example where upsert is not allowed? i.e. join key specifies a column with complex type ########## tests/table/test_upsert.py: ########## @@ -511,6 +511,76 @@ def test_upsert_without_identifier_fields(catalog: Catalog) -> None: tbl.upsert(df) +def test_upsert_struct_field_fails_in_join(catalog: Catalog) -> None: + identifier = "default.test_upsert_struct_field_fails" + _drop_table(catalog, identifier) + + schema = Schema( + NestedField(1, "id", IntegerType(), required=True), + NestedField( + 2, + "nested_type", + # Struct<type: string, coordinates: list<double>> Review Comment: nit: this comment is no longer true ########## pyiceberg/table/upsert_util.py: ########## @@ -82,14 +82,54 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols ], ) - return ( - source_table - # We already know that the schema is compatible, this is to fix large_ types - .cast(target_table.schema) - .join(target_table, keys=list(join_cols_set), join_type="inner", left_suffix="-lhs", right_suffix="-rhs") - .filter(diff_expr) - .drop_columns([f"{col}-rhs" for col in non_key_cols]) - .rename_columns({f"{col}-lhs" if col not in join_cols else col: col for col in source_table.column_names}) - # Finally cast to the original schema since it doesn't carry nullability: - # https://github.com/apache/arrow/issues/45557 - ).cast(target_table.schema) + try: + return ( + source_table + # We already know that the schema is compatible, this is to fix large_ types + .cast(target_table.schema) + .join(target_table, keys=list(join_cols_set), join_type="inner", left_suffix="-lhs", right_suffix="-rhs") + .filter(diff_expr) + .drop_columns([f"{col}-rhs" for col in non_key_cols]) + .rename_columns({f"{col}-lhs" if col not in join_cols else col: col for col in source_table.column_names}) + # Finally cast to the original schema since it doesn't carry nullability: + # https://github.com/apache/arrow/issues/45557 + ).cast(target_table.schema) + except pa.ArrowInvalid: + # When we are not able to compare (e.g. due to unsupported types), + # fall back to selecting only rows in the source table that do NOT already exist in the target. + # See: https://github.com/apache/arrow/issues/35785 + MARKER_COLUMN_NAME = "__from_target" + INDEX_COLUMN_NAME = "__source_index" + + if MARKER_COLUMN_NAME in join_cols_set or INDEX_COLUMN_NAME in join_cols_set: + raise ValueError( + f"{MARKER_COLUMN_NAME} and {INDEX_COLUMN_NAME} are reserved for joining " + f"DataFrames, and cannot be used as column names" + ) from None + Review Comment: nit: how do we gate for complex type columns in the join keys? ########## pyiceberg/table/upsert_util.py: ########## @@ -82,14 +82,54 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols ], ) - return ( - source_table - # We already know that the schema is compatible, this is to fix large_ types - .cast(target_table.schema) - .join(target_table, keys=list(join_cols_set), join_type="inner", left_suffix="-lhs", right_suffix="-rhs") - .filter(diff_expr) - .drop_columns([f"{col}-rhs" for col in non_key_cols]) - .rename_columns({f"{col}-lhs" if col not in join_cols else col: col for col in source_table.column_names}) - # Finally cast to the original schema since it doesn't carry nullability: - # https://github.com/apache/arrow/issues/45557 - ).cast(target_table.schema) + try: + return ( + source_table + # We already know that the schema is compatible, this is to fix large_ types + .cast(target_table.schema) + .join(target_table, keys=list(join_cols_set), join_type="inner", left_suffix="-lhs", right_suffix="-rhs") + .filter(diff_expr) + .drop_columns([f"{col}-rhs" for col in non_key_cols]) + .rename_columns({f"{col}-lhs" if col not in join_cols else col: col for col in source_table.column_names}) + # Finally cast to the original schema since it doesn't carry nullability: + # https://github.com/apache/arrow/issues/45557 + ).cast(target_table.schema) + except pa.ArrowInvalid: + # When we are not able to compare (e.g. due to unsupported types), + # fall back to selecting only rows in the source table that do NOT already exist in the target. + # See: https://github.com/apache/arrow/issues/35785 + MARKER_COLUMN_NAME = "__from_target" + INDEX_COLUMN_NAME = "__source_index" + + if MARKER_COLUMN_NAME in join_cols_set or INDEX_COLUMN_NAME in join_cols_set: + raise ValueError( + f"{MARKER_COLUMN_NAME} and {INDEX_COLUMN_NAME} are reserved for joining " + f"DataFrames, and cannot be used as column names" + ) from None + + # Step 1: Prepare source index with join keys and a marker index + # Cast to target table schema, so we can do the join + # See: https://github.com/apache/arrow/issues/37542 + source_index = ( + source_table.cast(target_table.schema) + .select(join_cols_set) + .append_column(INDEX_COLUMN_NAME, pa.array(range(len(source_table)))) + ) + + # Step 2: Prepare target index with join keys and a marker + target_index = target_table.select(join_cols_set).append_column(MARKER_COLUMN_NAME, pa.repeat(True, len(target_table))) + + # Step 3: Perform a left outer join to find which rows from source exist in target + joined = source_index.join(target_index, keys=list(join_cols_set), join_type="left outer") + + # Step 4: Restore original source order + joined = joined.sort_by(INDEX_COLUMN_NAME) + + # Step 5: Create a boolean mask for rows that do exist in the target + # i.e., where marker column is true after the join + to_update_mask = pc.invert(pc.is_null(joined[MARKER_COLUMN_NAME])) Review Comment: nit: instead of `not is_null`, can we just use `is True`? comparing nulls get super weird, see the kleene logic above ########## tests/table/test_upsert.py: ########## @@ -511,6 +511,76 @@ def test_upsert_without_identifier_fields(catalog: Catalog) -> None: tbl.upsert(df) +def test_upsert_struct_field_fails_in_join(catalog: Catalog) -> None: + identifier = "default.test_upsert_struct_field_fails" + _drop_table(catalog, identifier) + + schema = Schema( + NestedField(1, "id", IntegerType(), required=True), + NestedField( + 2, + "nested_type", + # Struct<type: string, coordinates: list<double>> + StructType( + NestedField(3, "sub1", StringType(), required=True), + NestedField(4, "sub2", StringType(), required=True), + ), + required=False, + ), + identifier_field_ids=[1], + ) + + tbl = catalog.create_table(identifier, schema=schema) + + arrow_schema = pa.schema( + [ + pa.field("id", pa.int32(), nullable=False), + pa.field( + "nested_type", + pa.struct( + [ + pa.field("sub1", pa.large_string(), nullable=False), + pa.field("sub2", pa.large_string(), nullable=False), + ] + ), + nullable=True, + ), + ] + ) + + initial_data = pa.Table.from_pylist( + [ + { + "id": 1, + "nested_type": {"sub1": "bla1", "sub2": "bla"}, + } + ], + schema=arrow_schema, + ) + tbl.append(initial_data) + + update_data = pa.Table.from_pylist( + [ + { + "id": 2, + "nested_type": {"sub1": "bla1", "sub2": "bla"}, + }, + { + "id": 1, + "nested_type": {"sub1": "bla1", "sub2": "bla"}, + }, + ], + schema=arrow_schema, + ) + + upd = tbl.upsert(update_data, join_cols=["id"]) + + # Row needs to be updated even tho it's not changed. + # When pyarrow isn't able to compare rows, just update everything + assert upd.rows_updated == 1 Review Comment: this _might_ cause confusion. perhaps we can throw a warning or something -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org For additional commands, e-mail: issues-h...@iceberg.apache.org