amogh-jahagirdar commented on code in PR #245:
URL: https://github.com/apache/iceberg-python/pull/245#discussion_r1465805383


##########
pyiceberg/table/__init__.py:
##########
@@ -2271,3 +2317,244 @@ def commit(self) -> Snapshot:
             )
 
         return snapshot
+
+
+class UpdateSpec:
+    _table: Table
+    _schema: Schema
+    _spec: PartitionSpec
+    _name_to_field: Dict[str, PartitionField] = {}
+    _name_to_added_field: Dict[str, PartitionField] = {}
+    _transform_to_field: Dict[Tuple[int, str], PartitionField] = {}
+    _transform_to_added_field: Dict[Tuple[int, str], PartitionField] = {}
+    _renames: Dict[str, str] = {}
+    _added_time_fields: Dict[int, PartitionField] = {}
+    _case_sensitive: bool
+    _adds: List[PartitionField]
+    _deletes: Set[int]
+    _last_assigned_partition_id: int
+    _transaction: Optional[Transaction]
+    _unassigned_field_name = 'unassigned_field_name'
+
+    def __init__(self, table: Table, transaction: Optional[Transaction] = 
None, case_sensitive: bool = True) -> None:
+        self._table = table
+        self._schema = table.schema()
+        self._spec = table.spec()
+        self._name_to_field = {field.name: field for field in 
self._spec.fields}
+        self._name_to_added_field = {}
+        self._transform_to_field = {(field.source_id, repr(field.transform)): 
field for field in self._spec.fields}
+        self._transform_to_added_field = {}
+        self._adds = []
+        self._deletes = set()
+        if len(table.specs()) == 1:
+            self._last_assigned_partition_id = PARTITION_FIELD_ID_START - 1
+        else:
+            self._last_assigned_partition_id = 
table.spec().last_assigned_field_id
+        self._renames = {}
+        self._transaction = transaction
+        self._case_sensitive = case_sensitive
+        self._added_time_fields = {}
+
+    def add_field(

Review Comment:
   Need to add API docs



##########
pyiceberg/partitioning.py:
##########
@@ -85,6 +91,20 @@ def __str__(self) -> str:
         """Return the string representation of the PartitionField class."""
         return f"{self.field_id}: {self.name}: 
{self.transform}({self.source_id})"
 
+    def __hash__(self) -> int:
+        """Return the hash of the partition field."""
+        return hash((self.name, self.source_id, self.field_id, 
repr(self.transform)))

Review Comment:
   I actually don't think we need hash and eq implementations; when I was 
initially implementing the dupe detection logic I had a dictionary keyed on the 
whole field but we can do everything just via the field IDs, names, and the 
transform name as part of various tuples. That should remove the need for 
implementing these methods at this moment. Let me know if that makes sense.



##########
pyiceberg/partitioning.py:
##########
@@ -85,6 +91,20 @@ def __str__(self) -> str:
         """Return the string representation of the PartitionField class."""
         return f"{self.field_id}: {self.name}: 
{self.transform}({self.source_id})"
 
+    def __hash__(self) -> int:
+        """Return the hash of the partition field."""
+        return hash((self.name, self.source_id, self.field_id, 
repr(self.transform)))
+
+    def __eq__(self, other: Any) -> bool:
+        """Return True if two partition fields are considered equal, False 
otherwise."""
+        return (
+            isinstance(other, PartitionField)
+            and other.field_id == self.field_id
+            and other.name == self.name
+            and other.source_id == self.source_id
+            and repr(other.transform) == repr(self.transform)

Review Comment:
   See the above comment, I don't think we need _eq_ and _hash_.



##########
pyiceberg/table/__init__.py:
##########
@@ -2271,3 +2317,244 @@ def commit(self) -> Snapshot:
             )
 
         return snapshot
+
+
+class UpdateSpec:
+    _table: Table
+    _schema: Schema
+    _spec: PartitionSpec
+    _name_to_field: Dict[str, PartitionField] = {}
+    _name_to_added_field: Dict[str, PartitionField] = {}
+    _transform_to_field: Dict[Tuple[int, str], PartitionField] = {}
+    _transform_to_added_field: Dict[Tuple[int, str], PartitionField] = {}
+    _renames: Dict[str, str] = {}
+    _added_time_fields: Dict[int, PartitionField] = {}
+    _case_sensitive: bool
+    _adds: List[PartitionField]
+    _deletes: Set[int]
+    _last_assigned_partition_id: int
+    _transaction: Optional[Transaction]
+    _unassigned_field_name = 'unassigned_field_name'
+
+    def __init__(self, table: Table, transaction: Optional[Transaction] = 
None, case_sensitive: bool = True) -> None:
+        self._table = table
+        self._schema = table.schema()
+        self._spec = table.spec()
+        self._name_to_field = {field.name: field for field in 
self._spec.fields}
+        self._name_to_added_field = {}
+        self._transform_to_field = {(field.source_id, repr(field.transform)): 
field for field in self._spec.fields}
+        self._transform_to_added_field = {}
+        self._adds = []
+        self._deletes = set()
+        if len(table.specs()) == 1:
+            self._last_assigned_partition_id = PARTITION_FIELD_ID_START - 1
+        else:
+            self._last_assigned_partition_id = 
table.spec().last_assigned_field_id
+        self._renames = {}
+        self._transaction = transaction
+        self._case_sensitive = case_sensitive
+        self._added_time_fields = {}
+
+    def add_field(
+        self, partition_field_name: Optional[str], source_column_name: str, 
transform: Transform[Any, Any]
+    ) -> UpdateSpec:
+        ref = Reference(source_column_name)
+        bound_ref = ref.bind(self._schema, self._case_sensitive)
+        # verify transform can actually bind it
+        output_type = bound_ref.field.field_type
+        if not transform.can_transform(output_type):
+            raise ValueError(f"{transform} cannot transform {output_type} 
values from {bound_ref.field.name}")
+
+        transform_key = (bound_ref.field.field_id, repr(transform))
+        existing_partition_field = self._transform_to_field.get(transform_key)
+        if existing_partition_field and 
self._is_duplicate_partition(transform, existing_partition_field):
+            raise ValueError(f"Duplicate partition field for 
${ref.name}=${ref}, ${existing_partition_field} already exists")
+
+        added = self._transform_to_added_field.get(transform_key)
+        if added:
+            raise ValueError(f"Already added partition {added.name}")
+
+        new_field = self._partition_field((bound_ref.field.field_id, 
transform), partition_field_name)
+        if new_field.name == self._unassigned_field_name:
+            name = _visit_field(self._schema, new_field, 
_PartitionNameGenerator())
+            new_field = PartitionField(new_field.source_id, 
new_field.field_id, new_field.transform, name)
+
+        if new_field.name in self._name_to_added_field:
+            raise ValueError(f"Already added partition field with name: 
{new_field.name}")
+
+        self._redundant_time_partition(new_field)
+        self._transform_to_added_field[transform_key] = new_field
+
+        existing_partition_field = self._name_to_field.get(new_field.name)
+        if existing_partition_field and new_field.field_id not in 
self._deletes:
+            if isinstance(existing_partition_field.transform, VoidTransform):
+                self.rename_field(
+                    existing_partition_field.name, 
existing_partition_field.name + "_" + str(existing_partition_field.field_id)
+                )
+            else:
+                raise ValueError(f"Cannot add duplicate partition field name: 
{existing_partition_field.name}")
+
+        self._name_to_added_field[new_field.name] = new_field
+        self._adds.append(new_field)
+        return self
+
+    def add_identity(self, source_column_name: str) -> UpdateSpec:
+        return self.add_field(self._unassigned_field_name, source_column_name, 
IdentityTransform())
+
+    def remove_field(self, name: str) -> UpdateSpec:
+        added = self._name_to_added_field.get(name)
+        if added:
+            raise ValueError(f"Cannot delete newly added field {name}")
+        renamed = self._renames.get(name)
+        if renamed:
+            raise ValueError(f"Cannot rename and delete field {name}")
+        field = self._name_to_field.get(name)
+        if not field:
+            raise ValueError(f"No such partition field: {name}")
+
+        self._deletes.add(field.field_id)
+        return self
+
+    def rename_field(self, name: str, new_name: str) -> UpdateSpec:
+        existing_field = self._name_to_field.get(new_name)
+        if existing_field and isinstance(existing_field.transform, 
VoidTransform):
+            return self.rename_field(name, name + "_" + 
str(existing_field.field_id))
+        added = self._name_to_added_field.get(name)
+        if added:
+            raise ValueError("Cannot rename recently added partitions")
+        field = self._name_to_field.get(name)
+        if not field:
+            raise ValueError(f"Cannot find partition field {name}")
+        if field.field_id in self._deletes:
+            raise ValueError(f"Cannot delete and rename partition field 
{name}")
+        self._renames[name] = new_name
+        return self
+
+    def commit(self) -> None:
+        new_spec = self._apply()
+        updates = []
+        requirements = []
+        if self._table.metadata.default_spec_id != new_spec.spec_id:
+            if new_spec.spec_id not in self._table.specs():
+                spec_update = AddPartitionSpecUpdate(spec=new_spec)
+                updates.append(spec_update)
+                if len(self._table.specs()) == 1:
+                    required_last_assigned_partitioned_id = 
PARTITION_FIELD_ID_START - 1
+                else:
+                    required_last_assigned_partitioned_id = 
self._table.spec().last_assigned_field_id
+                requirements.append(
+                    
AssertLastAssignedPartitionId(last_assigned_partition_id=required_last_assigned_partitioned_id)
+                )
+            if self._transaction:
+                self._transaction._append_updates(*updates)  # pylint: 
disable=W0212
+                self._transaction._append_requirements(*requirements)  # 
pylint: disable=W0212
+            else:
+                updates.append(SetDefaultSpecUpdate(spec_id=new_spec.spec_id))

Review Comment:
   Ooh good catch this is a bug. I'm missing a test for transaction which 
would've caught this.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org
For additional commands, e-mail: issues-h...@iceberg.apache.org

Reply via email to