smaheshwar-pltr commented on code in PR #3220:
URL: https://github.com/apache/iceberg-python/pull/3220#discussion_r3268135212


##########
tests/catalog/test_catalog_behaviors.py:
##########
@@ -387,6 +387,298 @@ def test_load_table_from_self_identifier(
     assert table.metadata == loaded_table.metadata
 
 
+_SIMPLE_SCHEMA = Schema(
+    NestedField(field_id=1, name="id", field_type=LongType(), required=False),
+    NestedField(field_id=2, name="data", field_type=StringType(), 
required=False),
+)
+
+
+def _create_simple_table(
+    catalog: Catalog,
+    identifier: Identifier,
+    *,
+    schema: Schema = _SIMPLE_SCHEMA,
+    format_version: int = 2,
+    partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
+    properties: dict[str, str] | None = None,
+) -> tuple[Identifier, Schema]:
+    namespace = Catalog.namespace_from(identifier)
+    catalog.create_namespace_if_not_exists(namespace)
+    merged_properties = {"format-version": str(format_version), **(properties 
or {})}
+    catalog.create_table(identifier, schema=schema, 
partition_spec=partition_spec, properties=merged_properties)
+    return identifier, schema
+
+
+def _simple_data(num_rows: int = 2) -> pa.Table:
+    return pa.Table.from_pydict(
+        {"id": list(range(num_rows)), "data": [chr(ord("a") + i) for i in 
range(num_rows)]},
+        schema=pa.schema([pa.field("id", pa.int64()), pa.field("data", 
pa.large_string())]),
+    )
+
+
+_REPLACE_SCHEMA = Schema(
+    NestedField(field_id=1, name="id", field_type=LongType(), required=False),
+    NestedField(field_id=2, name="data", field_type=StringType(), 
required=False),
+    NestedField(field_id=3, name="extra", field_type=BooleanType(), 
required=False),
+)
+
+
+def test_replace_transaction(catalog: Catalog, test_table_identifier: 
Identifier) -> None:
+    _, original_schema = _create_simple_table(catalog, test_table_identifier)
+    original = catalog.load_table(test_table_identifier)
+    original.append(_simple_data())
+    original = catalog.load_table(test_table_identifier)
+    old_snapshot_id = original.current_snapshot().snapshot_id  # type: 
ignore[union-attr]
+    snapshot_log_before = list(original.metadata.snapshot_log)
+    assert len(snapshot_log_before) == 1
+
+    catalog.replace_table(test_table_identifier, schema=_REPLACE_SCHEMA)
+    replaced = catalog.load_table(test_table_identifier)
+
+    # UUID + history preserved, current snapshot cleared, current schema 
swapped.
+    assert replaced.metadata.table_uuid == original.metadata.table_uuid
+    assert replaced.metadata.current_snapshot_id is None
+    assert {f.name for f in replaced.schema().fields} == {"id", "data", 
"extra"}
+    # Old snapshot kept by identity (not just count), and snapshot_log entries 
from before survive.
+    assert any(s.snapshot_id == old_snapshot_id for s in 
replaced.metadata.snapshots)
+    assert all(entry in replaced.metadata.snapshot_log for entry in 
snapshot_log_before)
+    # Old schema is still in the schemas list alongside the new one.
+    schema_ids = sorted(s.schema_id for s in replaced.metadata.schemas)
+    assert schema_ids == [0, 1]
+    assert replaced.metadata.current_schema_id == 1
+    # Time-travel back to the pre-replace snapshot returns the rows that were 
there before.
+    assert 
replaced.scan(snapshot_id=old_snapshot_id).to_arrow().equals(_simple_data())
+
+
+def test_complete_replace_transaction(catalog: Catalog, test_table_identifier: 
Identifier, tmp_path: Path) -> None:
+    _create_simple_table(catalog, test_table_identifier, properties={"keep": 
"yes", "override": "old"})
+    catalog.load_table(test_table_identifier).append(_simple_data())
+    original = catalog.load_table(test_table_identifier)
+    old_snapshot_id = original.current_snapshot().snapshot_id  # type: 
ignore[union-attr]
+    original_data = original.scan().to_arrow()
+
+    new_location = f"file://{tmp_path}/replaced"
+    new_schema = Schema(
+        NestedField(field_id=1, name="id", field_type=LongType(), 
required=False),
+        NestedField(field_id=2, name="data", field_type=StringType(), 
required=False),
+        NestedField(field_id=3, name="extra", field_type=BooleanType(), 
required=False),
+    )
+    new_spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, 
name="id_part", transform=IdentityTransform()))
+    new_sort = SortOrder(SortField(source_id=1, transform=IdentityTransform(), 
direction=SortDirection.ASC))
+    new_data = pa.Table.from_pydict(
+        {"id": [10, 20], "data": ["alice", "bob"], "extra": [True, False]},
+        schema=pa.schema([pa.field("id", pa.int64()), pa.field("data", 
pa.large_string()), pa.field("extra", pa.bool_())]),
+    )
+
+    with catalog.replace_table_transaction(
+        test_table_identifier,
+        schema=new_schema,
+        partition_spec=new_spec,
+        sort_order=new_sort,
+        location=new_location,
+        properties={"override": "new", "added": "v"},
+    ) as txn:
+        txn.append(new_data)
+
+    replaced = catalog.load_table(test_table_identifier)
+
+    # Identity invariants.
+    assert replaced.metadata.table_uuid == original.metadata.table_uuid
+    assert replaced.metadata.location == new_location
+
+    # New schema / spec / sort applied; old entries retained in history.
+    assert {f.name for f in replaced.schema().fields} == {"id", "data", 
"extra"}
+    assert sorted(s.schema_id for s in replaced.metadata.schemas) == [0, 1]
+    assert replaced.spec().fields[0].source_id == 1
+    assert isinstance(replaced.spec().fields[0].transform, IdentityTransform)
+    assert {s.spec_id for s in replaced.metadata.partition_specs} == {0, 1}
+    assert replaced.sort_order().fields == new_sort.fields
+    assert {s.order_id for s in replaced.metadata.sort_orders} == {0, 
replaced.metadata.default_sort_order_id}
+
+    # Property merge: kept, overridden, added — and `format-version` does not 
leak.
+    assert replaced.properties["keep"] == "yes"
+    assert replaced.properties["override"] == "new"
+    assert replaced.properties["added"] == "v"
+    assert "format-version" not in replaced.properties
+
+    # RTAS atomicity: new snapshot exists, has no parent (fresh start), old 
snapshot is still
+    # in the snapshot list, and time-travel reads return the original rows.
+    new_snapshot = replaced.current_snapshot()
+    assert new_snapshot is not None
+    assert new_snapshot.snapshot_id != old_snapshot_id
+    assert new_snapshot.parent_snapshot_id is None
+    assert any(s.snapshot_id == old_snapshot_id for s in 
replaced.metadata.snapshots)
+    assert replaced.scan().to_arrow().num_rows == 2
+    # Time-travel back to before the replace returns the original rows from 
the old schema.
+    time_travel = replaced.scan(snapshot_id=old_snapshot_id).to_arrow()
+    assert time_travel.num_rows == original_data.num_rows
+    assert time_travel.column("id").to_pylist() == 
original_data.column("id").to_pylist()
+
+
+def test_replace_transaction_requires_table_exists(catalog: Catalog, 
test_table_identifier: Identifier) -> None:
+    schema = Schema(NestedField(field_id=1, name="id", field_type=LongType(), 
required=False))
+    with pytest.raises(NoSuchTableError):
+        catalog.replace_table(test_table_identifier, schema=schema)
+
+
+def test_replace_table_reuses_schema_id_when_identical(catalog: Catalog, 
test_table_identifier: Identifier) -> None:
+    _, base_schema = _create_simple_table(catalog, test_table_identifier)
+    replaced = catalog.replace_table(test_table_identifier, schema=base_schema)
+    # Identical shape -> no new schema appended, current points back at id 0.
+    assert [s.schema_id for s in replaced.metadata.schemas] == [0]
+    assert replaced.metadata.current_schema_id == 0
+    assert replaced.metadata.last_column_id == 2
+
+
+def test_replace_table_reuses_partition_spec_and_sort_order_when_identical(
+    catalog: Catalog, test_table_identifier: Identifier
+) -> None:
+    spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, 
name="id_part", transform=IdentityTransform()))
+    sort = SortOrder(SortField(source_id=1, transform=IdentityTransform(), 
direction=SortDirection.ASC))
+    _, schema = _create_simple_table(catalog, test_table_identifier, 
partition_spec=spec)
+    # Introduce a sort order then replay both spec and sort — neither should 
append a new entry.
+    sorted_first = catalog.replace_table(test_table_identifier, schema=schema, 
partition_spec=spec, sort_order=sort)
+    sorted_order_id = sorted_first.metadata.default_sort_order_id
+    assert sorted_order_id != 0
+
+    replayed = catalog.replace_table(test_table_identifier, schema=schema, 
partition_spec=spec, sort_order=sort)
+    assert [s.spec_id for s in replayed.metadata.partition_specs] == [0]
+    assert replayed.metadata.default_spec_id == 0
+    assert replayed.metadata.default_sort_order_id == sorted_order_id
+
+    # Dropping the sort order falls back to the unsorted order_id 0 (also 
reused, not appended).
+    unsorted = catalog.replace_table(test_table_identifier, schema=schema, 
partition_spec=spec)
+    assert unsorted.sort_order().is_unsorted
+    assert unsorted.metadata.default_sort_order_id == 0
+
+
[email protected]("keep_identifier", [True, False], ids=["preserves", 
"drops"])
+def test_replace_table_identifier_field_ids(catalog: Catalog, 
test_table_identifier: Identifier, keep_identifier: bool) -> None:
+    schema = Schema(
+        NestedField(field_id=1, name="id", field_type=LongType(), 
required=True),
+        NestedField(field_id=2, name="data", field_type=StringType(), 
required=False),
+        identifier_field_ids=[1],
+    )
+    _create_simple_table(catalog, test_table_identifier, schema=schema)
+    new_schema = (
+        Schema(
+            NestedField(field_id=1, name="id", field_type=LongType(), 
required=True),
+            NestedField(field_id=2, name="data", field_type=StringType(), 
required=False),
+            NestedField(field_id=3, name="extra", field_type=BooleanType(), 
required=False),
+            identifier_field_ids=[1],
+        )
+        if keep_identifier
+        else Schema(
+            NestedField(field_id=1, name="id", field_type=LongType(), 
required=False),
+            NestedField(field_id=2, name="data", field_type=StringType(), 
required=False),
+        )
+    )
+    replaced = catalog.replace_table(test_table_identifier, schema=new_schema)
+    expected = [1] if keep_identifier else []
+    assert list(replaced.schema().identifier_field_ids) == expected
+
+
[email protected](
+    "format_version, expect_void_carry_forward",
+    [(1, True), (2, False)],
+    ids=["v1-carries-forward", "v2-drops"],
+)
+def test_replace_table_partition_field_carry_forward(
+    catalog: Catalog,
+    test_table_identifier: Identifier,
+    format_version: int,
+    expect_void_carry_forward: bool,
+) -> None:
+    spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, 
name="id_part", transform=IdentityTransform()))
+    _, schema = _create_simple_table(catalog, test_table_identifier, 
partition_spec=spec, format_version=format_version)
+    replaced = catalog.replace_table(test_table_identifier, schema=schema)
+    new_spec = replaced.spec()
+    if expect_void_carry_forward:
+        void_field = next(f for f in new_spec.fields if f.field_id == 1000)
+        assert isinstance(void_field.transform, VoidTransform)
+        assert void_field.source_id == 1
+        assert void_field.name == "id_part"
+    else:
+        assert new_spec.is_unpartitioned()
+
+
+def test_replace_table_upgrades_format_version(catalog: Catalog, 
test_table_identifier: Identifier) -> None:
+    _, schema = _create_simple_table(catalog, test_table_identifier, 
format_version=1)
+    assert catalog.load_table(test_table_identifier).format_version == 1
+
+    upgraded = catalog.replace_table(test_table_identifier, schema=schema, 
properties={"format-version": "2"})
+    assert upgraded.format_version == 2
+    # `format-version` is a control input, not a persisted property.
+    assert "format-version" not in upgraded.properties
+
+    # A follow-up replace stays at the upgraded version.
+    new_schema = Schema(*schema.fields, NestedField(field_id=3, name="extra", 
field_type=BooleanType(), required=False))
+    replayed = catalog.replace_table(test_table_identifier, schema=new_schema)
+    assert replayed.format_version == 2
+    assert {f.name for f in replayed.schema().fields} == {"id", "data", 
"extra"}
+
+
+def test_replace_table_rejects_format_version_downgrade(catalog: Catalog, 
test_table_identifier: Identifier) -> None:
+    _, schema = _create_simple_table(catalog, test_table_identifier, 
format_version=2)
+    with pytest.raises(ValueError, match="Cannot downgrade format-version"):
+        catalog.replace_table(test_table_identifier, schema=schema, 
properties={"format-version": "1"})
+
+
[email protected]("location_kind", ["inherit", "explicit", 
"trailing-slash"])
+def test_replace_table_location(catalog: Catalog, test_table_identifier: 
Identifier, tmp_path: Path, location_kind: str) -> None:
+    _, schema = _create_simple_table(catalog, test_table_identifier)
+    existing = catalog.load_table(test_table_identifier).metadata.location
+
+    if location_kind == "inherit":
+        replaced = catalog.replace_table(test_table_identifier, schema=schema)
+        assert replaced.metadata.location == existing
+    else:
+        bare = f"file://{tmp_path}/relocated"
+        arg = bare + "/" if location_kind == "trailing-slash" else bare
+        replaced = catalog.replace_table(test_table_identifier, schema=schema, 
location=arg)
+        assert replaced.metadata.location == bare
+
+
+def test_replace_table_transaction_rolls_back_on_failure(catalog: Catalog, 
test_table_identifier: Identifier) -> None:
+    _create_simple_table(catalog, test_table_identifier)
+    catalog.load_table(test_table_identifier).append(_simple_data())
+    before = catalog.load_table(test_table_identifier).metadata
+
+    def run_failing_replace() -> None:
+        with catalog.replace_table_transaction(test_table_identifier, 
schema=_REPLACE_SCHEMA):
+            raise RuntimeError("simulated failure inside replace transaction")
+
+    with pytest.raises(RuntimeError, match="simulated failure inside replace 
transaction"):
+        run_failing_replace()
+
+    after = catalog.load_table(test_table_identifier).metadata
+    assert after.table_uuid == before.table_uuid
+    assert after.current_snapshot_id == before.current_snapshot_id
+    assert after.current_schema_id == before.current_schema_id
+    assert len(after.schemas) == len(before.schemas)
+
+
+def test_concurrent_replace_transaction_schema_conflict(catalog: Catalog, 
test_table_identifier: Identifier) -> None:
+    _create_simple_table(catalog, test_table_identifier)
+    txn_a = catalog.replace_table_transaction(test_table_identifier, 
schema=_REPLACE_SCHEMA)
+    txn_b = catalog.replace_table_transaction(test_table_identifier, 
schema=_REPLACE_SCHEMA)
+
+    txn_a.commit_transaction()
+    with pytest.raises(CommitFailedException, match="last assigned field id"):
+        txn_b.commit_transaction()
+
+
+def test_concurrent_replace_transaction_partition_spec_conflict(catalog: 
Catalog, test_table_identifier: Identifier) -> None:

Review Comment:
   Mirrors Java's 
[`testConcurrentReplaceTransactionPartitionSpecConflict`](https://github.com/apache/iceberg/blob/2f6606a247e2b16be46ca6c02fc4cfc2e17691e6/core/src/test/java/org/apache/iceberg/catalog/CatalogTests.java#L2986-L3023).
 Same deliberate-not-ported reasoning for the non-conflict spec variants as for 
the schema case above.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to