stevenzwu commented on code in PR #10331: URL: https://github.com/apache/iceberg/pull/10331#discussion_r1600742620
########## flink/v1.19/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java: ########## @@ -47,151 +48,182 @@ * distribution to downstream subtasks. */ @Internal -class DataStatisticsOperator<D extends DataStatistics<D, S>, S> - extends AbstractStreamOperator<DataStatisticsOrRecord<D, S>> - implements OneInputStreamOperator<RowData, DataStatisticsOrRecord<D, S>>, OperatorEventHandler { +public class DataStatisticsOperator extends AbstractStreamOperator<StatisticsOrRecord> + implements OneInputStreamOperator<RowData, StatisticsOrRecord>, OperatorEventHandler { private static final long serialVersionUID = 1L; private final String operatorName; private final RowDataWrapper rowDataWrapper; private final SortKey sortKey; private final OperatorEventGateway operatorEventGateway; - private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer; - private transient volatile DataStatistics<D, S> localStatistics; - private transient volatile DataStatistics<D, S> globalStatistics; - private transient ListState<DataStatistics<D, S>> globalStatisticsState; + private final int downstreamParallelism; + private final StatisticsType statisticsType; + private final TypeSerializer<DataStatistics> taskStatisticsSerializer; + private final TypeSerializer<AggregatedStatistics> aggregatedStatisticsSerializer; + + private transient int parallelism; + private transient int subtaskIndex; + private transient ListState<AggregatedStatistics> globalStatisticsState; + // current statistics type may be different from the config due to possible + // migration from Map statistics to Sketch statistics when high cardinality detected + private transient volatile StatisticsType taskStatisticsType; + private transient volatile DataStatistics localStatistics; + private transient volatile AggregatedStatistics globalStatistics; DataStatisticsOperator( String operatorName, Schema schema, SortOrder sortOrder, OperatorEventGateway operatorEventGateway, - TypeSerializer<DataStatistics<D, S>> statisticsSerializer) { + int downstreamParallelism, + StatisticsType statisticsType) { this.operatorName = operatorName; this.rowDataWrapper = new RowDataWrapper(FlinkSchemaUtil.convert(schema), schema.asStruct()); this.sortKey = new SortKey(schema, sortOrder); this.operatorEventGateway = operatorEventGateway; - this.statisticsSerializer = statisticsSerializer; + this.downstreamParallelism = downstreamParallelism; + this.statisticsType = statisticsType; + + SortKeySerializer sortKeySerializer = new SortKeySerializer(schema, sortOrder); + this.taskStatisticsSerializer = new DataStatisticsSerializer(sortKeySerializer); + this.aggregatedStatisticsSerializer = new AggregatedStatisticsSerializer(sortKeySerializer); } @Override public void initializeState(StateInitializationContext context) throws Exception { - localStatistics = statisticsSerializer.createInstance(); - globalStatisticsState = + this.parallelism = getRuntimeContext().getTaskInfo().getNumberOfParallelSubtasks(); + this.subtaskIndex = getRuntimeContext().getTaskInfo().getIndexOfThisSubtask(); + this.globalStatisticsState = context .getOperatorStateStore() .getUnionListState( - new ListStateDescriptor<>("globalStatisticsState", statisticsSerializer)); + new ListStateDescriptor<>("globalStatisticsState", aggregatedStatisticsSerializer)); if (context.isRestored()) { - int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask(); if (globalStatisticsState.get() == null || !globalStatisticsState.get().iterator().hasNext()) { LOG.warn( "Operator {} subtask {} doesn't have global statistics state to restore", operatorName, subtaskIndex); - globalStatistics = statisticsSerializer.createInstance(); } else { LOG.info( - "Restoring operator {} global statistics state for subtask {}", - operatorName, - subtaskIndex); - globalStatistics = globalStatisticsState.get().iterator().next(); + "Operator {} subtask {} restoring global statistics state", operatorName, subtaskIndex); + this.globalStatistics = globalStatisticsState.get().iterator().next(); } - } else { - globalStatistics = statisticsSerializer.createInstance(); } + + this.taskStatisticsType = StatisticsUtil.collectType(statisticsType, globalStatistics); + this.localStatistics = + StatisticsUtil.createTaskStatistics(taskStatisticsType, parallelism, downstreamParallelism); } @Override public void open() throws Exception { - if (!globalStatistics.isEmpty()) { - output.collect( - new StreamRecord<>(DataStatisticsOrRecord.fromDataStatistics(globalStatistics))); + if (globalStatistics != null) { + output.collect(new StreamRecord<>(StatisticsOrRecord.fromDataStatistics(globalStatistics))); } } @Override - @SuppressWarnings("unchecked") public void handleOperatorEvent(OperatorEvent event) { - int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask(); Preconditions.checkArgument( - event instanceof DataStatisticsEvent, + event instanceof StatisticsEvent, String.format( "Operator %s subtask %s received unexpected operator event %s", operatorName, subtaskIndex, event.getClass())); - DataStatisticsEvent<D, S> statisticsEvent = (DataStatisticsEvent<D, S>) event; + StatisticsEvent statisticsEvent = (StatisticsEvent) event; LOG.info( - "Operator {} received global data event from coordinator checkpoint {}", + "Operator {} subtask {} received global data event from coordinator checkpoint {}", operatorName, + subtaskIndex, statisticsEvent.checkpointId()); globalStatistics = - DataStatisticsUtil.deserializeDataStatistics( - statisticsEvent.statisticsBytes(), statisticsSerializer); - output.collect(new StreamRecord<>(DataStatisticsOrRecord.fromDataStatistics(globalStatistics))); + StatisticsUtil.deserializeAggregatedStatistics( + statisticsEvent.statisticsBytes(), aggregatedStatisticsSerializer); + output.collect(new StreamRecord<>(StatisticsOrRecord.fromDataStatistics(globalStatistics))); } + @SuppressWarnings("unchecked") @Override public void processElement(StreamRecord<RowData> streamRecord) { RowData record = streamRecord.getValue(); StructLike struct = rowDataWrapper.wrap(record); sortKey.wrap(struct); localStatistics.add(sortKey); - output.collect(new StreamRecord<>(DataStatisticsOrRecord.fromRecord(record))); + + if (localStatistics.type() == StatisticsType.Map) { + Map<SortKey, Long> mapStatistics = (Map<SortKey, Long>) localStatistics.result(); + if (statisticsType == StatisticsType.Auto Review Comment: Good question - each operator makes independent decision on switching from Map to Sketch during local collection phase. - when operators received the global statistics from coordinator, operators should also check if type switch is needed. but looks like I missed this logic. will add. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org For additional commands, e-mail: issues-h...@iceberg.apache.org