HonahX commented on code in PR #296: URL: https://github.com/apache/iceberg-python/pull/296#discussion_r1465867474
########## pyiceberg/table/__init__.py: ########## @@ -1995,6 +2020,156 @@ def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]: return primitive +class UnionByNameVisitor(SchemaWithPartnerVisitor[int, bool]): + update_schema: UpdateSchema + new_schema: Schema + case_sensitive: bool + + def __init__(self, update_schema: UpdateSchema, new_schema: Schema, case_sensitive: bool) -> None: + self.update_schema = update_schema + self.new_schema = new_schema + self.case_sensitive = case_sensitive + + def schema(self, schema: Schema, partner_id: Optional[int], struct_result: bool) -> bool: + return struct_result + + def struct(self, struct: StructType, partner_id: Optional[int], missing_positions: List[bool]) -> bool: + if partner_id is None: + return True + + fields = struct.fields + partner_struct = self._find_field_type(partner_id) + + for pos, missing in enumerate(missing_positions): + if missing: + self._add_column(partner_id, fields[pos]) + else: + field = fields[pos] + if nested_field := partner_struct.field_by_name(field.name, case_sensitive=self.case_sensitive): + self._update_column(field, nested_field) + + return False + + def _add_column(self, parent_id: int, field: NestedField) -> None: + if parent_name := self.new_schema.find_column_name(parent_id): + path: Tuple[str, ...] = (parent_name, field.name) + else: + path = (field.name,) + + self.update_schema.add_column(path=path, field_type=field.field_type, required=field.required, doc=field.doc) + + def _update_column(self, field: NestedField, existing_field: NestedField) -> None: + full_name = self.new_schema.find_column_name(existing_field.field_id) + + if full_name is None: + raise ValueError(f"Could not find field: {existing_field}") + + if field.optional and existing_field.required: + self.update_schema.make_column_optional(full_name) + + if field.field_type.is_primitive and field.field_type != existing_field.field_type: + self.update_schema.update_column(full_name, field_type=field.field_type) + + if field.doc is not None and not field.doc != existing_field.doc: + self.update_schema.update_column(full_name, doc=field.doc) + + def _find_field_type(self, field_id: int) -> IcebergType: + if field_id == -1: + return self.new_schema.as_struct() + else: + return self.new_schema.find_field(field_id).field_type + + def field(self, field: NestedField, field_partner: Optional[int], field_result: bool) -> bool: Review Comment: ```suggestion def field(self, field: NestedField, partner_id: Optional[int], field_result: bool) -> bool: ``` Shall we name the second argument as `partner_id` to better reveal its content? We already did this in `schema` and `struct`. Same applied for `*_partner` arguments below ########## pyiceberg/table/__init__.py: ########## @@ -1995,6 +2020,156 @@ def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]: return primitive +class UnionByNameVisitor(SchemaWithPartnerVisitor[int, bool]): + update_schema: UpdateSchema + new_schema: Schema + case_sensitive: bool + + def __init__(self, update_schema: UpdateSchema, new_schema: Schema, case_sensitive: bool) -> None: + self.update_schema = update_schema + self.new_schema = new_schema + self.case_sensitive = case_sensitive + + def schema(self, schema: Schema, partner_id: Optional[int], struct_result: bool) -> bool: + return struct_result + + def struct(self, struct: StructType, partner_id: Optional[int], missing_positions: List[bool]) -> bool: + if partner_id is None: + return True + + fields = struct.fields + partner_struct = self._find_field_type(partner_id) + + for pos, missing in enumerate(missing_positions): + if missing: + self._add_column(partner_id, fields[pos]) + else: + field = fields[pos] + if nested_field := partner_struct.field_by_name(field.name, case_sensitive=self.case_sensitive): + self._update_column(field, nested_field) + + return False + + def _add_column(self, parent_id: int, field: NestedField) -> None: + if parent_name := self.new_schema.find_column_name(parent_id): + path: Tuple[str, ...] = (parent_name, field.name) + else: + path = (field.name,) + + self.update_schema.add_column(path=path, field_type=field.field_type, required=field.required, doc=field.doc) + + def _update_column(self, field: NestedField, existing_field: NestedField) -> None: + full_name = self.new_schema.find_column_name(existing_field.field_id) + + if full_name is None: + raise ValueError(f"Could not find field: {existing_field}") + + if field.optional and existing_field.required: + self.update_schema.make_column_optional(full_name) + + if field.field_type.is_primitive and field.field_type != existing_field.field_type: + self.update_schema.update_column(full_name, field_type=field.field_type) + + if field.doc is not None and not field.doc != existing_field.doc: + self.update_schema.update_column(full_name, doc=field.doc) + + def _find_field_type(self, field_id: int) -> IcebergType: + if field_id == -1: + return self.new_schema.as_struct() + else: + return self.new_schema.find_field(field_id).field_type + + def field(self, field: NestedField, field_partner: Optional[int], field_result: bool) -> bool: + return field_partner is None + + def list(self, list_type: ListType, list_partner: Optional[int], element_missing: bool) -> bool: + if list_partner is None: + return False Review Comment: ```suggestion return True ``` I think this should return `True` ########## pyiceberg/table/__init__.py: ########## @@ -1995,6 +2020,156 @@ def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]: return primitive +class UnionByNameVisitor(SchemaWithPartnerVisitor[int, bool]): + update_schema: UpdateSchema + new_schema: Schema + case_sensitive: bool + + def __init__(self, update_schema: UpdateSchema, new_schema: Schema, case_sensitive: bool) -> None: + self.update_schema = update_schema + self.new_schema = new_schema + self.case_sensitive = case_sensitive + + def schema(self, schema: Schema, partner_id: Optional[int], struct_result: bool) -> bool: + return struct_result + + def struct(self, struct: StructType, partner_id: Optional[int], missing_positions: List[bool]) -> bool: + if partner_id is None: + return True + + fields = struct.fields + partner_struct = self._find_field_type(partner_id) + + for pos, missing in enumerate(missing_positions): + if missing: + self._add_column(partner_id, fields[pos]) + else: + field = fields[pos] + if nested_field := partner_struct.field_by_name(field.name, case_sensitive=self.case_sensitive): + self._update_column(field, nested_field) + + return False + + def _add_column(self, parent_id: int, field: NestedField) -> None: + if parent_name := self.new_schema.find_column_name(parent_id): + path: Tuple[str, ...] = (parent_name, field.name) + else: + path = (field.name,) + + self.update_schema.add_column(path=path, field_type=field.field_type, required=field.required, doc=field.doc) + + def _update_column(self, field: NestedField, existing_field: NestedField) -> None: + full_name = self.new_schema.find_column_name(existing_field.field_id) + + if full_name is None: + raise ValueError(f"Could not find field: {existing_field}") + + if field.optional and existing_field.required: + self.update_schema.make_column_optional(full_name) + + if field.field_type.is_primitive and field.field_type != existing_field.field_type: + self.update_schema.update_column(full_name, field_type=field.field_type) + + if field.doc is not None and not field.doc != existing_field.doc: + self.update_schema.update_column(full_name, doc=field.doc) + + def _find_field_type(self, field_id: int) -> IcebergType: + if field_id == -1: + return self.new_schema.as_struct() + else: + return self.new_schema.find_field(field_id).field_type + + def field(self, field: NestedField, field_partner: Optional[int], field_result: bool) -> bool: + return field_partner is None + + def list(self, list_type: ListType, list_partner: Optional[int], element_missing: bool) -> bool: + if list_partner is None: + return False + + if element_missing: + raise ValueError("Error traversing schemas: element is missing, but list is present") + + partner_list_type = self._find_field_type(list_partner) + if not isinstance(partner_list_type, ListType): + raise ValueError(f"Expected list-type, got: {partner_list_type}") + + self._update_column(list_type.element_field, partner_list_type.element_field) + + return False + + def map(self, map_type: MapType, map_partner: Optional[int], key_missing: bool, value_missing: bool) -> bool: + if map_partner is None: + return False Review Comment: ```suggestion return True ``` ########## pyiceberg/table/__init__.py: ########## @@ -1398,14 +1401,23 @@ class UpdateSchema: def __init__( self, - table: Table, + table: Optional[Table], Review Comment: Is this change primarily for easier testing? ########## pyiceberg/table/__init__.py: ########## @@ -1995,6 +2020,156 @@ def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]: return primitive +class UnionByNameVisitor(SchemaWithPartnerVisitor[int, bool]): + update_schema: UpdateSchema + new_schema: Schema + case_sensitive: bool + + def __init__(self, update_schema: UpdateSchema, new_schema: Schema, case_sensitive: bool) -> None: + self.update_schema = update_schema + self.new_schema = new_schema + self.case_sensitive = case_sensitive + + def schema(self, schema: Schema, partner_id: Optional[int], struct_result: bool) -> bool: + return struct_result + + def struct(self, struct: StructType, partner_id: Optional[int], missing_positions: List[bool]) -> bool: + if partner_id is None: + return True + + fields = struct.fields + partner_struct = self._find_field_type(partner_id) Review Comment: Shall we check if `isinstance(partner_struct, StructType)`/`partner_struct.is_struct` and throw an exception if not. ########## pyiceberg/table/__init__.py: ########## @@ -1449,6 +1461,15 @@ def case_sensitive(self, case_sensitive: bool) -> UpdateSchema: self._case_sensitive = case_sensitive return self + def union_by_name(self, new_schema: Schema) -> UpdateSchema: + visit_with_partner( + new_schema, + -1, + UnionByNameVisitor(update_schema=self, new_schema=self._schema, case_sensitive=self._case_sensitive), # type: ignore Review Comment: It seems the names are a little bit confusing here. We use `new_schema` to represent both the new schema to union and the existing schema in the current update_schema. Shall we rename the one in `UnionByNameVisitor` to `existing_schema`? ########## pyiceberg/table/__init__.py: ########## @@ -1995,6 +2020,156 @@ def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]: return primitive +class UnionByNameVisitor(SchemaWithPartnerVisitor[int, bool]): + update_schema: UpdateSchema + new_schema: Schema + case_sensitive: bool + + def __init__(self, update_schema: UpdateSchema, new_schema: Schema, case_sensitive: bool) -> None: + self.update_schema = update_schema + self.new_schema = new_schema + self.case_sensitive = case_sensitive + + def schema(self, schema: Schema, partner_id: Optional[int], struct_result: bool) -> bool: + return struct_result + + def struct(self, struct: StructType, partner_id: Optional[int], missing_positions: List[bool]) -> bool: + if partner_id is None: + return True + + fields = struct.fields + partner_struct = self._find_field_type(partner_id) + + for pos, missing in enumerate(missing_positions): + if missing: + self._add_column(partner_id, fields[pos]) + else: + field = fields[pos] + if nested_field := partner_struct.field_by_name(field.name, case_sensitive=self.case_sensitive): + self._update_column(field, nested_field) + + return False + + def _add_column(self, parent_id: int, field: NestedField) -> None: + if parent_name := self.new_schema.find_column_name(parent_id): + path: Tuple[str, ...] = (parent_name, field.name) + else: + path = (field.name,) + + self.update_schema.add_column(path=path, field_type=field.field_type, required=field.required, doc=field.doc) + + def _update_column(self, field: NestedField, existing_field: NestedField) -> None: + full_name = self.new_schema.find_column_name(existing_field.field_id) + + if full_name is None: + raise ValueError(f"Could not find field: {existing_field}") + + if field.optional and existing_field.required: + self.update_schema.make_column_optional(full_name) + + if field.field_type.is_primitive and field.field_type != existing_field.field_type: + self.update_schema.update_column(full_name, field_type=field.field_type) + + if field.doc is not None and not field.doc != existing_field.doc: + self.update_schema.update_column(full_name, doc=field.doc) + + def _find_field_type(self, field_id: int) -> IcebergType: + if field_id == -1: + return self.new_schema.as_struct() + else: + return self.new_schema.find_field(field_id).field_type + + def field(self, field: NestedField, field_partner: Optional[int], field_result: bool) -> bool: + return field_partner is None + + def list(self, list_type: ListType, list_partner: Optional[int], element_missing: bool) -> bool: + if list_partner is None: + return False + + if element_missing: + raise ValueError("Error traversing schemas: element is missing, but list is present") + + partner_list_type = self._find_field_type(list_partner) + if not isinstance(partner_list_type, ListType): + raise ValueError(f"Expected list-type, got: {partner_list_type}") + + self._update_column(list_type.element_field, partner_list_type.element_field) + + return False + + def map(self, map_type: MapType, map_partner: Optional[int], key_missing: bool, value_missing: bool) -> bool: + if map_partner is None: + return False + + if key_missing: + raise ValueError("Error traversing schemas: key is missing, but map is present") + + if value_missing: + raise ValueError("Error traversing schemas: value is missing, but map is present") + + partner_map_type = self._find_field_type(map_partner) + if not isinstance(partner_map_type, MapType): + raise ValueError(f"Expected map-type, got: {partner_map_type}") + + self._update_column(map_type.key_field, partner_map_type.key_field) + self._update_column(map_type.value_field, partner_map_type.value_field) + + return False + + def primitive(self, primitive: PrimitiveType, primitive_partner: Optional[int]) -> bool: + return primitive_partner is None + + +class PartnerIdByNameAccessor(PartnerAccessor[int]): + partner_schema: Schema + case_sensitive: bool + + def __init__(self, partner_schema: Schema, case_sensitive: bool) -> None: + self.partner_schema = partner_schema + self.case_sensitive = case_sensitive + + def schema_partner(self, partner: Optional[int]) -> Optional[int]: + return -1 + + def field_partner(self, partner_field_id: Optional[int], field_id: int, field_name: str) -> Optional[int]: + if partner_field_id is not None: + if partner_field_id == -1: + struct = self.partner_schema.as_struct() + else: + struct = self.partner_schema.find_field(partner_field_id).field_type + if not struct.is_struct: + raise ValueError(f"Expected StructType: {struct}") + + if field := struct.field_by_name(name=field_name, case_sensitive=self.case_sensitive): + return field.field_id + + return None + + def list_element_partner(self, partner_list: Optional[int]) -> Optional[int]: Review Comment: ```suggestion def list_element_partner(self, partner_list_id: Optional[int]) -> Optional[int]: ``` Shall we add `_id` to the name. Same applies to map_key/value_partner ########## tests/test_schema.py: ########## @@ -912,3 +914,668 @@ def test_promotion(file_type: IcebergType, read_type: IcebergType) -> None: else: with pytest.raises(ResolveError): promote(file_type, read_type) + + +@pytest.fixture() +def primitive_fields() -> List[NestedField]: + return [ + NestedField(field_id=1, name=str(primitive_type), field_type=primitive_type, required=False) + for primitive_type in TEST_PRIMITIVE_TYPES + ] + + +def test_add_top_level_primitives(primitive_fields: NestedField) -> None: + for primitive_field in primitive_fields: + new_schema = Schema(primitive_field) + applied = UpdateSchema(None, schema=Schema()).union_by_name(new_schema)._apply() + assert applied == new_schema + + +def test_add_top_level_list_of_primitives(primitive_fields: NestedField) -> None: + for primitive_type in TEST_PRIMITIVE_TYPES: + new_schema = Schema( + NestedField( + field_id=1, + name="aList", + field_type=ListType(element_id=2, element_type=primitive_type, element_required=False), + required=False, + ) + ) + applied = UpdateSchema(None, schema=Schema()).union_by_name(new_schema)._apply() + assert applied.as_struct() == new_schema.as_struct() + + +def test_add_top_level_map_of_primitives(primitive_fields: NestedField) -> None: + for primitive_type in TEST_PRIMITIVE_TYPES: + new_schema = Schema( + NestedField( + field_id=1, + name="aMap", + field_type=MapType( + key_id=2, key_type=primitive_type, value_id=3, value_type=primitive_type, value_required=False + ), + required=False, + ) + ) + applied = UpdateSchema(None, schema=Schema()).union_by_name(new_schema)._apply() + assert applied.as_struct() == new_schema.as_struct() + + +def test_add_top_struct_of_primitives(primitive_fields: NestedField) -> None: + for primitive_type in TEST_PRIMITIVE_TYPES: + new_schema = Schema( + NestedField( + field_id=1, + name="aStruct", + field_type=StructType(NestedField(field_id=2, name="primitive", field_type=primitive_type, required=False)), + required=False, + ) + ) + applied = UpdateSchema(None, schema=Schema()).union_by_name(new_schema)._apply() + assert applied.as_struct() == new_schema.as_struct() + + +def test_add_nested_primitive(primitive_fields: NestedField) -> None: + for primitive_type in TEST_PRIMITIVE_TYPES: + current_schema = Schema(NestedField(field_id=1, name="aStruct", field_type=StructType(), required=False)) + new_schema = Schema( + NestedField( + field_id=1, + name="aStruct", + field_type=StructType(NestedField(field_id=2, name="primitive", field_type=primitive_type, required=False)), + required=False, + ) + ) + applied = UpdateSchema(None, schema=current_schema).union_by_name(new_schema)._apply() + assert applied.as_struct() == new_schema.as_struct() + + +def _primitive_fields(types: List[PrimitiveType], start_id: int = 0) -> List[NestedField]: + fields = [] + for iceberg_type in types: + fields.append(NestedField(field_id=start_id, name=str(iceberg_type), field_type=iceberg_type, required=False)) + start_id = start_id + 1 + + return fields + + +def test_add_nested_primitives(primitive_fields: NestedField) -> None: + current_schema = Schema(NestedField(field_id=1, name="aStruct", field_type=StructType(), required=False)) + new_schema = Schema( + NestedField( + field_id=1, name="aStruct", field_type=StructType(*_primitive_fields(TEST_PRIMITIVE_TYPES, 2)), required=False + ) + ) + applied = UpdateSchema(None, schema=current_schema).union_by_name(new_schema)._apply() + assert applied.as_struct() == new_schema.as_struct() + + +def test_add_nested_lists(primitive_fields: NestedField) -> None: + new_schema = Schema( + NestedField( + field_id=1, + name="aList", + type=ListType( + element_id=2, + element_type=ListType( + element_id=3, Review Comment: Out of curiosity, did you manually format the schema in this style or the `make lint` did the job? This looks nice, especially in case of multi-level nested schema 😄 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org For additional commands, e-mail: issues-h...@iceberg.apache.org