stevenzwu commented on code in PR #10331: URL: https://github.com/apache/iceberg/pull/10331#discussion_r1599242761
########## flink/v1.19/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatistics.java: ########## @@ -19,53 +19,87 @@ package org.apache.iceberg.flink.sink.shuffle; import java.io.Serializable; -import org.apache.flink.api.common.typeutils.TypeSerializer; +import java.util.Arrays; +import java.util.Map; +import org.apache.iceberg.SortKey; import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; -import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.base.Objects; /** * AggregatedStatistics is used by {@link DataStatisticsCoordinator} to collect {@link * DataStatistics} from {@link DataStatisticsOperator} subtasks for specific checkpoint. It stores * the merged {@link DataStatistics} result from all reported subtasks. */ -class AggregatedStatistics<D extends DataStatistics<D, S>, S> implements Serializable { - +class AggregatedStatistics implements Serializable { private final long checkpointId; - private final DataStatistics<D, S> dataStatistics; - - AggregatedStatistics(long checkpoint, TypeSerializer<DataStatistics<D, S>> statisticsSerializer) { - this.checkpointId = checkpoint; - this.dataStatistics = statisticsSerializer.createInstance(); - } + private final StatisticsType type; + private final Map<SortKey, Long> keyFrequency; Review Comment: combine both Map and Sketch stats in the same aggregated statistics object would allow run-time switch from `Map` stats to `Sketch`. ########## flink/v1.19/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsUtil.java: ########## @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.io.IOException; +import java.io.UncheckedIOException; +import javax.annotation.Nullable; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.DataInputDeserializer; +import org.apache.flink.core.memory.DataOutputSerializer; + +/** + * DataStatisticsUtil is the utility to serialize and deserialize {@link DataStatistics} and {@link + * AggregatedStatistics} + */ +class StatisticsUtil { Review Comment: renamed from `DataStatisticsUtil` ########## flink/v1.19/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsSerializer.java: ########## @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; +import org.apache.datasketches.memory.Memory; +import org.apache.datasketches.sampling.ReservoirItemsSketch; +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.EnumSerializer; +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.api.common.typeutils.base.MapSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; + +@Internal +class DataStatisticsSerializer extends TypeSerializer<DataStatistics> { Review Comment: this single serializer can handle both map and sketch stats type ########## flink/v1.19/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatisticsTracker.java: ########## @@ -104,30 +144,135 @@ AggregatedStatistics<D, S> updateAndCheckCompletion( subtask, checkpointId); } else { - inProgressStatistics.mergeDataStatistic( + merge(dataStatistics); + LOG.debug( + "Merge data statistics from operator {} subtask {} for checkpoint {}.", operatorName, - event.checkpointId(), - DataStatisticsUtil.deserializeDataStatistics( - event.statisticsBytes(), statisticsSerializer)); + subtask, + checkpointId); } + // This should be the happy path where all subtasks reports are received if (inProgressSubtaskSet.size() == parallelism) { - completedStatistics = inProgressStatistics; + completedStatistics = completedStatistics(); + resetAggregates(); LOG.info( - "Received data statistics from all {} operators {} for checkpoint {}. Return last completed aggregator {}.", + "Received data statistics from all {} operators {} for checkpoint {}. Return last completed aggregator.", parallelism, operatorName, - inProgressStatistics.checkpointId(), - completedStatistics.dataStatistics()); - inProgressStatistics = new AggregatedStatistics<>(checkpointId + 1, statisticsSerializer); - inProgressSubtaskSet.clear(); + inProgressCheckpointId); } return completedStatistics; } + private boolean inProgress() { + return inProgressCheckpointId != CheckpointStoreUtil.INVALID_CHECKPOINT_ID; + } + + private AggregatedStatistics completedStatistics() { + if (coordinatorStatisticsType == StatisticsType.Map) { + LOG.info( + "Completed map statistics aggregation with {} keys", coordinatorMapStatistics.size()); + return AggregatedStatistics.fromKeyFrequency( + inProgressCheckpointId, coordinatorMapStatistics); + } else { + ReservoirItemsSketch<SortKey> sketch = coordinatorSketchStatistics.getResult(); + LOG.info( + "Completed sketch statistics aggregation: " + + "reservoir size = {}, number of items seen = {}, number of samples = {}", + sketch.getK(), + sketch.getN(), + sketch.getNumSamples()); + return AggregatedStatistics.fromRangeBounds( + inProgressCheckpointId, + SketchUtil.rangeBounds(downstreamParallelism, comparator, sketch)); + } + } + + private void initializeAggregates(long checkpointId, DataStatistics taskStatistics) { + LOG.info("Starting a new statistics aggregation for checkpoint {}", checkpointId); + this.inProgressCheckpointId = checkpointId; + this.coordinatorStatisticsType = taskStatistics.type(); + + if (coordinatorStatisticsType == StatisticsType.Map) { + this.coordinatorMapStatistics = Maps.newHashMap(); + this.coordinatorSketchStatistics = null; + } else { + this.coordinatorMapStatistics = null; + this.coordinatorSketchStatistics = + ReservoirItemsUnion.newInstance( + SketchUtil.determineCoordinatorReservoirSize(downstreamParallelism)); + } + } + + private void resetAggregates() { + inProgressSubtaskSet.clear(); + this.inProgressCheckpointId = CheckpointStoreUtil.INVALID_CHECKPOINT_ID; + this.coordinatorMapStatistics = null; + this.coordinatorSketchStatistics = null; + } + + @SuppressWarnings("unchecked") + private void merge(DataStatistics taskStatistics) { Review Comment: this method shows the stats type migration from `Map` to `Sketch` ########## flink/v1.19/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java: ########## @@ -47,151 +48,182 @@ * distribution to downstream subtasks. */ @Internal -class DataStatisticsOperator<D extends DataStatistics<D, S>, S> - extends AbstractStreamOperator<DataStatisticsOrRecord<D, S>> - implements OneInputStreamOperator<RowData, DataStatisticsOrRecord<D, S>>, OperatorEventHandler { +public class DataStatisticsOperator extends AbstractStreamOperator<StatisticsOrRecord> + implements OneInputStreamOperator<RowData, StatisticsOrRecord>, OperatorEventHandler { private static final long serialVersionUID = 1L; private final String operatorName; private final RowDataWrapper rowDataWrapper; private final SortKey sortKey; private final OperatorEventGateway operatorEventGateway; - private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer; - private transient volatile DataStatistics<D, S> localStatistics; - private transient volatile DataStatistics<D, S> globalStatistics; - private transient ListState<DataStatistics<D, S>> globalStatisticsState; + private final int downstreamParallelism; + private final StatisticsType statisticsType; + private final TypeSerializer<DataStatistics> taskStatisticsSerializer; + private final TypeSerializer<AggregatedStatistics> aggregatedStatisticsSerializer; + + private transient int parallelism; + private transient int subtaskIndex; + private transient ListState<AggregatedStatistics> globalStatisticsState; + // current statistics type may be different from the config due to possible + // migration from Map statistics to Sketch statistics when high cardinality detected + private transient volatile StatisticsType taskStatisticsType; + private transient volatile DataStatistics localStatistics; + private transient volatile AggregatedStatistics globalStatistics; DataStatisticsOperator( String operatorName, Schema schema, SortOrder sortOrder, OperatorEventGateway operatorEventGateway, - TypeSerializer<DataStatistics<D, S>> statisticsSerializer) { + int downstreamParallelism, + StatisticsType statisticsType) { this.operatorName = operatorName; this.rowDataWrapper = new RowDataWrapper(FlinkSchemaUtil.convert(schema), schema.asStruct()); this.sortKey = new SortKey(schema, sortOrder); this.operatorEventGateway = operatorEventGateway; - this.statisticsSerializer = statisticsSerializer; + this.downstreamParallelism = downstreamParallelism; + this.statisticsType = statisticsType; + + SortKeySerializer sortKeySerializer = new SortKeySerializer(schema, sortOrder); + this.taskStatisticsSerializer = new DataStatisticsSerializer(sortKeySerializer); + this.aggregatedStatisticsSerializer = new AggregatedStatisticsSerializer(sortKeySerializer); } @Override public void initializeState(StateInitializationContext context) throws Exception { - localStatistics = statisticsSerializer.createInstance(); - globalStatisticsState = + this.parallelism = getRuntimeContext().getTaskInfo().getNumberOfParallelSubtasks(); + this.subtaskIndex = getRuntimeContext().getTaskInfo().getIndexOfThisSubtask(); + this.globalStatisticsState = context .getOperatorStateStore() .getUnionListState( - new ListStateDescriptor<>("globalStatisticsState", statisticsSerializer)); + new ListStateDescriptor<>("globalStatisticsState", aggregatedStatisticsSerializer)); if (context.isRestored()) { - int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask(); if (globalStatisticsState.get() == null || !globalStatisticsState.get().iterator().hasNext()) { LOG.warn( "Operator {} subtask {} doesn't have global statistics state to restore", operatorName, subtaskIndex); - globalStatistics = statisticsSerializer.createInstance(); } else { LOG.info( - "Restoring operator {} global statistics state for subtask {}", - operatorName, - subtaskIndex); - globalStatistics = globalStatisticsState.get().iterator().next(); + "Operator {} subtask {} restoring global statistics state", operatorName, subtaskIndex); + this.globalStatistics = globalStatisticsState.get().iterator().next(); } - } else { - globalStatistics = statisticsSerializer.createInstance(); } + + this.taskStatisticsType = StatisticsUtil.collectType(statisticsType, globalStatistics); + this.localStatistics = + StatisticsUtil.createTaskStatistics(taskStatisticsType, parallelism, downstreamParallelism); } @Override public void open() throws Exception { - if (!globalStatistics.isEmpty()) { - output.collect( - new StreamRecord<>(DataStatisticsOrRecord.fromDataStatistics(globalStatistics))); + if (globalStatistics != null) { + output.collect(new StreamRecord<>(StatisticsOrRecord.fromDataStatistics(globalStatistics))); } } @Override - @SuppressWarnings("unchecked") public void handleOperatorEvent(OperatorEvent event) { - int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask(); Preconditions.checkArgument( - event instanceof DataStatisticsEvent, + event instanceof StatisticsEvent, String.format( "Operator %s subtask %s received unexpected operator event %s", operatorName, subtaskIndex, event.getClass())); - DataStatisticsEvent<D, S> statisticsEvent = (DataStatisticsEvent<D, S>) event; + StatisticsEvent statisticsEvent = (StatisticsEvent) event; LOG.info( - "Operator {} received global data event from coordinator checkpoint {}", + "Operator {} subtask {} received global data event from coordinator checkpoint {}", operatorName, + subtaskIndex, statisticsEvent.checkpointId()); globalStatistics = - DataStatisticsUtil.deserializeDataStatistics( - statisticsEvent.statisticsBytes(), statisticsSerializer); - output.collect(new StreamRecord<>(DataStatisticsOrRecord.fromDataStatistics(globalStatistics))); + StatisticsUtil.deserializeAggregatedStatistics( + statisticsEvent.statisticsBytes(), aggregatedStatisticsSerializer); + output.collect(new StreamRecord<>(StatisticsOrRecord.fromDataStatistics(globalStatistics))); } + @SuppressWarnings("unchecked") @Override public void processElement(StreamRecord<RowData> streamRecord) { RowData record = streamRecord.getValue(); StructLike struct = rowDataWrapper.wrap(record); sortKey.wrap(struct); localStatistics.add(sortKey); - output.collect(new StreamRecord<>(DataStatisticsOrRecord.fromRecord(record))); + + if (localStatistics.type() == StatisticsType.Map) { + Map<SortKey, Long> mapStatistics = (Map<SortKey, Long>) localStatistics.result(); + if (statisticsType == StatisticsType.Auto Review Comment: this is stats migration (Map -> Sketch) at operator side during collection phase. ########## flink/v1.19/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatisticsSerializer.java: ########## @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.EnumSerializer; +import org.apache.flink.api.common.typeutils.base.ListSerializer; +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.api.common.typeutils.base.MapSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.iceberg.SortKey; + +public class AggregatedStatisticsSerializer extends TypeSerializer<AggregatedStatistics> { + private final TypeSerializer<SortKey> sortKeySerializer; + private final EnumSerializer<StatisticsType> statisticsTypeSerializer; + private final MapSerializer<SortKey, Long> keyFrequencySerializer; + private final ListSerializer<SortKey> rangeBoundsSerializer; + + AggregatedStatisticsSerializer(TypeSerializer<SortKey> sortKeySerializer) { + this.sortKeySerializer = sortKeySerializer; + this.statisticsTypeSerializer = new EnumSerializer<>(StatisticsType.class); + this.keyFrequencySerializer = new MapSerializer<>(sortKeySerializer, LongSerializer.INSTANCE); + this.rangeBoundsSerializer = new ListSerializer<>(sortKeySerializer); + } + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public TypeSerializer<AggregatedStatistics> duplicate() { + return new AggregatedStatisticsSerializer(sortKeySerializer); + } + + @Override + public AggregatedStatistics createInstance() { + return new AggregatedStatistics(0, StatisticsType.Map, Collections.emptyMap(), null); + } + + @Override + public AggregatedStatistics copy(AggregatedStatistics from) { + return new AggregatedStatistics( + from.checkpointId(), from.type(), from.keyFrequency(), from.rangeBounds()); + } + + @Override + public AggregatedStatistics copy(AggregatedStatistics from, AggregatedStatistics reuse) { + // no benefit of reuse + return copy(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(AggregatedStatistics record, DataOutputView target) throws IOException { + target.writeLong(record.checkpointId()); + statisticsTypeSerializer.serialize(record.type(), target); + if (record.type() == StatisticsType.Map) { + keyFrequencySerializer.serialize(record.keyFrequency(), target); + } else { + rangeBoundsSerializer.serialize(Arrays.asList(record.rangeBounds()), target); Review Comment: Reused list serializer from Flink. paying a small penalty for array to list conversion for that. ########## flink/v1.19/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java: ########## @@ -44,51 +46,76 @@ import org.slf4j.LoggerFactory; /** - * DataStatisticsCoordinator receives {@link DataStatisticsEvent} from {@link - * DataStatisticsOperator} every subtask and then merge them together. Once aggregation for all - * subtasks data statistics completes, DataStatisticsCoordinator will send the aggregated data - * statistics back to {@link DataStatisticsOperator}. In the end a custom partitioner will - * distribute traffic based on the aggregated data statistics to improve data clustering. + * DataStatisticsCoordinator receives {@link StatisticsEvent} from {@link DataStatisticsOperator} + * every subtask and then merge them together. Once aggregation for all subtasks data statistics + * completes, DataStatisticsCoordinator will send the aggregated data statistics back to {@link + * DataStatisticsOperator}. In the end a custom partitioner will distribute traffic based on the + * aggregated data statistics to improve data clustering. */ @Internal -class DataStatisticsCoordinator<D extends DataStatistics<D, S>, S> implements OperatorCoordinator { +class DataStatisticsCoordinator implements OperatorCoordinator { private static final Logger LOG = LoggerFactory.getLogger(DataStatisticsCoordinator.class); private final String operatorName; + private final OperatorCoordinator.Context context; + private final Schema schema; + private final SortOrder sortOrder; + private final int downstreamParallelism; + private final StatisticsType statisticsType; + private final ExecutorService coordinatorExecutor; - private final OperatorCoordinator.Context operatorCoordinatorContext; private final SubtaskGateways subtaskGateways; private final CoordinatorExecutorThreadFactory coordinatorThreadFactory; - private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer; - private final transient AggregatedStatisticsTracker<D, S> aggregatedStatisticsTracker; - private volatile AggregatedStatistics<D, S> completedStatistics; - private volatile boolean started; + private final TypeSerializer<AggregatedStatistics> aggregatedStatisticsSerializer; + + private transient boolean started; + private transient AggregatedStatisticsTracker aggregatedStatisticsTracker; + private transient AggregatedStatistics completedStatistics; DataStatisticsCoordinator( String operatorName, OperatorCoordinator.Context context, - TypeSerializer<DataStatistics<D, S>> statisticsSerializer) { + Schema schema, + SortOrder sortOrder, + int downstreamParallelism, Review Comment: need to know the downstream write operator parallelism (a.k.a number of partitions) to determine the sketch reservoir size and calculate range bounds array. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org For additional commands, e-mail: issues-h...@iceberg.apache.org