vaultah commented on code in PR #13720:
URL: https://github.com/apache/iceberg/pull/13720#discussion_r2293737793
##########
spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/actions/RewriteTablePathSparkAction.java:
##########
@@ -494,36 +483,60 @@ public RewriteContentFileResult
appendDeleteFile(RewriteResult<DeleteFile> r1) {
}
}
- /** Rewrite manifest files in a distributed manner and return rewritten data
files path pairs. */
- private RewriteContentFileResult rewriteManifests(
+ /**
+ * Rewrite manifest files in a distributed manner and return the resulting
manifests and content
+ * files selected for rewriting.
+ */
+ private Map<String, RewriteContentFileResult> rewriteManifests(
Set<Snapshot> deltaSnapshots, TableMetadata tableMetadata,
Set<ManifestFile> toRewrite) {
if (toRewrite.isEmpty()) {
- return new RewriteContentFileResult();
+ return Maps.newHashMap();
}
Encoder<ManifestFile> manifestFileEncoder =
Encoders.javaSerialization(ManifestFile.class);
+ Encoder<RewriteContentFileResult> manifestResultEncoder =
+ Encoders.javaSerialization(RewriteContentFileResult.class);
+ Encoder<Tuple2<String, RewriteContentFileResult>> tupleEncoder =
+ Encoders.tuple(Encoders.STRING(), manifestResultEncoder);
+
Dataset<ManifestFile> manifestDS =
spark().createDataset(Lists.newArrayList(toRewrite),
manifestFileEncoder);
Set<Long> deltaSnapshotIds =
deltaSnapshots.stream().map(Snapshot::snapshotId).collect(Collectors.toSet());
- return manifestDS
- .repartition(toRewrite.size())
- .map(
- toManifests(
- tableBroadcast(),
- sparkContext().broadcast(deltaSnapshotIds),
- stagingDir,
- tableMetadata.formatVersion(),
- sourcePrefix,
- targetPrefix),
- Encoders.bean(RewriteContentFileResult.class))
- // duplicates are expected here as the same data file can have
different statuses
- // (e.g. added and deleted)
- .reduce((ReduceFunction<RewriteContentFileResult>)
RewriteContentFileResult::append);
- }
-
- private static MapFunction<ManifestFile, RewriteContentFileResult>
toManifests(
+ Iterator<Tuple2<String, RewriteContentFileResult>> resultIterator =
+ manifestDS
+ .repartition(toRewrite.size())
+ .map(
+ toManifests(
+ tableBroadcast(),
+ sparkContext().broadcast(deltaSnapshotIds),
+ stagingDir,
+ tableMetadata.formatVersion(),
+ sourcePrefix,
+ targetPrefix),
+ tupleEncoder)
+ .toLocalIterator();
Review Comment:
You're right that the driver still has to process all O(N) records and build
the final map in memory. My thinking with `toLocalIterator` was primarily to
avoid the large intermediate `List` that `.collect()` would create, which felt
like the more immediate memory risk.
However, I agree that the sequential aggregation on the driver is still a
scalability concern. A different approach that might be better is to use
`.mapPartitions()` to pre-aggregate the results into maps on each executor
first, and then `reduce` those few maps on the driver. This would minimize the
driver workload. Although, I noticed the current implementation uses
`.repartition(toRewrite.size())`, which I think might also become a performance
bottleneck for tables with many manifests. A refactor to `.mapPartitions` could
be a good opportunity to solve both issues at once.
If you're aligned, I'm happy to refactor it to that pattern.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]