rdblue commented on code in PR #8174:
URL: https://github.com/apache/iceberg/pull/8174#discussion_r1303342083
##########
python/pyiceberg/table/__init__.py:
##########
@@ -839,3 +887,253 @@ def to_ray(self) -> ray.data.dataset.Dataset:
import ray
return ray.data.from_arrow(self.to_arrow())
+
+
+class UpdateSchema:
+ _table: Table
+ _schema: Schema
+ _last_column_id: itertools.count[int]
+ _identifier_field_names: List[str]
+ _adds: Dict[int, List[NestedField]]
+ _added_name_to_id: Dict[str, int]
+ _id_to_parent: Dict[int, str]
+ _allow_incompatible_changes: bool
+ _case_sensitive: bool
+ _transaction: Optional[Transaction]
+
+ def __init__(self, schema: Schema, table: Table, transaction:
Optional[Transaction] = None):
+ self._table = table
+ self._schema = schema
+ self._last_column_id = itertools.count(schema.highest_field_id + 1)
+ self._identifier_field_names = schema.column_names
+ self._adds = {}
+ self._added_name_to_id = {}
+ self._id_to_parent = {}
+ self._allow_incompatible_changes = False
+ self._case_sensitive = True
+ self._transaction = transaction
+
+ def __exit__(self, _: Any, value: Any, traceback: Any) -> None:
+ """Closes and commits the change."""
+ return self.commit()
+
+ def __enter__(self) -> UpdateSchema:
+ """Update the table."""
+ return self
+
+ def case_sensitive(self, case_sensitive: bool) -> UpdateSchema:
+ """Determines if the case of schema needs to be considered when
comparing column names.
+
+ Args:
+ case_sensitive: When false case is not considered in column name
comparisons.
+
+ Returns:
+ This for method chaining
+ """
+ self._case_sensitive = case_sensitive
+ return self
+
+ def add_column(
+ self, name: str, type_var: IcebergType, doc: Optional[str] = None,
parent: Optional[str] = None, required: bool = False
+ ) -> UpdateSchema:
+ """Add a new column to a nested struct or Add a new top-level column.
+
+ Args:
+ name: Name for the new column.
+ type_var: Type for the new column.
+ doc: Documentation string for the new column.
+ parent: Name of the parent struct to the column will be added to.
+ required: Whether the new column is required.
+
+ Returns:
+ This for method chaining
+ """
+ if "." in name:
+ raise ValueError(f"Cannot add column with ambiguous name: {name}")
+
+ if required and not self._allow_incompatible_changes:
+ # Table format version 1 and 2 cannot add required column because
there is no initial value
+ raise ValueError(f"Incompatible change: cannot add required
column: {name}")
+
+ self._internal_add_column(parent, name, not required, type_var, doc)
+ return self
+
+ def allow_incompatible_changes(self) -> UpdateSchema:
+ """Allow incompatible changes to the schema.
+
+ Returns:
+ This for method chaining
+ """
+ self._allow_incompatible_changes = True
+ return self
+
+ def commit(self) -> None:
+ """Apply the pending changes and commit."""
+ new_schema = self._apply()
+ updates = [
+ AddSchemaUpdate(schema=new_schema,
last_column_id=new_schema.highest_field_id),
+ SetCurrentSchemaUpdate(schema_id=-1),
+ ]
+ requirements =
[AssertCurrentSchemaId(current_schema_id=self._schema.schema_id)]
+
+ if self._transaction is not None:
+ self._transaction._append_updates(*updates) # pylint:
disable=W0212
+ self._transaction._append_requirements(*requirements) # pylint:
disable=W0212
+ else:
+ table_update_response = self._table.catalog._commit_table( #
pylint: disable=W0212
+ CommitTableRequest(identifier=self._table.identifier[1:],
updates=updates, requirements=requirements)
+ )
+ self._table.metadata = table_update_response.metadata
+ self._table.metadata_location =
table_update_response.metadata_location
+
+ def _apply(self) -> Schema:
+ """Apply the pending changes to the original schema and returns the
result.
+
+ Returns:
+ the result Schema when all pending updates are applied
+ """
+ return _apply_changes(self._schema, self._adds,
self._identifier_field_names)
+
+ def _internal_add_column(
+ self, parent: Optional[str], name: str, is_optional: bool, type_var:
IcebergType, doc: Optional[str]
+ ) -> None:
+ full_name: str = name
+ parent_id: int = TABLE_ROOT_ID
+
+ exist_field: Optional[NestedField] = None
+ if parent:
+ parent_field = self._schema.find_field(parent,
self._case_sensitive)
+ parent_type = parent_field.field_type
+ if isinstance(parent_type, MapType):
+ parent_field = parent_type.value_field
+ elif isinstance(parent_type, ListType):
+ parent_field = parent_type.element_field
+
+ if not parent_field.field_type.is_struct:
+ raise ValueError(f"Cannot add column to non-struct type:
{parent}")
+
+ parent_id = parent_field.field_id
+
+ try:
+ exist_field = self._schema.find_field(parent + "." + name,
self._case_sensitive)
+ except ValueError:
+ pass
+
+ if exist_field:
+ raise ValueError(f"Cannot add column, name already exists:
{parent}.{name}")
+
+ full_name = parent_field.name + "." + name
+
+ else:
+ try:
+ exist_field = self._schema.find_field(name,
self._case_sensitive)
+ except ValueError:
+ pass
+
+ if exist_field:
+ raise ValueError(f"Cannot add column, name already exists:
{name}")
+
+ # assign new IDs in order
+ new_id = self.assign_new_column_id()
+
+ # update tracking for moves
+ self._added_name_to_id[full_name] = new_id
+
+ new_type = assign_fresh_schema_ids(type_var, self.assign_new_column_id)
+ field = NestedField(new_id, name, new_type, not is_optional, doc)
+
+ self._adds.setdefault(parent_id, []).append(field)
Review Comment:
@Fokko, looks like this doesn't update `_id_to_parent` but the Java
implementation does.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]