rdblue commented on code in PR #8174:
URL: https://github.com/apache/iceberg/pull/8174#discussion_r1303393150
##########
python/pyiceberg/table/__init__.py:
##########
@@ -839,3 +887,253 @@ def to_ray(self) -> ray.data.dataset.Dataset:
import ray
return ray.data.from_arrow(self.to_arrow())
+
+
+class UpdateSchema:
+ _table: Table
+ _schema: Schema
+ _last_column_id: itertools.count[int]
+ _identifier_field_names: List[str]
+ _adds: Dict[int, List[NestedField]]
+ _added_name_to_id: Dict[str, int]
+ _id_to_parent: Dict[int, str]
+ _allow_incompatible_changes: bool
+ _case_sensitive: bool
+ _transaction: Optional[Transaction]
+
+ def __init__(self, schema: Schema, table: Table, transaction:
Optional[Transaction] = None):
+ self._table = table
+ self._schema = schema
+ self._last_column_id = itertools.count(schema.highest_field_id + 1)
+ self._identifier_field_names = schema.column_names
+ self._adds = {}
+ self._added_name_to_id = {}
+ self._id_to_parent = {}
+ self._allow_incompatible_changes = False
+ self._case_sensitive = True
+ self._transaction = transaction
+
+ def __exit__(self, _: Any, value: Any, traceback: Any) -> None:
+ """Closes and commits the change."""
+ return self.commit()
+
+ def __enter__(self) -> UpdateSchema:
+ """Update the table."""
+ return self
+
+ def case_sensitive(self, case_sensitive: bool) -> UpdateSchema:
+ """Determines if the case of schema needs to be considered when
comparing column names.
+
+ Args:
+ case_sensitive: When false case is not considered in column name
comparisons.
+
+ Returns:
+ This for method chaining
+ """
+ self._case_sensitive = case_sensitive
+ return self
+
+ def add_column(
+ self, name: str, type_var: IcebergType, doc: Optional[str] = None,
parent: Optional[str] = None, required: bool = False
+ ) -> UpdateSchema:
+ """Add a new column to a nested struct or Add a new top-level column.
+
+ Args:
+ name: Name for the new column.
+ type_var: Type for the new column.
+ doc: Documentation string for the new column.
+ parent: Name of the parent struct to the column will be added to.
+ required: Whether the new column is required.
+
+ Returns:
+ This for method chaining
+ """
+ if "." in name:
+ raise ValueError(f"Cannot add column with ambiguous name: {name}")
+
+ if required and not self._allow_incompatible_changes:
+ # Table format version 1 and 2 cannot add required column because
there is no initial value
+ raise ValueError(f"Incompatible change: cannot add required
column: {name}")
+
+ self._internal_add_column(parent, name, not required, type_var, doc)
+ return self
+
+ def allow_incompatible_changes(self) -> UpdateSchema:
+ """Allow incompatible changes to the schema.
+
+ Returns:
+ This for method chaining
+ """
+ self._allow_incompatible_changes = True
+ return self
+
+ def commit(self) -> None:
+ """Apply the pending changes and commit."""
+ new_schema = self._apply()
+ updates = [
+ AddSchemaUpdate(schema=new_schema,
last_column_id=new_schema.highest_field_id),
+ SetCurrentSchemaUpdate(schema_id=-1),
+ ]
+ requirements =
[AssertCurrentSchemaId(current_schema_id=self._schema.schema_id)]
+
+ if self._transaction is not None:
+ self._transaction._append_updates(*updates) # pylint:
disable=W0212
+ self._transaction._append_requirements(*requirements) # pylint:
disable=W0212
+ else:
+ table_update_response = self._table.catalog._commit_table( #
pylint: disable=W0212
+ CommitTableRequest(identifier=self._table.identifier[1:],
updates=updates, requirements=requirements)
+ )
+ self._table.metadata = table_update_response.metadata
+ self._table.metadata_location =
table_update_response.metadata_location
+
+ def _apply(self) -> Schema:
+ """Apply the pending changes to the original schema and returns the
result.
+
+ Returns:
+ the result Schema when all pending updates are applied
+ """
+ return _apply_changes(self._schema, self._adds,
self._identifier_field_names)
+
+ def _internal_add_column(
+ self, parent: Optional[str], name: str, is_optional: bool, type_var:
IcebergType, doc: Optional[str]
+ ) -> None:
+ full_name: str = name
+ parent_id: int = TABLE_ROOT_ID
+
+ exist_field: Optional[NestedField] = None
+ if parent:
+ parent_field = self._schema.find_field(parent,
self._case_sensitive)
+ parent_type = parent_field.field_type
+ if isinstance(parent_type, MapType):
+ parent_field = parent_type.value_field
+ elif isinstance(parent_type, ListType):
+ parent_field = parent_type.element_field
+
+ if not parent_field.field_type.is_struct:
+ raise ValueError(f"Cannot add column to non-struct type:
{parent}")
+
+ parent_id = parent_field.field_id
+
+ try:
+ exist_field = self._schema.find_field(parent + "." + name,
self._case_sensitive)
+ except ValueError:
+ pass
+
+ if exist_field:
+ raise ValueError(f"Cannot add column, name already exists:
{parent}.{name}")
+
+ full_name = parent_field.name + "." + name
+
+ else:
+ try:
+ exist_field = self._schema.find_field(name,
self._case_sensitive)
+ except ValueError:
+ pass
+
+ if exist_field:
+ raise ValueError(f"Cannot add column, name already exists:
{name}")
+
+ # assign new IDs in order
+ new_id = self.assign_new_column_id()
+
+ # update tracking for moves
+ self._added_name_to_id[full_name] = new_id
+
+ new_type = assign_fresh_schema_ids(type_var, self.assign_new_column_id)
+ field = NestedField(new_id, name, new_type, not is_optional, doc)
+
+ self._adds.setdefault(parent_id, []).append(field)
+
+ def assign_new_column_id(self) -> int:
+ return next(self._last_column_id)
+
+
+def _apply_changes(schema_: Schema, adds: Dict[int, List[NestedField]],
identifier_field_names: List[str]) -> Schema:
+ struct = visit(schema_, _ApplyChanges(adds))
+ name_to_id: Dict[str, int] = index_by_name(struct)
+ for name in identifier_field_names:
+ if name not in name_to_id:
+ raise ValueError(f"Cannot add field {name} as an identifier field:
not found in current schema or added columns")
+
+ return Schema(*struct.fields)
+
+
+class _ApplyChanges(SchemaVisitor[IcebergType]):
+ def __init__(self, adds: Dict[int, List[NestedField]]):
+ self.adds = adds
+
+ def schema(self, schema: Schema, struct_result: IcebergType) ->
IcebergType:
+ fields = _ApplyChanges.add_fields(schema.as_struct().fields,
self.adds.get(TABLE_ROOT_ID))
+ if len(fields) > 0:
+ return StructType(*fields)
+
+ return struct_result
+
+ def struct(self, struct: StructType, field_results: List[IcebergType]) ->
IcebergType:
+ has_change = False
+ new_fields: List[NestedField] = []
+ for i in range(len(field_results)):
+ type_: Optional[IcebergType] = field_results[i]
+ if type_ is None:
+ has_change = True
+ continue
+
+ field: NestedField = struct.fields[i]
+ new_fields.append(field)
+
+ if has_change:
+ return StructType(*new_fields)
+
+ return struct
+
+ def field(self, field: NestedField, field_result: IcebergType) ->
IcebergType:
+ field_id: int = field.field_id
+ if field_id in self.adds:
+ new_fields = self.adds[field_id]
+ if len(new_fields) > 0:
+ fields = _ApplyChanges.add_fields(field_result.fields,
new_fields)
+ if len(fields) > 0:
+ return StructType(*fields)
+
+ return field_result
+
+ def list(self, list_type: ListType, element_result: IcebergType) ->
IcebergType:
+ element_field: NestedField = list_type.element_field
+ element_type = self.field(element_field, element_result)
+ if element_type is None:
+ raise ValueError(f"Cannot delete element type from list:
{element_field}")
+
+ is_element_optional = not list_type.element_required
+
+ if is_element_optional == element_field.required and
list_type.element_type == element_type:
+ return list_type
+
+ return ListType(list_type.element_id, element_type,
is_element_optional)
+
+ def map(self, map_type: MapType, key_result: IcebergType, value_result:
IcebergType) -> IcebergType:
+ key_id: int = map_type.key_field.field_id
+ if key_id in self.adds:
+ raise ValueError(f"Cannot add fields to map keys: {map_type}")
+
+ value_field: NestedField = map_type.value_field
+ value_type = self.field(value_field, value_result)
+ if value_type is None:
+ raise ValueError(f"Cannot delete value type from map:
{value_field}")
+
+ is_value_optional = not map_type.value_required
+
+ if is_value_optional != value_field.required and map_type.value_type
== value_type:
+ return map_type
+
+ return MapType(map_type.key_id, map_type.key_field, map_type.value_id,
value_type, not is_value_optional)
+
+ def primitive(self, primitive: PrimitiveType) -> IcebergType:
+ return primitive
+
+ @staticmethod
+ def add_fields(fields: Tuple[NestedField, ...], adds:
Optional[List[NestedField]]) -> List[NestedField]:
Review Comment:
If adds is None, then I think this should return None.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]