jo-migo opened a new issue, #45467:
URL: https://github.com/apache/arrow/issues/45467

   ### Describe the bug, including details regarding any error messages, 
version, and platform.
   
   TL;DR it seems like `pyarrow.dataset.write_dataset` is approximately 
buffering each entire output file in memory while writing. Is this expected, 
and can this be tuned somehow?
   
   ## Problem Setup
   I have a parquet dataset with the characteristics:
   - Total size 50GB
   - Total rows 25_000_000
   - Contains 25 files, each ~2GB 
   - 1_000_000 rows per file, 10 row groups of size 100_000 each
   
   I am re-writing the dataset into a new location with larger files (but the 
same row group size). In order to do that, I am passing a scanner over the 
original dataset to `write_dataset`. An unimportant implementation detail: the 
scanner is yielding `RecordBatch`es from a custom rust-based iterator over 
`pyarrow.RecordBatches` (this is done with python-rust bindings that use the 
[pyarrow crate](https://docs.rs/arrow/latest/arrow/pyarrow/)). I rule out that 
this is a problem below (see scanner-only benchmarks).
   
   ## Problem
   Just scanning the dataset is not _too_ memory-inefficient, but the RAM 
utilization shoots surprisingly high when actually writing a new dataset with 
`write_dataset`.. One trick that helps reduce memory load a little bit is to 
completely reset the default memory pool after every batch which is scanned 
(see linked code).
   
   All of the following plots are showing additional RSS (`current RSS - RSS at 
program start`) during scanning or writing. 
   
   Calling `write_dataset` with `max_rows_per_file=10_000_000` **without** 
resetting the memory pool after every batch (i.e. with `release_memory=False` 
in the code below) resulted in a peak RSS of almost **26GB** greater than the 
starting RAM:
   
   
![Image](https://github.com/user-attachments/assets/c1f6ab77-03a9-426d-85c7-3d3c993831c9)
   
   The output dataset for the above operation contains 3 files, with 2 of them 
being about 20GB in size and the third being 5GB. So, the above plot makes it 
_seem_ like essentially the entire 2 files are being buffered in memory as they 
are written (see the 2 peaks). 
   
   Performing the same write_dataset operation likewise with 
`max_rows_per_file=10_000_000` but **with** resetting the memory pool after 
every batch (i.e. with `release_memory=True` in the code below) resulted in a 
peak of 2GB but generally lower RSS throughout:
   
   
![Image](https://github.com/user-attachments/assets/03dea8ce-2e0b-411a-859d-8dacbdca6fe0)
   
   And just to make sure that the scanner itself is not accounting for much of 
the memory footprint, here is the RAM utilization from **just** scanning the 
dataset with batch sizes = 1_000, without writing. This uses about 140MB of 
RAM. This example does not reset the memory pool during iteration:
   
   
![Image](https://github.com/user-attachments/assets/d10570a7-4e14-4ce6-a77e-f7bcefe20b6d)
   
   **With** replacing the memory pool after each record batch, the RAM can be 
cut down consistently to around half, ~70MB:
   
   
![Image](https://github.com/user-attachments/assets/c142c20f-6e67-4363-b34f-fa1ff5f12faf)
   
   My question is: is there any way to further tune the write_dataset operation 
(and maybe even the scanner)? I would ideally need to be able to run this 
dataset repartitioning on workers with less than 1GB of RAM available. What is 
being buffered in memory while writing, and is there any way to reduce it?
   
   Here is my benchmarking script – the parameters passed in my benchmarks 
correspond to the values shown in the line plots’ titles.
   
   ## Code
   ```
   import os
   from pathlib import Path
   
   import psutil
   from typing import Iterator
   
   import pyarrow as pa
   from pyarrow import RecordBatch, set_cpu_count, set_io_thread_count
   from pyarrow.dataset import Scanner, dataset as PyArrowDataset, write_dataset
   
   process = psutil.Process(os.getpid())
   
   def scan_or_repartition_dataset(
        initial_rss: int,
        batch_size: int = 1_000,
        output_row_group_size: int = 100_000,
        output_file_size: int = 1_000_000,
        release_memory: bool = True,
        write: bool = True,
   ) -> None:
        run_id = "-".join(map(str, [batch_size, output_row_group_size, 
output_file_size, release_memory, write]))
        print(f"Starting run ID: {run_id}, initial RSS: {initial_rss / 
1_000_000}MB")
        pa.set_memory_pool(pa.system_memory_pool())
        dataset = ParquetDataset(source=source) # rust-based class
   
        def get_next_batch() -> Iterator[RecordBatch]:
        record_count = 0
        for record_batch in dataset.read_row_batches(batch_size=batch_size):  # 
rust-based iterator over pyarrow.RecordBatch
                yield record_batch
                record_count += batch_size
                if release_memory:
                pa.default_memory_pool().release_unused()
                pa.set_memory_pool(pa.system_memory_pool())
                if record_count % 1_000_000 == 0:
                memory_usage = process.memory_info().rss
                print(f"Intermediate RSS difference: {(memory_usage - 
initial_rss) / 1_000_000}MB")
   
        scanner = Scanner.from_batches(
        source=get_next_batch(),
        schema=PyArrowDataset(source).schema,
        batch_size=batch_size,
        )
        if write:
        write_dataset(
                data=scanner,
                format="parquet",
                base_dir=output_path,
                max_rows_per_group=output_row_group_size,
                max_rows_per_file=output_file_size,
                min_rows_per_group=output_row_group_size,
                existing_data_behavior="overwrite_or_ignore",
        )
        else:
        for _ in scanner.to_batches():
                pass
   
   ```
   
   ## Environment
   
   - python 3.10.16
   - pyarrow 19.0.0
   - Debian GNU/Linux 12 (bookworm)
   
   **Component(s)**
   Parquet, Python
   
   ## Additional things
   I have tried this also with the jemalloc and mimalloc memory pools. Those 
run OOM significantly faster unfortunately. 
   
   ### Component(s)
   
   Python


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@arrow.apache.org.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to