HonahX commented on code in PR #358: URL: https://github.com/apache/iceberg-python/pull/358#discussion_r1477467681
########## pyiceberg/io/pyarrow.py: ########## @@ -1745,14 +1747,42 @@ def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]: key_metadata=None, ) - if len(collected_metrics) != 1: - # One file has been written - raise ValueError(f"Expected 1 entry, got: {collected_metrics}") - fill_parquet_file_metadata( data_file=data_file, - parquet_metadata=collected_metrics[0], + parquet_metadata=writer.writer.metadata, stats_columns=compute_statistics_plan(table.schema(), table.properties), parquet_column_mapping=parquet_path_to_id_mapping(table.schema()), ) return iter([data_file]) + + +def _get_parquet_writer_kwargs(table_properties: Properties) -> Dict[str, Any]: + def _get_int(key: str) -> Optional[int]: + if value := table_properties.get(key): + try: + return int(value) + except ValueError as e: + raise ValueError(f"Could not parse table property {key} to an integer: {value}") from e + else: + return None + + for key_pattern in [ + "write.parquet.row-group-size-bytes", + "write.parquet.page-row-limit", + "write.parquet.bloom-filter-max-bytes", + "write.parquet.bloom-filter-enabled.column.*", + ]: + if unsupported_keys := fnmatch.filter(table_properties, key_pattern): + raise NotImplementedError(f"Parquet writer option(s) {unsupported_keys} not implemented") + + compression_codec = table_properties.get("write.parquet.compression-codec") Review Comment: ```suggestion compression_codec = table_properties.get("write.parquet.compression-codec", "zstd") ``` How about adding the default value here? RestCatalog backend and HiveCatalog explicitly set the default codec at catalog level. https://github.com/apache/iceberg-python/blob/02e64300aee376a76c175be253a29dcd7c31f0cc/pyiceberg/catalog/hive.py#L158 But other catalogs, such as `glue` and `sql`, do not set this explicitly when creating new tables. In general, for tables that have no `write.parquet.compression-codec` key in its property, we still want to use the default codec `zstd` when writing parquet. ########## tests/integration/test_writes.py: ########## @@ -490,15 +419,103 @@ def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_w @pytest.mark.integration -def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: - identifier = "default.arrow_data_files" +@pytest.mark.parametrize("format_version", ["1", "2"]) +@pytest.mark.parametrize( + "properties, expected_compression_name", + [ + # REST catalog uses Zstandard by default: https://github.com/apache/iceberg/pull/8593 + ({}, "ZSTD"), + ({"write.parquet.compression-codec": "uncompressed"}, "UNCOMPRESSED"), + ({"write.parquet.compression-codec": "gzip", "write.parquet.compression-level": "1"}, "GZIP"), + ({"write.parquet.compression-codec": "zstd", "write.parquet.compression-level": "1"}, "ZSTD"), + ({"write.parquet.compression-codec": "snappy"}, "SNAPPY"), + ], +) +def test_write_parquet_compression_properties( + spark: SparkSession, + session_catalog: Catalog, + arrow_table_with_null: pa.Table, + format_version: str, + properties: Dict[str, Any], + expected_compression_name: str, +) -> None: + identifier = "default.write_parquet_compression_properties" + + tbl = _create_table(session_catalog, identifier, {"format-version": format_version, **properties}, [arrow_table_with_null]) + + data_file_paths = [task.file.file_path for task in tbl.scan().plan_files()] + + fs = S3FileSystem( + endpoint_override=session_catalog.properties["s3.endpoint"], + access_key=session_catalog.properties["s3.access-key-id"], + secret_key=session_catalog.properties["s3.secret-access-key"], + ) + uri = urlparse(data_file_paths[0]) + with fs.open_input_file(f"{uri.netloc}{uri.path}") as f: + parquet_metadata = pq.read_metadata(f) + compression = parquet_metadata.row_group(0).column(0).compression + + assert compression == expected_compression_name + + +@pytest.mark.integration +@pytest.mark.integration +@pytest.mark.parametrize( + "properties, expected_kwargs", + [ + ({"write.parquet.page-size-bytes": "42"}, {"data_page_size": 42}), + ({"write.parquet.dict-size-bytes": "42"}, {"dictionary_pagesize_limit": 42}), + ], +) +def test_write_parquet_other_properties( + mocker: MockerFixture, + spark: SparkSession, + session_catalog: Catalog, + arrow_table_with_null: pa.Table, + properties: Dict[str, Any], + expected_kwargs: Dict[str, Any], +) -> None: + print(type(mocker)) + identifier = "default.test_write_parquet_other_properties" + + # The properties we test cannot be checked on the resulting Parquet file, so we spy on the ParquetWriter call instead + ParquetWriter = mocker.spy(pq, "ParquetWriter") + _create_table(session_catalog, identifier, properties, [arrow_table_with_null]) + + call_kwargs = ParquetWriter.call_args[1] + for key, value in expected_kwargs.items(): + assert call_kwargs.get(key) == value + + +@pytest.mark.integration +@pytest.mark.integration +@pytest.mark.integration Review Comment: Seems we have 2 extra `@pytest.mark.integration` ########## tests/integration/test_writes.py: ########## @@ -490,15 +419,103 @@ def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_w @pytest.mark.integration -def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: - identifier = "default.arrow_data_files" +@pytest.mark.parametrize("format_version", ["1", "2"]) +@pytest.mark.parametrize( + "properties, expected_compression_name", + [ + # REST catalog uses Zstandard by default: https://github.com/apache/iceberg/pull/8593 + ({}, "ZSTD"), + ({"write.parquet.compression-codec": "uncompressed"}, "UNCOMPRESSED"), + ({"write.parquet.compression-codec": "gzip", "write.parquet.compression-level": "1"}, "GZIP"), + ({"write.parquet.compression-codec": "zstd", "write.parquet.compression-level": "1"}, "ZSTD"), + ({"write.parquet.compression-codec": "snappy"}, "SNAPPY"), + ], +) +def test_write_parquet_compression_properties( + spark: SparkSession, + session_catalog: Catalog, + arrow_table_with_null: pa.Table, + format_version: str, + properties: Dict[str, Any], + expected_compression_name: str, +) -> None: + identifier = "default.write_parquet_compression_properties" + + tbl = _create_table(session_catalog, identifier, {"format-version": format_version, **properties}, [arrow_table_with_null]) + + data_file_paths = [task.file.file_path for task in tbl.scan().plan_files()] + + fs = S3FileSystem( + endpoint_override=session_catalog.properties["s3.endpoint"], + access_key=session_catalog.properties["s3.access-key-id"], + secret_key=session_catalog.properties["s3.secret-access-key"], + ) + uri = urlparse(data_file_paths[0]) + with fs.open_input_file(f"{uri.netloc}{uri.path}") as f: + parquet_metadata = pq.read_metadata(f) + compression = parquet_metadata.row_group(0).column(0).compression + + assert compression == expected_compression_name + + +@pytest.mark.integration +@pytest.mark.integration Review Comment: ```suggestion ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org For additional commands, e-mail: issues-h...@iceberg.apache.org