jqin61 commented on code in PR #555:
URL: https://github.com/apache/iceberg-python/pull/555#discussion_r1544770143


##########
pyiceberg/table/__init__.py:
##########
@@ -3108,3 +3138,127 @@ def snapshots(self) -> "pa.Table":
             snapshots,
             schema=snapshots_schema,
         )
+
+
+@dataclass(frozen=True)
+class TablePartition:
+    partition_key: PartitionKey
+    arrow_table_partition: pa.Table
+
+
+def _get_partition_sort_order(partition_columns: list[str], reverse: bool = 
False) -> dict[str, Any]:
+    order = 'ascending' if not reverse else 'descending'
+    null_placement = 'at_start' if reverse else 'at_end'
+    return {'sort_keys': [(column_name, order) for column_name in 
partition_columns], 'null_placement': null_placement}
+
+
+def group_by_partition_scheme(
+    iceberg_table_metadata: TableMetadata, arrow_table: pa.Table, 
partition_columns: list[str]
+) -> pa.Table:
+    """Given a table sort it by current partition scheme with all transform 
functions supported."""
+    from pyiceberg.transforms import IdentityTransform
+
+    supported = {IdentityTransform}
+    if not all(
+        type(field.transform) in supported for field in 
iceberg_table_metadata.spec().fields if field in partition_columns
+    ):
+        raise ValueError(
+            f"Not all transforms are supported, get: {[transform in supported 
for transform in iceberg_table_metadata.spec().fields]}."
+        )
+
+    # only works for identity
+    sort_options = _get_partition_sort_order(partition_columns, reverse=False)
+    sorted_arrow_table = 
arrow_table.sort_by(sorting=sort_options['sort_keys'], 
null_placement=sort_options['null_placement'])
+    return sorted_arrow_table
+
+
+def get_partition_columns(iceberg_table_metadata: TableMetadata, arrow_table: 
pa.Table) -> list[str]:
+    arrow_table_cols = set(arrow_table.column_names)
+    partition_cols = []
+    for transform_field in iceberg_table_metadata.spec().fields:
+        column_name = 
iceberg_table_metadata.schema().find_column_name(transform_field.source_id)
+        if not column_name:
+            raise ValueError(f"{transform_field=} could not be found in 
{iceberg_table_metadata.schema()}.")
+        if column_name not in arrow_table_cols:
+            continue
+        partition_cols.append(column_name)
+    return partition_cols
+
+
+def _get_table_partitions(
+    arrow_table: pa.Table,
+    partition_spec: PartitionSpec,
+    schema: Schema,
+    slice_instructions: list[dict[str, Any]],
+) -> list[TablePartition]:
+    sorted_slice_instructions = sorted(slice_instructions, key=lambda x: 
x['offset'])
+
+    partition_fields = partition_spec.fields
+
+    offsets = [inst["offset"] for inst in sorted_slice_instructions]
+    projected_and_filtered = {
+        partition_field.source_id: 
arrow_table[schema.find_field(name_or_id=partition_field.source_id).name]
+        .take(offsets)
+        .to_pylist()
+        for partition_field in partition_fields
+    }
+
+    table_partitions = []
+    for inst in sorted_slice_instructions:
+        partition_slice = arrow_table.slice(**inst)
+        fieldvalues = [
+            PartitionFieldValue(partition_field, 
projected_and_filtered[partition_field.source_id][inst["offset"]])
+            for partition_field in partition_fields
+        ]
+        partition_key = PartitionKey(raw_partition_field_values=fieldvalues, 
partition_spec=partition_spec, schema=schema)
+        table_partitions.append(TablePartition(partition_key=partition_key, 
arrow_table_partition=partition_slice))
+
+    return table_partitions
+
+
+def partition(iceberg_table_metadata: TableMetadata, arrow_table: pa.Table) -> 
Iterable[TablePartition]:
+    """Based on the iceberg table partition spec, slice the arrow table into 
partitions with their keys.
+
+    Example:
+    Input:
+    An arrow table with partition key of ['n_legs', 'year'] and with data of
+    {'year': [2020, 2022, 2022, 2021, 2022, 2022, 2022, 2019, 2021],
+     'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100],
+     'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", 
"Horse","Brittle stars", "Centipede"]}.
+    The algrithm:
+    Firstly we group the rows into partitions by sorting with sort order 
[('n_legs', 'descending'), ('year', 'descending')]
+    and null_placement of "at_end".
+    This gives the same table as raw input.
+    Then we sort_indices using reverse order of [('n_legs', 'descending'), 
('year', 'descending')]
+    and null_placement : "at_start".
+    This gives:
+    [8, 7, 4, 5, 6, 3, 1, 2, 0]
+    Based on this we get partition groups of indices:
+    [{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 
'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, 
{'offset': 0, 'length': 1}]
+    We then retrieve the partition keys by offsets.
+    And slice the arrow table by offsets and lengths of each partition.
+    """
+    import pyarrow as pa
+
+    partition_columns = get_partition_columns(iceberg_table_metadata, 
arrow_table)

Review Comment:
   it will be more useful when there are hidden partition columns. And the 
check is also for mypy check because find_column_name returns optional[str]



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org
For additional commands, e-mail: issues-h...@iceberg.apache.org

Reply via email to