stevenzwu commented on code in PR #10331:
URL: https://github.com/apache/iceberg/pull/10331#discussion_r1599242761


##########
flink/v1.19/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatistics.java:
##########
@@ -19,53 +19,87 @@
 package org.apache.iceberg.flink.sink.shuffle;
 
 import java.io.Serializable;
-import org.apache.flink.api.common.typeutils.TypeSerializer;
+import java.util.Arrays;
+import java.util.Map;
+import org.apache.iceberg.SortKey;
 import org.apache.iceberg.relocated.com.google.common.base.MoreObjects;
-import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
+import org.apache.iceberg.relocated.com.google.common.base.Objects;
 
 /**
  * AggregatedStatistics is used by {@link DataStatisticsCoordinator} to 
collect {@link
  * DataStatistics} from {@link DataStatisticsOperator} subtasks for specific 
checkpoint. It stores
  * the merged {@link DataStatistics} result from all reported subtasks.
  */
-class AggregatedStatistics<D extends DataStatistics<D, S>, S> implements 
Serializable {
-
+class AggregatedStatistics implements Serializable {
   private final long checkpointId;
-  private final DataStatistics<D, S> dataStatistics;
-
-  AggregatedStatistics(long checkpoint, TypeSerializer<DataStatistics<D, S>> 
statisticsSerializer) {
-    this.checkpointId = checkpoint;
-    this.dataStatistics = statisticsSerializer.createInstance();
-  }
+  private final StatisticsType type;
+  private final Map<SortKey, Long> keyFrequency;

Review Comment:
   combine both Map and Sketch stats in the same aggregated statistics object 
would allow run-time switch from `Map` stats to `Sketch`.



##########
flink/v1.19/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsUtil.java:
##########
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import javax.annotation.Nullable;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.core.memory.DataOutputSerializer;
+
+/**
+ * DataStatisticsUtil is the utility to serialize and deserialize {@link 
DataStatistics} and {@link
+ * AggregatedStatistics}
+ */
+class StatisticsUtil {

Review Comment:
   renamed from `DataStatisticsUtil`



##########
flink/v1.19/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsSerializer.java:
##########
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+import org.apache.datasketches.memory.Memory;
+import org.apache.datasketches.sampling.ReservoirItemsSketch;
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.base.EnumSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.api.common.typeutils.base.MapSerializer;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.iceberg.SortKey;
+import org.apache.iceberg.relocated.com.google.common.collect.Maps;
+
+@Internal
+class DataStatisticsSerializer extends TypeSerializer<DataStatistics> {

Review Comment:
   this single serializer can handle both map and sketch stats type



##########
flink/v1.19/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatisticsTracker.java:
##########
@@ -104,30 +144,135 @@ AggregatedStatistics<D, S> updateAndCheckCompletion(
           subtask,
           checkpointId);
     } else {
-      inProgressStatistics.mergeDataStatistic(
+      merge(dataStatistics);
+      LOG.debug(
+          "Merge data statistics from operator {} subtask {} for checkpoint 
{}.",
           operatorName,
-          event.checkpointId(),
-          DataStatisticsUtil.deserializeDataStatistics(
-              event.statisticsBytes(), statisticsSerializer));
+          subtask,
+          checkpointId);
     }
 
+    // This should be the happy path where all subtasks reports are received
     if (inProgressSubtaskSet.size() == parallelism) {
-      completedStatistics = inProgressStatistics;
+      completedStatistics = completedStatistics();
+      resetAggregates();
       LOG.info(
-          "Received data statistics from all {} operators {} for checkpoint 
{}. Return last completed aggregator {}.",
+          "Received data statistics from all {} operators {} for checkpoint 
{}. Return last completed aggregator.",
           parallelism,
           operatorName,
-          inProgressStatistics.checkpointId(),
-          completedStatistics.dataStatistics());
-      inProgressStatistics = new AggregatedStatistics<>(checkpointId + 1, 
statisticsSerializer);
-      inProgressSubtaskSet.clear();
+          inProgressCheckpointId);
     }
 
     return completedStatistics;
   }
 
+  private boolean inProgress() {
+    return inProgressCheckpointId != CheckpointStoreUtil.INVALID_CHECKPOINT_ID;
+  }
+
+  private AggregatedStatistics completedStatistics() {
+    if (coordinatorStatisticsType == StatisticsType.Map) {
+      LOG.info(
+          "Completed map statistics aggregation with {} keys", 
coordinatorMapStatistics.size());
+      return AggregatedStatistics.fromKeyFrequency(
+          inProgressCheckpointId, coordinatorMapStatistics);
+    } else {
+      ReservoirItemsSketch<SortKey> sketch = 
coordinatorSketchStatistics.getResult();
+      LOG.info(
+          "Completed sketch statistics aggregation: "
+              + "reservoir size = {}, number of items seen = {}, number of 
samples = {}",
+          sketch.getK(),
+          sketch.getN(),
+          sketch.getNumSamples());
+      return AggregatedStatistics.fromRangeBounds(
+          inProgressCheckpointId,
+          SketchUtil.rangeBounds(downstreamParallelism, comparator, sketch));
+    }
+  }
+
+  private void initializeAggregates(long checkpointId, DataStatistics 
taskStatistics) {
+    LOG.info("Starting a new statistics aggregation for checkpoint {}", 
checkpointId);
+    this.inProgressCheckpointId = checkpointId;
+    this.coordinatorStatisticsType = taskStatistics.type();
+
+    if (coordinatorStatisticsType == StatisticsType.Map) {
+      this.coordinatorMapStatistics = Maps.newHashMap();
+      this.coordinatorSketchStatistics = null;
+    } else {
+      this.coordinatorMapStatistics = null;
+      this.coordinatorSketchStatistics =
+          ReservoirItemsUnion.newInstance(
+              
SketchUtil.determineCoordinatorReservoirSize(downstreamParallelism));
+    }
+  }
+
+  private void resetAggregates() {
+    inProgressSubtaskSet.clear();
+    this.inProgressCheckpointId = CheckpointStoreUtil.INVALID_CHECKPOINT_ID;
+    this.coordinatorMapStatistics = null;
+    this.coordinatorSketchStatistics = null;
+  }
+
+  @SuppressWarnings("unchecked")
+  private void merge(DataStatistics taskStatistics) {

Review Comment:
   this method shows the stats type migration from `Map` to `Sketch`



##########
flink/v1.19/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java:
##########
@@ -47,151 +48,182 @@
  * distribution to downstream subtasks.
  */
 @Internal
-class DataStatisticsOperator<D extends DataStatistics<D, S>, S>
-    extends AbstractStreamOperator<DataStatisticsOrRecord<D, S>>
-    implements OneInputStreamOperator<RowData, DataStatisticsOrRecord<D, S>>, 
OperatorEventHandler {
+public class DataStatisticsOperator extends 
AbstractStreamOperator<StatisticsOrRecord>
+    implements OneInputStreamOperator<RowData, StatisticsOrRecord>, 
OperatorEventHandler {
 
   private static final long serialVersionUID = 1L;
 
   private final String operatorName;
   private final RowDataWrapper rowDataWrapper;
   private final SortKey sortKey;
   private final OperatorEventGateway operatorEventGateway;
-  private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer;
-  private transient volatile DataStatistics<D, S> localStatistics;
-  private transient volatile DataStatistics<D, S> globalStatistics;
-  private transient ListState<DataStatistics<D, S>> globalStatisticsState;
+  private final int downstreamParallelism;
+  private final StatisticsType statisticsType;
+  private final TypeSerializer<DataStatistics> taskStatisticsSerializer;
+  private final TypeSerializer<AggregatedStatistics> 
aggregatedStatisticsSerializer;
+
+  private transient int parallelism;
+  private transient int subtaskIndex;
+  private transient ListState<AggregatedStatistics> globalStatisticsState;
+  // current statistics type may be different from the config due to possible
+  // migration from Map statistics to Sketch statistics when high cardinality 
detected
+  private transient volatile StatisticsType taskStatisticsType;
+  private transient volatile DataStatistics localStatistics;
+  private transient volatile AggregatedStatistics globalStatistics;
 
   DataStatisticsOperator(
       String operatorName,
       Schema schema,
       SortOrder sortOrder,
       OperatorEventGateway operatorEventGateway,
-      TypeSerializer<DataStatistics<D, S>> statisticsSerializer) {
+      int downstreamParallelism,
+      StatisticsType statisticsType) {
     this.operatorName = operatorName;
     this.rowDataWrapper = new RowDataWrapper(FlinkSchemaUtil.convert(schema), 
schema.asStruct());
     this.sortKey = new SortKey(schema, sortOrder);
     this.operatorEventGateway = operatorEventGateway;
-    this.statisticsSerializer = statisticsSerializer;
+    this.downstreamParallelism = downstreamParallelism;
+    this.statisticsType = statisticsType;
+
+    SortKeySerializer sortKeySerializer = new SortKeySerializer(schema, 
sortOrder);
+    this.taskStatisticsSerializer = new 
DataStatisticsSerializer(sortKeySerializer);
+    this.aggregatedStatisticsSerializer = new 
AggregatedStatisticsSerializer(sortKeySerializer);
   }
 
   @Override
   public void initializeState(StateInitializationContext context) throws 
Exception {
-    localStatistics = statisticsSerializer.createInstance();
-    globalStatisticsState =
+    this.parallelism = 
getRuntimeContext().getTaskInfo().getNumberOfParallelSubtasks();
+    this.subtaskIndex = 
getRuntimeContext().getTaskInfo().getIndexOfThisSubtask();
+    this.globalStatisticsState =
         context
             .getOperatorStateStore()
             .getUnionListState(
-                new ListStateDescriptor<>("globalStatisticsState", 
statisticsSerializer));
+                new ListStateDescriptor<>("globalStatisticsState", 
aggregatedStatisticsSerializer));
 
     if (context.isRestored()) {
-      int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
       if (globalStatisticsState.get() == null
           || !globalStatisticsState.get().iterator().hasNext()) {
         LOG.warn(
             "Operator {} subtask {} doesn't have global statistics state to 
restore",
             operatorName,
             subtaskIndex);
-        globalStatistics = statisticsSerializer.createInstance();
       } else {
         LOG.info(
-            "Restoring operator {} global statistics state for subtask {}",
-            operatorName,
-            subtaskIndex);
-        globalStatistics = globalStatisticsState.get().iterator().next();
+            "Operator {} subtask {} restoring global statistics state", 
operatorName, subtaskIndex);
+        this.globalStatistics = globalStatisticsState.get().iterator().next();
       }
-    } else {
-      globalStatistics = statisticsSerializer.createInstance();
     }
+
+    this.taskStatisticsType = StatisticsUtil.collectType(statisticsType, 
globalStatistics);
+    this.localStatistics =
+        StatisticsUtil.createTaskStatistics(taskStatisticsType, parallelism, 
downstreamParallelism);
   }
 
   @Override
   public void open() throws Exception {
-    if (!globalStatistics.isEmpty()) {
-      output.collect(
-          new 
StreamRecord<>(DataStatisticsOrRecord.fromDataStatistics(globalStatistics)));
+    if (globalStatistics != null) {
+      output.collect(new 
StreamRecord<>(StatisticsOrRecord.fromDataStatistics(globalStatistics)));
     }
   }
 
   @Override
-  @SuppressWarnings("unchecked")
   public void handleOperatorEvent(OperatorEvent event) {
-    int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
     Preconditions.checkArgument(
-        event instanceof DataStatisticsEvent,
+        event instanceof StatisticsEvent,
         String.format(
             "Operator %s subtask %s received unexpected operator event %s",
             operatorName, subtaskIndex, event.getClass()));
-    DataStatisticsEvent<D, S> statisticsEvent = (DataStatisticsEvent<D, S>) 
event;
+    StatisticsEvent statisticsEvent = (StatisticsEvent) event;
     LOG.info(
-        "Operator {} received global data event from coordinator checkpoint 
{}",
+        "Operator {} subtask {} received global data event from coordinator 
checkpoint {}",
         operatorName,
+        subtaskIndex,
         statisticsEvent.checkpointId());
     globalStatistics =
-        DataStatisticsUtil.deserializeDataStatistics(
-            statisticsEvent.statisticsBytes(), statisticsSerializer);
-    output.collect(new 
StreamRecord<>(DataStatisticsOrRecord.fromDataStatistics(globalStatistics)));
+        StatisticsUtil.deserializeAggregatedStatistics(
+            statisticsEvent.statisticsBytes(), aggregatedStatisticsSerializer);
+    output.collect(new 
StreamRecord<>(StatisticsOrRecord.fromDataStatistics(globalStatistics)));
   }
 
+  @SuppressWarnings("unchecked")
   @Override
   public void processElement(StreamRecord<RowData> streamRecord) {
     RowData record = streamRecord.getValue();
     StructLike struct = rowDataWrapper.wrap(record);
     sortKey.wrap(struct);
     localStatistics.add(sortKey);
-    output.collect(new 
StreamRecord<>(DataStatisticsOrRecord.fromRecord(record)));
+
+    if (localStatistics.type() == StatisticsType.Map) {
+      Map<SortKey, Long> mapStatistics = (Map<SortKey, Long>) 
localStatistics.result();
+      if (statisticsType == StatisticsType.Auto

Review Comment:
   this is stats migration (Map -> Sketch) at operator side during collection 
phase.



##########
flink/v1.19/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatisticsSerializer.java:
##########
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.base.EnumSerializer;
+import org.apache.flink.api.common.typeutils.base.ListSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.api.common.typeutils.base.MapSerializer;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.iceberg.SortKey;
+
+public class AggregatedStatisticsSerializer extends 
TypeSerializer<AggregatedStatistics> {
+  private final TypeSerializer<SortKey> sortKeySerializer;
+  private final EnumSerializer<StatisticsType> statisticsTypeSerializer;
+  private final MapSerializer<SortKey, Long> keyFrequencySerializer;
+  private final ListSerializer<SortKey> rangeBoundsSerializer;
+
+  AggregatedStatisticsSerializer(TypeSerializer<SortKey> sortKeySerializer) {
+    this.sortKeySerializer = sortKeySerializer;
+    this.statisticsTypeSerializer = new EnumSerializer<>(StatisticsType.class);
+    this.keyFrequencySerializer = new MapSerializer<>(sortKeySerializer, 
LongSerializer.INSTANCE);
+    this.rangeBoundsSerializer = new ListSerializer<>(sortKeySerializer);
+  }
+
+  @Override
+  public boolean isImmutableType() {
+    return false;
+  }
+
+  @Override
+  public TypeSerializer<AggregatedStatistics> duplicate() {
+    return new AggregatedStatisticsSerializer(sortKeySerializer);
+  }
+
+  @Override
+  public AggregatedStatistics createInstance() {
+    return new AggregatedStatistics(0, StatisticsType.Map, 
Collections.emptyMap(), null);
+  }
+
+  @Override
+  public AggregatedStatistics copy(AggregatedStatistics from) {
+    return new AggregatedStatistics(
+        from.checkpointId(), from.type(), from.keyFrequency(), 
from.rangeBounds());
+  }
+
+  @Override
+  public AggregatedStatistics copy(AggregatedStatistics from, 
AggregatedStatistics reuse) {
+    // no benefit of reuse
+    return copy(from);
+  }
+
+  @Override
+  public int getLength() {
+    return -1;
+  }
+
+  @Override
+  public void serialize(AggregatedStatistics record, DataOutputView target) 
throws IOException {
+    target.writeLong(record.checkpointId());
+    statisticsTypeSerializer.serialize(record.type(), target);
+    if (record.type() == StatisticsType.Map) {
+      keyFrequencySerializer.serialize(record.keyFrequency(), target);
+    } else {
+      rangeBoundsSerializer.serialize(Arrays.asList(record.rangeBounds()), 
target);

Review Comment:
   Reused list serializer from Flink. paying a small penalty for array to list 
conversion for that.



##########
flink/v1.19/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java:
##########
@@ -44,51 +46,76 @@
 import org.slf4j.LoggerFactory;
 
 /**
- * DataStatisticsCoordinator receives {@link DataStatisticsEvent} from {@link
- * DataStatisticsOperator} every subtask and then merge them together. Once 
aggregation for all
- * subtasks data statistics completes, DataStatisticsCoordinator will send the 
aggregated data
- * statistics back to {@link DataStatisticsOperator}. In the end a custom 
partitioner will
- * distribute traffic based on the aggregated data statistics to improve data 
clustering.
+ * DataStatisticsCoordinator receives {@link StatisticsEvent} from {@link 
DataStatisticsOperator}
+ * every subtask and then merge them together. Once aggregation for all 
subtasks data statistics
+ * completes, DataStatisticsCoordinator will send the aggregated data 
statistics back to {@link
+ * DataStatisticsOperator}. In the end a custom partitioner will distribute 
traffic based on the
+ * aggregated data statistics to improve data clustering.
  */
 @Internal
-class DataStatisticsCoordinator<D extends DataStatistics<D, S>, S> implements 
OperatorCoordinator {
+class DataStatisticsCoordinator implements OperatorCoordinator {
   private static final Logger LOG = 
LoggerFactory.getLogger(DataStatisticsCoordinator.class);
 
   private final String operatorName;
+  private final OperatorCoordinator.Context context;
+  private final Schema schema;
+  private final SortOrder sortOrder;
+  private final int downstreamParallelism;
+  private final StatisticsType statisticsType;
+
   private final ExecutorService coordinatorExecutor;
-  private final OperatorCoordinator.Context operatorCoordinatorContext;
   private final SubtaskGateways subtaskGateways;
   private final CoordinatorExecutorThreadFactory coordinatorThreadFactory;
-  private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer;
-  private final transient AggregatedStatisticsTracker<D, S> 
aggregatedStatisticsTracker;
-  private volatile AggregatedStatistics<D, S> completedStatistics;
-  private volatile boolean started;
+  private final TypeSerializer<AggregatedStatistics> 
aggregatedStatisticsSerializer;
+
+  private transient boolean started;
+  private transient AggregatedStatisticsTracker aggregatedStatisticsTracker;
+  private transient AggregatedStatistics completedStatistics;
 
   DataStatisticsCoordinator(
       String operatorName,
       OperatorCoordinator.Context context,
-      TypeSerializer<DataStatistics<D, S>> statisticsSerializer) {
+      Schema schema,
+      SortOrder sortOrder,
+      int downstreamParallelism,

Review Comment:
   need to know the downstream write operator parallelism (a.k.a number of 
partitions) to determine the sketch reservoir size and calculate range bounds 
array.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org
For additional commands, e-mail: issues-h...@iceberg.apache.org

Reply via email to