Fokko commented on code in PR #6437:
URL: https://github.com/apache/iceberg/pull/6437#discussion_r1056874376


##########
python/pyiceberg/io/pyarrow.py:
##########
@@ -437,3 +465,198 @@ def visit_or(self, left_result: pc.Expression, 
right_result: pc.Expression) -> p
 
 def expression_to_pyarrow(expr: BooleanExpression) -> pc.Expression:
     return boolean_expression_visit(expr, _ConvertToArrowExpression())
+
+
+def project_table(
+    files: Iterable[FileScanTask], table: Table, row_filter: 
BooleanExpression, projected_schema: Schema, case_sensitive: bool
+) -> pa.Table:
+    """Resolves the right columns based on the identifier
+
+    Args:
+        files(Iterable[FileScanTask]): A URI or a path to a local file
+        table(Table): The table that's being queried
+        row_filter(BooleanExpression): The expression for filtering rows
+        projected_schema(Schema): The output schema
+        case_sensitive(bool): Case sensitivity when looking up column names
+
+    Raises:
+        ResolveException: When an incompatible query is done
+    """
+
+    if isinstance(table.io, PyArrowFileIO):
+        scheme, path = PyArrowFileIO.parse_location(table.location())
+        fs = table.io.get_fs(scheme)
+    else:
+        raise ValueError(f"Expected PyArrowFileIO, got: {table.io}")
+
+    bound_row_filter = bind(table.schema(), row_filter, 
case_sensitive=case_sensitive)
+
+    projected_field_ids = {
+        id for id in projected_schema.field_ids if not 
isinstance(projected_schema.find_type(id), (MapType, ListType))
+    }.union(extract_field_ids(bound_row_filter))
+
+    tables = []
+    for task in files:
+        _, path = PyArrowFileIO.parse_location(task.file.file_path)
+
+        # Get the schema
+        with fs.open_input_file(path) as fout:
+            parquet_schema = pq.read_schema(fout)
+            schema_raw = parquet_schema.metadata.get(ICEBERG_SCHEMA)
+            file_schema = Schema.parse_raw(schema_raw)
+
+        pyarrow_filter = None
+        if row_filter is not AlwaysTrue():
+            translated_row_filter = translate_column_names(bound_row_filter, 
file_schema, case_sensitive=case_sensitive)
+            bound_row_filter = bind(file_schema, translated_row_filter, 
case_sensitive=case_sensitive)
+            pyarrow_filter = expression_to_pyarrow(bound_row_filter)
+
+        file_project_schema = prune_columns(file_schema, projected_field_ids, 
select_full_types=False)
+
+        if file_schema is None:
+            raise ValueError(f"Missing Iceberg schema in Metadata for file: 
{path}")
+
+        # Prune the stuff that we don't need anyway
+        file_project_schema_arrow = schema_to_pyarrow(file_project_schema)
+
+        arrow_table = ds.dataset(
+            source=[path], schema=file_project_schema_arrow, 
format=ds.ParquetFileFormat(), filesystem=fs
+        ).to_table(filter=pyarrow_filter)
+
+        tables.append(to_requested_schema(projected_schema, 
file_project_schema, arrow_table))
+
+    if len(tables) > 1:
+        return pa.concat_tables(tables)
+    else:
+        return tables[0]
+
+
+def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: 
pa.Table) -> pa.Table:
+    return VisitWithArrow(requested_schema, file_schema, table).visit()
+
+
+class VisitWithArrow:
+    requested_schema: Schema
+    file_schema: Schema
+    table: pa.Table
+
+    def __init__(self, requested_schema: Schema, file_schema: Schema, table: 
pa.Table) -> None:
+        self.requested_schema = requested_schema
+        self.file_schema = file_schema
+        self.table = table
+
+    def visit(self) -> pa.Table:
+        return self.visit_with_arrow(self.requested_schema, self.file_schema)
+
+    @singledispatchmethod
+    def visit_with_arrow(self, requested_schema: Union[Schema, IcebergType], 
file_schema: Union[Schema, IcebergType]) -> pa.Table:
+        """A generic function for applying a schema visitor to any point 
within a schema
+
+        The function traverses the schema in post-order fashion
+
+        Args:
+            obj(Schema | IcebergType): An instance of a Schema or an 
IcebergType
+            visitor (VisitWithArrow[T]): An instance of an implementation of 
the generic VisitWithArrow base class
+
+        Raises:
+            NotImplementedError: If attempting to visit an unrecognized object 
type
+        """
+        raise NotImplementedError(f"Cannot visit non-type: {requested_schema}")
+
+    @visit_with_arrow.register(Schema)
+    def _(self, requested_schema: Schema, file_schema: Schema) -> pa.Table:
+        """Visit a Schema with a concrete SchemaVisitorWithPartner"""
+        struct_result = self.visit_with_arrow(requested_schema.as_struct(), 
file_schema.as_struct())
+        pyarrow_schema = schema_to_pyarrow(requested_schema)
+        return pa.Table.from_arrays(struct_result.flatten(), 
schema=pyarrow_schema)
+
+    def _get_field_by_id(self, field_id: int) -> Optional[NestedField]:
+        try:
+            return self.file_schema.find_field(field_id)
+        except ValueError:
+            # Field is not in the file
+            return None
+
+    @visit_with_arrow.register(StructType)
+    def _(self, requested_struct: StructType, file_struct: 
Optional[IcebergType]) -> pa.Array:  # pylint: disable=unused-argument
+        """Visit a StructType with a concrete SchemaVisitorWithPartner"""
+        results = []
+
+        for requested_field in requested_struct.fields:
+            file_field = self._get_field_by_id(requested_field.field_id)
+
+            if file_field is None and requested_field.required:
+                raise ResolveException(f"Field is required, and could not be 
found in the file: {requested_field}")
+
+            results.append(self.visit_with_arrow(requested_field.field_type, 
file_field))
+
+        pyarrow_schema = schema_to_pyarrow(requested_struct)
+        return pa.StructArray.from_arrays(arrays=results, 
fields=pyarrow_schema)
+
+    @visit_with_arrow.register(ListType)
+    def _(self, requested_list: ListType, file_field: Optional[NestedField]) 
-> pa.Array:
+        """Visit a ListType with a concrete SchemaVisitorWithPartner"""
+
+        if file_field is not None:
+            if not isinstance(file_field.field_type, ListType):
+                raise ValueError(f"Expected list, got: {file_field}")
+
+            return self.visit_with_arrow(requested_list.element_type, 
self._get_field_by_id(file_field.field_type.element_id))
+        else:
+            # Not in the file, fill in with nulls
+            return pa.nulls(len(self.table), 
type=pa.list_(schema_to_pyarrow(requested_list.element_type)))
+
+    @visit_with_arrow.register(MapType)
+    def _(self, requested_map: MapType, file_map: Optional[NestedField]) -> 
pa.Array:
+        """Visit a MapType with a concrete SchemaVisitorWithPartner"""
+
+        if file_map is not None:
+            if not isinstance(file_map.field_type, MapType):
+                raise ValueError(f"Expected map, got: {file_map}")
+
+            key = self._get_field_by_id(file_map.field_type.key_id)
+            return self.visit_with_arrow(requested_map.key_type, key)
+        else:
+            # Not in the file, fill in with nulls
+            return pa.nulls(
+                len(self.table),
+                type=pa.map_(schema_to_pyarrow(requested_map.key_type), 
schema_to_pyarrow(requested_map.value_type)),
+            )
+
+    def _get_column_data(self, file_field: NestedField) -> pa.Array:
+        column_name = self.file_schema.find_column_name(file_field.field_id)
+        column_data = self.table
+        struct_schema = self.table.schema
+
+        if column_name is None:
+            # Should not happen
+            raise ValueError(f"Could not find column: {column_name}")
+
+        column_parts = list(reversed(column_name.split(".")))
+        while len(column_parts) > 1:
+            part = column_parts.pop()
+            column_data = column_data.column(part)
+            struct_schema = 
struct_schema[struct_schema.get_field_index(part)].type
+
+        if not isinstance(struct_schema, (pa.ListType, pa.MapType)):
+            # PyArrow does not have an element
+            idx = struct_schema.get_field_index(column_parts.pop())
+            column_data = column_data.flatten()[idx]
+
+        return column_data.combine_chunks()

Review Comment:
   I have the same concern. I think we should be able to avoid this using the 
lower-level buffer API, but still have to dig into that. I wanted to get it to 
work first, and then we can make it fast :)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org
For additional commands, e-mail: issues-h...@iceberg.apache.org

Reply via email to