yegangy0718 commented on code in PR #7360: URL: https://github.com/apache/iceberg/pull/7360#discussion_r1335098182
########## flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatisticsTracker.java: ########## @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Set; +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * AggregatedStatisticsTracker is used by {@link DataStatisticsCoordinator} to track the in progress + * {@link AggregatedStatistics} received from {@link DataStatisticsOperator} subtasks for specific + * checkpoint. + */ +@Internal Review Comment: will remove ########## flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatisticsTracker.java: ########## @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Set; +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * AggregatedStatisticsTracker is used by {@link DataStatisticsCoordinator} to track the in progress + * {@link AggregatedStatistics} received from {@link DataStatisticsOperator} subtasks for specific + * checkpoint. + */ +@Internal +class AggregatedStatisticsTracker<D extends DataStatistics<D, S>, S> { + private static final Logger LOG = LoggerFactory.getLogger(AggregatedStatisticsTracker.class); + private static final double EXPECTED_DATA_STATISTICS_RECEIVED_PERCENTAGE = 90; Review Comment: will take the suggestion to rename the variable ########## flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatisticsTracker.java: ########## @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Set; +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * AggregatedStatisticsTracker is used by {@link DataStatisticsCoordinator} to track the in progress + * {@link AggregatedStatistics} received from {@link DataStatisticsOperator} subtasks for specific + * checkpoint. + */ +@Internal +class AggregatedStatisticsTracker<D extends DataStatistics<D, S>, S> { + private static final Logger LOG = LoggerFactory.getLogger(AggregatedStatisticsTracker.class); + private static final double EXPECTED_DATA_STATISTICS_RECEIVED_PERCENTAGE = 90; + private final String operatorName; + private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer; + private final int parallelism; + private final Set<Integer> inProgressSubtaskSet; + private volatile AggregatedStatistics<D, S> inProgressStatistics; + + AggregatedStatisticsTracker( + String operatorName, + TypeSerializer<DataStatistics<D, S>> statisticsSerializer, + int parallelism) { + this.operatorName = operatorName; + this.statisticsSerializer = statisticsSerializer; + this.parallelism = parallelism; + this.inProgressSubtaskSet = Sets.newHashSet(); + } + + AggregatedStatistics<D, S> receiveDataStatisticEventAndCheckCompletion( + int subtask, DataStatisticsEvent<D, S> event) { + long checkpointId = event.checkpointId(); + + if (inProgressStatistics != null && inProgressStatistics.checkpointId() > checkpointId) { + LOG.info( + "Expect data statistics for operator {} checkpoint {}, but receive event from older checkpoint {}. Ignore it.", + operatorName, + inProgressStatistics.checkpointId(), + checkpointId); + return null; + } + + AggregatedStatistics<D, S> completedStatistics = null; + if (inProgressStatistics != null && inProgressStatistics.checkpointId() < checkpointId) { + if ((double) inProgressSubtaskSet.size() / parallelism * 100 + >= EXPECTED_DATA_STATISTICS_RECEIVED_PERCENTAGE) { + completedStatistics = inProgressStatistics; + LOG.info( + "Received data statistics from {} subtasks out of total {} for operator {} at checkpoint {}. " + + "Complete data statistics aggregation at checkpoint {} as it is more than the threshold of {} percentage", + inProgressSubtaskSet.size(), + parallelism, + operatorName, + checkpointId, + inProgressStatistics.checkpointId(), + EXPECTED_DATA_STATISTICS_RECEIVED_PERCENTAGE); + } else { + LOG.info( + "Received data statistics from {} subtasks out of total {} for operator {} at checkpoint {}. " + + "Aborting the incomplete aggregation for checkpoint {}", + inProgressSubtaskSet.size(), + parallelism, + operatorName, + checkpointId, + inProgressStatistics.checkpointId()); + } + + inProgressStatistics = null; + inProgressSubtaskSet.clear(); + } + + if (inProgressStatistics == null) { + LOG.info("Starting a new data statistics for checkpoint {}", checkpointId); + inProgressStatistics = new AggregatedStatistics<>(checkpointId, statisticsSerializer); + inProgressSubtaskSet.clear(); + } + + if (!inProgressSubtaskSet.add(subtask)) { + LOG.debug( + "Ignore duplicated data statistics from operator {} subtask {} for checkpoint {}.", + operatorName, + subtask, + checkpointId); + } else { + inProgressStatistics.mergeDataStatistic( + operatorName, + event.checkpointId(), + DataStatisticsUtil.deserializeDataStatistics( + event.statisticsBytes(), statisticsSerializer)); + } + + if (inProgressSubtaskSet.size() == parallelism) { + completedStatistics = inProgressStatistics; + LOG.info( + "Received data statistics from all {} operators {} for checkpoint {}. Return last completed aggregator {}.", + parallelism, + operatorName, + inProgressStatistics.checkpointId(), + completedStatistics.dataStatistics()); + inProgressStatistics = new AggregatedStatistics<>(checkpointId + 1, statisticsSerializer); Review Comment: yup, either works ########## flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java: ########## @@ -0,0 +1,395 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.operators.coordination.OperatorCoordinator; +import org.apache.flink.runtime.operators.coordination.OperatorEvent; +import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.FatalExitExceptionHandler; +import org.apache.flink.util.FlinkRuntimeException; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.ThrowableCatchingRunnable; +import org.apache.flink.util.function.ThrowingRunnable; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * DataStatisticsCoordinator receives {@link DataStatisticsEvent} from {@link + * DataStatisticsOperator} every subtask and then merge them together. Once aggregation for all + * subtasks data statistics completes, DataStatisticsCoordinator will send the aggregated data + * statistics back to {@link DataStatisticsOperator}. In the end a custom partitioner will + * distribute traffic based on the aggregated data statistics to improve data clustering. + */ +@Internal +class DataStatisticsCoordinator<D extends DataStatistics<D, S>, S> implements OperatorCoordinator { + private static final Logger LOG = LoggerFactory.getLogger(DataStatisticsCoordinator.class); + + private final String operatorName; + private final ExecutorService coordinatorExecutor; + private final OperatorCoordinator.Context operatorCoordinatorContext; + private final SubtaskGateways subtaskGateways; + private final CoordinatorExecutorThreadFactory coordinatorThreadFactory; + private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer; + private final transient AggregatedStatisticsTracker<D, S> aggregatedStatisticsTracker; + private volatile AggregatedStatistics<D, S> completedStatistics; + private volatile boolean started; + + DataStatisticsCoordinator( + String operatorName, + OperatorCoordinator.Context context, + TypeSerializer<DataStatistics<D, S>> statisticsSerializer) { + this.operatorName = operatorName; + this.coordinatorThreadFactory = + new CoordinatorExecutorThreadFactory( + "DataStatisticsCoordinator-" + operatorName, context.getUserCodeClassloader()); + this.coordinatorExecutor = Executors.newSingleThreadExecutor(coordinatorThreadFactory); + this.operatorCoordinatorContext = context; + this.subtaskGateways = new SubtaskGateways(operatorName, parallelism()); + this.statisticsSerializer = statisticsSerializer; + this.aggregatedStatisticsTracker = + new AggregatedStatisticsTracker<>(operatorName, statisticsSerializer, parallelism()); + } + + @Override + public void start() throws Exception { + LOG.info("Starting data statistics coordinator: {}.", operatorName); + started = true; + } + + @Override + public void close() throws Exception { + coordinatorExecutor.shutdown(); + LOG.info("Closed data statistics coordinator: {}.", operatorName); + } + + @VisibleForTesting + void callInCoordinatorThread(Callable<Void> callable, String errorMessage) { + ensureStarted(); + // Ensure the task is done by the coordinator executor. + if (!coordinatorThreadFactory.isCurrentThreadCoordinatorThread()) { + try { + Callable<Void> guardedCallable = + () -> { + try { + return callable.call(); + } catch (Throwable t) { + LOG.error( + "Uncaught Exception in data statistics coordinator: {} executor", + operatorName, + t); + ExceptionUtils.rethrowException(t); + return null; + } + }; + + coordinatorExecutor.submit(guardedCallable).get(); + } catch (InterruptedException | ExecutionException e) { + throw new FlinkRuntimeException(errorMessage, e); + } + } else { + try { + callable.call(); + } catch (Throwable t) { + LOG.error( + "Uncaught Exception in data statistics coordinator: {} executor", operatorName, t); + throw new FlinkRuntimeException(errorMessage, t); + } + } + } + + public void runInCoordinatorThread(Runnable runnable) { + this.coordinatorExecutor.execute( + new ThrowableCatchingRunnable( + throwable -> + this.coordinatorThreadFactory.uncaughtException(Thread.currentThread(), throwable), + runnable)); + } + + private void runInCoordinatorThread(ThrowingRunnable<Throwable> action, String actionString) { + ensureStarted(); + runInCoordinatorThread( + () -> { + try { + action.run(); + } catch (Throwable t) { + ExceptionUtils.rethrowIfFatalErrorOrOOM(t); + LOG.error( + "Uncaught exception in the data statistics coordinator: {} while {}. Triggering job failover.", + operatorName, + actionString, + t); + operatorCoordinatorContext.failJob(t); + } + }); + } + + private void ensureStarted() { + Preconditions.checkState(started, "The coordinator of %s has not started yet.", operatorName); + } + + private int parallelism() { + return operatorCoordinatorContext.currentParallelism(); + } + + private void handleDataStatisticRequest(int subtask, DataStatisticsEvent<D, S> event) { + AggregatedStatistics<D, S> aggregatedStatistics = + aggregatedStatisticsTracker.receiveDataStatisticEventAndCheckCompletion(subtask, event); + + if (aggregatedStatistics != null) { + completedStatistics = aggregatedStatistics; + sendDataStatisticsToSubtasks( + completedStatistics.checkpointId(), completedStatistics.dataStatistics()); + } + } + + private void sendDataStatisticsToSubtasks( + long checkpointId, DataStatistics<D, S> globalDataStatistics) { + callInCoordinatorThread( + () -> { + DataStatisticsEvent<D, S> dataStatisticsEvent = + DataStatisticsEvent.create(checkpointId, globalDataStatistics, statisticsSerializer); + int parallelism = parallelism(); + for (int i = 0; i < parallelism; ++i) { + subtaskGateways.getSubtaskGateway(i).sendEvent(dataStatisticsEvent); + } + + return null; + }, + String.format( + "Failed to send operator %s coordinator global data statistics for checkpoint %d", + operatorName, checkpointId)); + } + + @Override + @SuppressWarnings("unchecked") + public void handleEventFromOperator(int subtask, int attemptNumber, OperatorEvent event) { + runInCoordinatorThread( + () -> { + LOG.debug( + "Handling event from subtask {} (#{}) of {}: {}", + subtask, + attemptNumber, + operatorName, + event); + Preconditions.checkArgument(event instanceof DataStatisticsEvent); + handleDataStatisticRequest(subtask, ((DataStatisticsEvent<D, S>) event)); + }, + String.format( + "handling operator event %s from subtask %d (#%d)", + event.getClass(), subtask, attemptNumber)); + } + + @Override + public void checkpointCoordinator(long checkpointId, CompletableFuture<byte[]> resultFuture) { + runInCoordinatorThread( + () -> { + LOG.debug( + "Snapshotting data statistics coordinator {} for checkpoint {}", + operatorName, + checkpointId); + resultFuture.complete( + DataStatisticsUtil.serializeAggregatedStatistics( + completedStatistics, statisticsSerializer)); + }, + String.format("taking checkpoint %d", checkpointId)); + } + + @Override + public void notifyCheckpointComplete(long checkpointId) {} + + @Override + public void resetToCheckpoint(long checkpointId, @Nullable byte[] checkpointData) + throws Exception { + Preconditions.checkState( + !started, "The coordinator %s can only be reset if it was not yet started", operatorName); + + if (checkpointData == null) { + LOG.info( + "Data statistic coordinator {} checkpoint {} data is null. Cannot be restored.", Review Comment: `Cannot be restored sounds like an error.` Make sense. will update ########## flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinatorProvider.java: ########## @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.operators.coordination.OperatorCoordinator; +import org.apache.flink.runtime.operators.coordination.RecreateOnResetOperatorCoordinator; + +/** + * DataStatisticsCoordinatorProvider provides the method to create new {@link + * DataStatisticsCoordinator} + */ +public class DataStatisticsCoordinatorProvider<D extends DataStatistics<D, S>, S> Review Comment: yes, right ########## flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinatorProvider.java: ########## @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.operators.coordination.EventReceivingTasks; +import org.apache.flink.runtime.operators.coordination.MockOperatorCoordinatorContext; +import org.apache.flink.runtime.operators.coordination.RecreateOnResetOperatorCoordinator; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.data.binary.BinaryRowData; +import org.apache.flink.table.runtime.typeutils.RowDataSerializer; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.VarCharType; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestDataStatisticsCoordinatorProvider { + private static final OperatorID OPERATOR_ID = new OperatorID(); + private static final int NUM_SUBTASKS = 1; + + private DataStatisticsCoordinatorProvider<MapDataStatistics, Map<RowData, Long>> provider; + private EventReceivingTasks receivingTasks; + private TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, Long>>> + statisticsSerializer; + + @Before + public void before() { + statisticsSerializer = + MapDataStatisticsSerializer.fromKeySerializer( + new RowDataSerializer(RowType.of(new VarCharType()))); + provider = + new DataStatisticsCoordinatorProvider<>( + "DataStatisticsCoordinatorProvider", OPERATOR_ID, statisticsSerializer); + receivingTasks = EventReceivingTasks.createForRunningTasks(); + } + + @Test + @SuppressWarnings("unchecked") + public void testCheckpointAndReset() throws Exception { + RowType rowType = RowType.of(new VarCharType()); + // When coordinator handles events from operator, DataStatisticsUtil#deserializeDataStatistics + // deserializes bytes into BinaryRowData + BinaryRowData binaryRowDataA = + new RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("a"))); + BinaryRowData binaryRowDataB = + new RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("b"))); + BinaryRowData binaryRowDataC = + new RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("c"))); + BinaryRowData binaryRowDataD = + new RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("d"))); + BinaryRowData binaryRowDataE = + new RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("e"))); + + RecreateOnResetOperatorCoordinator coordinator = + (RecreateOnResetOperatorCoordinator) + provider.create(new MockOperatorCoordinatorContext(OPERATOR_ID, NUM_SUBTASKS)); + DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>> dataStatisticsCoordinator = + (DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>>) + coordinator.getInternalCoordinator(); + + // Start the coordinator + coordinator.start(); + TestDataStatisticsCoordinator.setAllTasksReady( + NUM_SUBTASKS, dataStatisticsCoordinator, receivingTasks); + MapDataStatistics checkpoint1Subtask0DataStatistic = new MapDataStatistics(); + checkpoint1Subtask0DataStatistic.add(binaryRowDataA); + checkpoint1Subtask0DataStatistic.add(binaryRowDataB); + checkpoint1Subtask0DataStatistic.add(binaryRowDataC); + DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>> + checkpoint1Subtask0DataStatisticEvent = + DataStatisticsEvent.create(1, checkpoint1Subtask0DataStatistic, statisticsSerializer); + + // Handle events from operators for checkpoint 1 + coordinator.handleEventFromOperator(0, 0, checkpoint1Subtask0DataStatisticEvent); + TestDataStatisticsCoordinator.waitForCoordinatorToProcessActions(dataStatisticsCoordinator); + // Verify checkpoint 1 global data statistics + MapDataStatistics checkpoint1GlobalDataStatistics = + (MapDataStatistics) dataStatisticsCoordinator.completedStatistics().dataStatistics(); + Assert.assertEquals( + checkpoint1Subtask0DataStatistic.statistics(), + checkpoint1GlobalDataStatistics.statistics()); + byte[] bytes = waitForCheckpoint(1L, dataStatisticsCoordinator); Review Comment: will rename ########## flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java: ########## @@ -118,20 +126,16 @@ public void testProcessElement() throws Exception { testHarness = createHarness(this.operator)) { StateInitializationContext stateContext = getStateContext(); operator.initializeState(stateContext); - operator.processElement(new StreamRecord<>(GenericRowData.of(StringData.fromString("a")))); - operator.processElement(new StreamRecord<>(GenericRowData.of(StringData.fromString("a")))); - operator.processElement(new StreamRecord<>(GenericRowData.of(StringData.fromString("b")))); + operator.processElement(new StreamRecord<>(binaryRowDataA)); Review Comment: since binaryRowData is defined for `testRestoreState`, I used the defined class variable here as well. ########## flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java: ########## @@ -0,0 +1,395 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.operators.coordination.OperatorCoordinator; +import org.apache.flink.runtime.operators.coordination.OperatorEvent; +import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.FatalExitExceptionHandler; +import org.apache.flink.util.FlinkRuntimeException; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.ThrowableCatchingRunnable; +import org.apache.flink.util.function.ThrowingRunnable; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * DataStatisticsCoordinator receives {@link DataStatisticsEvent} from {@link + * DataStatisticsOperator} every subtask and then merge them together. Once aggregation for all + * subtasks data statistics completes, DataStatisticsCoordinator will send the aggregated data + * statistics back to {@link DataStatisticsOperator}. In the end a custom partitioner will + * distribute traffic based on the aggregated data statistics to improve data clustering. + */ +@Internal +class DataStatisticsCoordinator<D extends DataStatistics<D, S>, S> implements OperatorCoordinator { + private static final Logger LOG = LoggerFactory.getLogger(DataStatisticsCoordinator.class); + + private final String operatorName; + private final ExecutorService coordinatorExecutor; + private final OperatorCoordinator.Context operatorCoordinatorContext; + private final SubtaskGateways subtaskGateways; + private final CoordinatorExecutorThreadFactory coordinatorThreadFactory; + private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer; + private final transient AggregatedStatisticsTracker<D, S> aggregatedStatisticsTracker; + private volatile AggregatedStatistics<D, S> completedStatistics; + private volatile boolean started; + + DataStatisticsCoordinator( + String operatorName, + OperatorCoordinator.Context context, + TypeSerializer<DataStatistics<D, S>> statisticsSerializer) { + this.operatorName = operatorName; + this.coordinatorThreadFactory = + new CoordinatorExecutorThreadFactory( + "DataStatisticsCoordinator-" + operatorName, context.getUserCodeClassloader()); + this.coordinatorExecutor = Executors.newSingleThreadExecutor(coordinatorThreadFactory); + this.operatorCoordinatorContext = context; + this.subtaskGateways = new SubtaskGateways(operatorName, parallelism()); + this.statisticsSerializer = statisticsSerializer; + this.aggregatedStatisticsTracker = + new AggregatedStatisticsTracker<>(operatorName, statisticsSerializer, parallelism()); + } + + @Override + public void start() throws Exception { + LOG.info("Starting data statistics coordinator: {}.", operatorName); + started = true; + } + + @Override + public void close() throws Exception { + coordinatorExecutor.shutdown(); + LOG.info("Closed data statistics coordinator: {}.", operatorName); + } + + @VisibleForTesting + void callInCoordinatorThread(Callable<Void> callable, String errorMessage) { + ensureStarted(); + // Ensure the task is done by the coordinator executor. + if (!coordinatorThreadFactory.isCurrentThreadCoordinatorThread()) { + try { + Callable<Void> guardedCallable = + () -> { + try { + return callable.call(); + } catch (Throwable t) { + LOG.error( + "Uncaught Exception in data statistics coordinator: {} executor", + operatorName, + t); + ExceptionUtils.rethrowException(t); + return null; + } + }; + + coordinatorExecutor.submit(guardedCallable).get(); + } catch (InterruptedException | ExecutionException e) { + throw new FlinkRuntimeException(errorMessage, e); + } + } else { + try { + callable.call(); + } catch (Throwable t) { + LOG.error( + "Uncaught Exception in data statistics coordinator: {} executor", operatorName, t); + throw new FlinkRuntimeException(errorMessage, t); + } + } + } + + public void runInCoordinatorThread(Runnable runnable) { + this.coordinatorExecutor.execute( + new ThrowableCatchingRunnable( + throwable -> + this.coordinatorThreadFactory.uncaughtException(Thread.currentThread(), throwable), + runnable)); + } + + private void runInCoordinatorThread(ThrowingRunnable<Throwable> action, String actionString) { + ensureStarted(); + runInCoordinatorThread( + () -> { + try { + action.run(); + } catch (Throwable t) { + ExceptionUtils.rethrowIfFatalErrorOrOOM(t); + LOG.error( + "Uncaught exception in the data statistics coordinator: {} while {}. Triggering job failover.", + operatorName, + actionString, + t); + operatorCoordinatorContext.failJob(t); + } + }); + } + + private void ensureStarted() { + Preconditions.checkState(started, "The coordinator of %s has not started yet.", operatorName); + } + + private int parallelism() { + return operatorCoordinatorContext.currentParallelism(); + } + + private void handleDataStatisticRequest(int subtask, DataStatisticsEvent<D, S> event) { + AggregatedStatistics<D, S> aggregatedStatistics = + aggregatedStatisticsTracker.receiveDataStatisticEventAndCheckCompletion(subtask, event); + + if (aggregatedStatistics != null) { + completedStatistics = aggregatedStatistics; + sendDataStatisticsToSubtasks( + completedStatistics.checkpointId(), completedStatistics.dataStatistics()); + } + } + + private void sendDataStatisticsToSubtasks( + long checkpointId, DataStatistics<D, S> globalDataStatistics) { + callInCoordinatorThread( + () -> { + DataStatisticsEvent<D, S> dataStatisticsEvent = + DataStatisticsEvent.create(checkpointId, globalDataStatistics, statisticsSerializer); + int parallelism = parallelism(); + for (int i = 0; i < parallelism; ++i) { + subtaskGateways.getSubtaskGateway(i).sendEvent(dataStatisticsEvent); + } + + return null; + }, + String.format( + "Failed to send operator %s coordinator global data statistics for checkpoint %d", + operatorName, checkpointId)); + } + + @Override + @SuppressWarnings("unchecked") + public void handleEventFromOperator(int subtask, int attemptNumber, OperatorEvent event) { + runInCoordinatorThread( + () -> { + LOG.debug( + "Handling event from subtask {} (#{}) of {}: {}", + subtask, + attemptNumber, + operatorName, + event); + Preconditions.checkArgument(event instanceof DataStatisticsEvent); + handleDataStatisticRequest(subtask, ((DataStatisticsEvent<D, S>) event)); + }, + String.format( + "handling operator event %s from subtask %d (#%d)", + event.getClass(), subtask, attemptNumber)); + } + + @Override + public void checkpointCoordinator(long checkpointId, CompletableFuture<byte[]> resultFuture) { + runInCoordinatorThread( + () -> { + LOG.debug( + "Snapshotting data statistics coordinator {} for checkpoint {}", + operatorName, + checkpointId); + resultFuture.complete( + DataStatisticsUtil.serializeAggregatedStatistics( + completedStatistics, statisticsSerializer)); + }, + String.format("taking checkpoint %d", checkpointId)); + } + + @Override + public void notifyCheckpointComplete(long checkpointId) {} + + @Override + public void resetToCheckpoint(long checkpointId, @Nullable byte[] checkpointData) + throws Exception { + Preconditions.checkState( + !started, "The coordinator %s can only be reset if it was not yet started", operatorName); + + if (checkpointData == null) { + LOG.info( + "Data statistic coordinator {} checkpoint {} data is null. Cannot be restored.", + operatorName, + checkpointId); + return; + } + + LOG.info( + "Restoring data statistic coordinator {} from checkpoint {}.", operatorName, checkpointId); + completedStatistics = + DataStatisticsUtil.deserializeAggregatedStatistics(checkpointData, statisticsSerializer); + } + + @Override + public void subtaskReset(int subtask, long checkpointId) { + runInCoordinatorThread( + () -> { + LOG.info( + "Resetting subtask {} to checkpoint {} for data statistics {}.", Review Comment: will take the suggestion -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
