JonasJ-ap commented on code in PR #6997:
URL: https://github.com/apache/iceberg/pull/6997#discussion_r1179568693


##########
python/pyiceberg/io/pyarrow.py:
##########
@@ -486,6 +499,195 @@ def expression_to_pyarrow(expr: BooleanExpression) -> 
pc.Expression:
     return boolean_expression_visit(expr, _ConvertToArrowExpression())
 
 
+def pyarrow_to_schema(schema: pa.Schema) -> Schema:
+    visitor = _ConvertToIceberg()
+    return visit_pyarrow(schema, visitor)
+
+
+@singledispatch
+def visit_pyarrow(obj: pa.DataType | pa.Schema, visitor: 
PyArrowSchemaVisitor[T]) -> T:
+    """A generic function for applying a pyarrow schema visitor to any point 
within a schema
+
+    The function traverses the schema in post-order fashion
+
+    Args:
+        obj(pa.DataType): An instance of a Schema or an IcebergType
+        visitor (PyArrowSchemaVisitor[T]): An instance of an implementation of 
the generic PyarrowSchemaVisitor base class
+
+    Raises:
+        NotImplementedError: If attempting to visit an unrecognized object type
+    """
+    raise NotImplementedError("Cannot visit non-type: %s" % obj)
+
+
+@visit_pyarrow.register(pa.Schema)
+def _(obj: pa.Schema, visitor: PyArrowSchemaVisitor[T]) -> Optional[T]:
+    struct_results: List[Optional[T]] = []
+    for field in obj:
+        visitor.before_field(field)
+        struct_result = visit_pyarrow(field.type, visitor)
+        visitor.after_field(field)
+        struct_results.append(struct_result)
+
+    return visitor.schema(obj, struct_results)
+
+
+@visit_pyarrow.register(pa.StructType)
+def _(obj: pa.StructType, visitor: PyArrowSchemaVisitor[T]) -> Optional[T]:
+    struct_results: List[Optional[T]] = []
+    for field in obj:
+        visitor.before_field(field)
+        struct_result = visit_pyarrow(field.type, visitor)
+        visitor.after_field(field)
+        struct_results.append(struct_result)
+
+    return visitor.struct(obj, struct_results)
+
+
+@visit_pyarrow.register(pa.ListType)
+def _(obj: pa.ListType, visitor: PyArrowSchemaVisitor[T]) -> Optional[T]:
+    visitor.before_field(obj.value_field)
+    list_result = visit_pyarrow(obj.value_field.type, visitor)
+    visitor.after_field(obj.value_field)
+    return visitor.list(obj, list_result)
+
+
+@visit_pyarrow.register(pa.MapType)
+def _(obj: pa.MapType, visitor: PyArrowSchemaVisitor[T]) -> Optional[T]:
+    visitor.before_field(obj.key_field)
+    key_result = visit_pyarrow(obj.key_field.type, visitor)
+    visitor.after_field(obj.key_field)
+    visitor.before_field(obj.item_field)
+    value_result = visit_pyarrow(obj.item_field.type, visitor)
+    visitor.after_field(obj.item_field)
+    return visitor.map(obj, key_result, value_result)
+
+
+@visit_pyarrow.register(pa.DataType)
+def _(obj: pa.DataType, visitor: PyArrowSchemaVisitor[T]) -> Optional[T]:
+    if pa.types.is_nested(obj):
+        raise TypeError(f"Expected primitive type, got: {type(obj)}")
+    return visitor.primitive(obj)
+
+
+class PyArrowSchemaVisitor(Generic[T], ABC):
+    def before_field(self, field: pa.Field) -> None:
+        """Override this method to perform an action immediately before 
visiting a field."""
+
+    def after_field(self, field: pa.Field) -> None:
+        """Override this method to perform an action immediately after 
visiting a field."""
+
+    @abstractmethod
+    def schema(self, schema: pa.Schema, field_results: List[Optional[T]]) -> 
Optional[T]:
+        """visit a schema"""
+
+    @abstractmethod
+    def struct(self, struct: pa.StructType, field_results: List[Optional[T]]) 
-> Optional[T]:
+        """visit a struct"""
+
+    @abstractmethod
+    def list(self, list_type: pa.ListType, element_result: Optional[T]) -> 
Optional[T]:
+        """visit a list"""
+
+    @abstractmethod
+    def map(self, map_type: pa.MapType, key_result: Optional[T], value_result: 
Optional[T]) -> Optional[T]:
+        """visit a map"""
+
+    @abstractmethod
+    def primitive(self, primitive: pa.DataType) -> Optional[T]:
+        """visit a primitive type"""
+
+
+def _get_field_id(field: pa.Field) -> Optional[int]:
+    for pyarrow_field_id_key in PYARROW_FIELD_ID_KEYS:
+        if field_id_str := field.metadata.get(pyarrow_field_id_key):
+            return int(field_id_str.decode())
+    return None
+
+
+def _get_field_doc(field: pa.Field) -> Optional[str]:
+    for pyarrow_doc_key in PYARROW_FIELD_DOC_KEYS:
+        if doc_str := field.metadata.get(pyarrow_doc_key):
+            return doc_str.decode()
+    return None
+
+
+class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
+    def schema(self, schema: pa.Schema, field_results: 
List[Optional[IcebergType]]) -> Schema:

Review Comment:
   Thank you very much for your suggestions and test. I've added it to the code



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to