Fokko commented on code in PR #245:
URL: https://github.com/apache/iceberg-python/pull/245#discussion_r1441727718


##########
pyiceberg/table/__init__.py:
##########
@@ -1904,3 +1913,200 @@ def _generate_snapshot_id() -> int:
     snapshot_id = snapshot_id if snapshot_id >= 0 else snapshot_id * -1
 
     return snapshot_id
+
+class UpdateSpec:
+    _table: Table
+    _schema: Schema
+    _spec: PartitionSpec
+    _name_to_field: Dict[str, PartitionField]
+    _name_to_added_field: Dict[str, PartitionField]
+    _transform_to_field: Dict[Tuple[int, str], PartitionField]
+    _transform_to_added_field: Dict[Tuple[int, str], PartitionField]
+    _case_sensitive: bool
+    _adds: List[PartitionField]
+    _deletes: Set[int]
+    _last_assigned_partition_id: int
+    _renames = Dict[str, str]
+
+    def __init__(
+        self,
+        table: Table,
+        transaction: Optional[Transaction] = None,
+        case_sensitive = True
+    ) -> None:
+        self._table = table
+        self._schema = table.schema()
+        self._spec = table.spec()
+        self._name_to_field = {field.name:field for field in self._spec.fields}
+        self._name_to_added_field = {}
+        self._transform_to_field = {(field.sourceId, repr(field.transform)): 
field for field in self._spec.fields}
+        self._transform_to_added_field = {}
+        self._adds = []
+        self._deletes = {}
+        self._last_assigned_partition_id = table.spec().last_assigned_field_id
+        self._renames = {}
+        self._transaction = transaction
+        self._case_sensitive = case_sensitive
+
+
+    def add_field(self, name: str, transform: Transform) -> UpdateSpec:
+        ref = Reference(name)
+        bound_ref = ref.bind(self._schema, self._case_sensitive)
+        # verify transform can actually bind it
+        output_type = bound_ref.field.field_type
+        if not transform.can_transform(output_type):
+            raise ValueError(f"{transform} cannot transform {output_type} 
values from {bound_ref.field.name}")
+        transform_key = (bound_ref.field.field_id, transform)
+        existing_partition_field = self._transform_to_field.get(transform)
+        if existing_partition_field and 
self._is_duplicate_partition(transform_key[1], existing_partition_field):
+            raise ValueError(f"Duplicate partition field for 
${ref.name}=${ref}, ${existing_partition_field} already exists")
+        added = self._transform_to_added_field.get(transform_key)
+        if added:
+            raise ValueError(f"Already added partition {added}")
+        new_field = self._partition_field(transform_key, name)
+        if not new_field.name:
+            new_field.name = _visit(self._schema, new_field, 
_PartitionNameGenerator())
+
+        self._check_redundant_partitions(new_field)
+        self._transform_to_added_field[transform_key] = new_field
+        
+        existing_partition_field = self._name_to_field.get(new_field.name)
+        if existing_partition_field and not new_field.field_id in 
self._deletes:
+            if isinstance(existing_partition_field.transform, VoidTransform):
+                self.rename_field(existing_partition_field.name, 
existing_partition_field.name + "_" + existing_partition_field.field_id)
+            else:
+                raise ValueError(f"Cannot add duplicate partition field name: 
{existing_partition_field.name}")
+
+        self._name_to_added_field[new_field.name] = new_field
+        self._adds.append(new_field)
+        return self
+    
+    def remove_field(self, name: str) -> UpdateSpec:
+        added = self._name_to_added_field.get(name)
+        if added:
+            raise ValueError(f"Cannot delete newly added field {name}")
+        renamed = self._renames.get(name)
+        if renamed:
+            raise ValueError(f"Cannot rename and delete field: {name}")
+        field = self._name_to_field.get(name)
+        if not field:
+            raise ValueError(f"No such partition field: {name}")
+
+        self._deletes.add(field.field_id)
+        return self
+
+    def rename_field(self, name: str, new_name: str) -> UpdateSpec:
+        existing_field = self._name_to_field.get(new_name)
+        if existing_field and isinstance(existing_field.transform, 
VoidTransform):
+            return self.rename_field(name, name + "_" + 
existing_field.field_id)
+        added = self._name_to_added_field.get(name)
+        if added:
+            raise ValueError(f"Cannot rename recently added partitions")
+        field = self._name_to_field.get(name)
+        if not field:
+            raise ValueError(f"Cannot find partition field {name}")
+        if field.field_id in self._deletes:
+            raise ValueError(f"Cannot delete and rename partition field 
{name}")
+        self._renames[name] = new_name
+        return self
+    
+    def commit(self):
+        new_spec = self._apply()
+        updates = []
+        requirements = []
+        if self._table.metadata.default_spec_id != new_spec.spec_id:
+            if new_spec.spec_id not in self._table.specs():
+                spec_update = AddPartitionSpecUpdate(spec=new_spec)
+                updates.append(spec_update)
+                
requirements.append(AssertLastAssignedPartitionId(last_assigned_partition_id=self._table.metadata.last_partition_id))
 
+
+            updates.append(SetDefaultSpecUpdate(spec_id=new_spec.spec_id))
+            
requirements.append(AssertDefaultSpecId(default_spec_id=self._table.metadata.default_spec_id))
        
+            self._table._do_commit(updates=updates, requirements=requirements) 
 # pylint: disable=W0212
+            
+
+    def _apply(self) -> PartitionSpec:
+        last_assigned_field_id = PARTITION_FIELD_ID_START - 1
+        partition_fields = []
+        partition_names = set([])
+
+        def _check_and_add_partition_name(schema: Schema, name: str, 
source_id: int, partition_names: Set[int]):
+            field = schema.find_field(name)
+            if source_id and field and field.field_id != source_id:
+                raise ValueError(f"Cannot create identity partition from a 
different field in the schema {name}")
+            elif field:
+                raise ValueError(f"Cannot create partition from name that 
exists in schema {name}")
+            if not name:
+                raise ValueError(f"Undefined empty/none name")
+            if name in partition_names:
+                raise ValueError(f"Cannot use partition name more than once 
{name}")
+            partition_names.add(name)
+
+        def _add_new_field(source_id: int, field_id: int, name: str, 
transform: Transform, partition_names: Set[int]):
+            _check_and_add_partition_name(name, source_id, partition_names)
+            return PartitionField(source_id, field_id, name, transform)
+            
+
+        for field in self._spec.fields:
+            if field.field_id not in self._deletes:
+                renamed = self._renames.get(field.name)
+                if renamed:
+                    new_field = _add_new_field(field.source_id, 
field.field_id, renamed, field.transform, partition_names)
+                else:
+                    new_field = _add_new_field(field.source_id, 
field.field_id, field.name, field.transform, partition_names)
+                last_assigned_field_id = max(last_assigned_field_id, 
new_field.field_id)
+                partition_fields.append(new_field)
+            elif self._table.format_version == 1:
+                renamed = self._renames.get(field.name)
+                if renamed:
+                    new_field = _add_new_field(field.source_id, 
field.field_id, renamed, VoidTransform(), partition_names)
+                else:
+                    new_field = _add_new_field(field.source_id, 
field.field_id, field.name, VoidTransform(), partition_names)
+
+                last_assigned_field_id = max(last_assigned_field_id, 
new_field.field_id)
+                partition_fields.append(new_field)
+        
+        for added_field in self._adds:
+            new_field = PartitionField(source_id=added_field.source_id, 
field_id=added_field.field_id, transform=added_field.transform, 
name=added_field.name)
+            partition_fields.append(new_field)
+        # Reuse spec id or create a new one.
+        new_spec = PartitionSpec(*partition_fields)
+        new_spec_id = INITIAL_PARTITION_SPEC_ID
+        for spec in self._table.specs().values():
+            if new_spec.compatible_with(spec):
+                new_spec_id = spec.spec_id
+                break
+            elif new_spec_id <= spec.spec_id:
+                new_spec_id = spec.spec_id + 1
+
+        return PartitionSpec(*partition_fields, spec_id=new_spec_id)
+    
+    def _check_redundant_partitions(self, field: PartitionField):
+        if isinstance(field.transform, TimeTransform):
+            existing_time_field = self._added_time_fields.get(field.source_id)
+            if existing_time_field:
+                raise ValueError(f"Cannot add time partition field: 
{field.name} conflicts with {existing_time_field.name}")
+            self._added_time_fields[field.source_id] = field
+    
+    def _partition_field(self, transform_key: Tuple[int, Transform], name: 
str) -> PartitionField:
+        if self._table.metadata.format_version == 2 and self._table.metadata:

Review Comment:
   I don't think metadata can be `None`:
   ```suggestion
           if self._table.metadata.format_version == 2:
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org
For additional commands, e-mail: issues-h...@iceberg.apache.org

Reply via email to