jqin61 commented on code in PR #555: URL: https://github.com/apache/iceberg-python/pull/555#discussion_r1544770143
########## pyiceberg/table/__init__.py: ########## @@ -3108,3 +3138,127 @@ def snapshots(self) -> "pa.Table": snapshots, schema=snapshots_schema, ) + + +@dataclass(frozen=True) +class TablePartition: + partition_key: PartitionKey + arrow_table_partition: pa.Table + + +def _get_partition_sort_order(partition_columns: list[str], reverse: bool = False) -> dict[str, Any]: + order = 'ascending' if not reverse else 'descending' + null_placement = 'at_start' if reverse else 'at_end' + return {'sort_keys': [(column_name, order) for column_name in partition_columns], 'null_placement': null_placement} + + +def group_by_partition_scheme( + iceberg_table_metadata: TableMetadata, arrow_table: pa.Table, partition_columns: list[str] +) -> pa.Table: + """Given a table sort it by current partition scheme with all transform functions supported.""" + from pyiceberg.transforms import IdentityTransform + + supported = {IdentityTransform} + if not all( + type(field.transform) in supported for field in iceberg_table_metadata.spec().fields if field in partition_columns + ): + raise ValueError( + f"Not all transforms are supported, get: {[transform in supported for transform in iceberg_table_metadata.spec().fields]}." + ) + + # only works for identity + sort_options = _get_partition_sort_order(partition_columns, reverse=False) + sorted_arrow_table = arrow_table.sort_by(sorting=sort_options['sort_keys'], null_placement=sort_options['null_placement']) + return sorted_arrow_table + + +def get_partition_columns(iceberg_table_metadata: TableMetadata, arrow_table: pa.Table) -> list[str]: + arrow_table_cols = set(arrow_table.column_names) + partition_cols = [] + for transform_field in iceberg_table_metadata.spec().fields: + column_name = iceberg_table_metadata.schema().find_column_name(transform_field.source_id) + if not column_name: + raise ValueError(f"{transform_field=} could not be found in {iceberg_table_metadata.schema()}.") + if column_name not in arrow_table_cols: + continue + partition_cols.append(column_name) + return partition_cols + + +def _get_table_partitions( + arrow_table: pa.Table, + partition_spec: PartitionSpec, + schema: Schema, + slice_instructions: list[dict[str, Any]], +) -> list[TablePartition]: + sorted_slice_instructions = sorted(slice_instructions, key=lambda x: x['offset']) + + partition_fields = partition_spec.fields + + offsets = [inst["offset"] for inst in sorted_slice_instructions] + projected_and_filtered = { + partition_field.source_id: arrow_table[schema.find_field(name_or_id=partition_field.source_id).name] + .take(offsets) + .to_pylist() + for partition_field in partition_fields + } + + table_partitions = [] + for inst in sorted_slice_instructions: + partition_slice = arrow_table.slice(**inst) + fieldvalues = [ + PartitionFieldValue(partition_field, projected_and_filtered[partition_field.source_id][inst["offset"]]) + for partition_field in partition_fields + ] + partition_key = PartitionKey(raw_partition_field_values=fieldvalues, partition_spec=partition_spec, schema=schema) + table_partitions.append(TablePartition(partition_key=partition_key, arrow_table_partition=partition_slice)) + + return table_partitions + + +def partition(iceberg_table_metadata: TableMetadata, arrow_table: pa.Table) -> Iterable[TablePartition]: + """Based on the iceberg table partition spec, slice the arrow table into partitions with their keys. + + Example: + Input: + An arrow table with partition key of ['n_legs', 'year'] and with data of + {'year': [2020, 2022, 2022, 2021, 2022, 2022, 2022, 2019, 2021], + 'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100], + 'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", "Horse","Brittle stars", "Centipede"]}. + The algrithm: + Firstly we group the rows into partitions by sorting with sort order [('n_legs', 'descending'), ('year', 'descending')] + and null_placement of "at_end". + This gives the same table as raw input. + Then we sort_indices using reverse order of [('n_legs', 'descending'), ('year', 'descending')] + and null_placement : "at_start". + This gives: + [8, 7, 4, 5, 6, 3, 1, 2, 0] + Based on this we get partition groups of indices: + [{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, {'offset': 0, 'length': 1}] + We then retrieve the partition keys by offsets. + And slice the arrow table by offsets and lengths of each partition. + """ + import pyarrow as pa + + partition_columns = get_partition_columns(iceberg_table_metadata, arrow_table) Review Comment: it will be more useful when there are hidden partition columns. And the check is also for mypy check because find_column_name returns optional[str] -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org For additional commands, e-mail: issues-h...@iceberg.apache.org