/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.checkpoint;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.checkpoint.OperatorState;
import org.apache.flink.runtime.checkpoint.OperatorStateRepartitioner;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.ChainedStateHandle;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.TaskStateHandles;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class StateAssignmentOperation {
    private static final Logger LOG = LoggerFactory.getLogger(StateAssignmentOperation.class);
    private final Map<JobVertexID, ExecutionJobVertex> tasks;
    private final Map<OperatorID, OperatorState> operatorStates;
    private final boolean allowNonRestoredState;

    public StateAssignmentOperation(Map<JobVertexID, ExecutionJobVertex> tasks, Map<OperatorID, OperatorState> operatorStates, boolean allowNonRestoredState) {
        this.tasks = (Map)Preconditions.checkNotNull(tasks);
        this.operatorStates = (Map)Preconditions.checkNotNull(operatorStates);
        this.allowNonRestoredState = allowNonRestoredState;
    }

    public boolean assignStates() throws Exception {
        HashMap<OperatorID, OperatorState> localOperators = new HashMap<OperatorID, OperatorState>(this.operatorStates);
        Map<JobVertexID, ExecutionJobVertex> localTasks = this.tasks;
        StateAssignmentOperation.checkStateMappingCompleteness(this.allowNonRestoredState, this.operatorStates, this.tasks);
        for (Map.Entry<JobVertexID, ExecutionJobVertex> task : localTasks.entrySet()) {
            ExecutionJobVertex executionJobVertex = task.getValue();
            List<OperatorID> operatorIDs = executionJobVertex.getOperatorIDs();
            List<OperatorID> altOperatorIDs = executionJobVertex.getUserDefinedOperatorIDs();
            ArrayList<OperatorState> operatorStates = new ArrayList<OperatorState>();
            boolean statelessTask = true;
            for (int x = 0; x < operatorIDs.size(); ++x) {
                OperatorID operatorID = altOperatorIDs.get(x) == null ? operatorIDs.get(x) : altOperatorIDs.get(x);
                OperatorState operatorState = (OperatorState)localOperators.remove((Object)operatorID);
                if (operatorState == null) {
                    operatorState = new OperatorState(operatorID, executionJobVertex.getParallelism(), executionJobVertex.getMaxParallelism());
                } else {
                    statelessTask = false;
                }
                operatorStates.add(operatorState);
            }
            if (statelessTask) continue;
            this.assignAttemptState(task.getValue(), operatorStates);
        }
        return true;
    }

    private void assignAttemptState(ExecutionJobVertex executionJobVertex, List<OperatorState> operatorStates) {
        List<OperatorID> operatorIDs = executionJobVertex.getOperatorIDs();
        this.checkParallelismPreconditions(operatorStates, executionJobVertex);
        int newParallelism = executionJobVertex.getParallelism();
        List<KeyGroupRange> keyGroupPartitions = StateAssignmentOperation.createKeyGroupPartitions(executionJobVertex.getMaxParallelism(), newParallelism);
        ArrayList<List<Collection<OperatorStateHandle>>> newManagedOperatorStates = new ArrayList<List<Collection<OperatorStateHandle>>>();
        ArrayList<List<Collection<OperatorStateHandle>>> newRawOperatorStates = new ArrayList<List<Collection<OperatorStateHandle>>>();
        this.reDistributePartitionableStates(operatorStates, newParallelism, newManagedOperatorStates, newRawOperatorStates);
        for (int subTaskIndex = 0; subTaskIndex < newParallelism; ++subTaskIndex) {
            Execution currentExecutionAttempt = executionJobVertex.getTaskVertices()[subTaskIndex].getCurrentExecutionAttempt();
            ArrayList<StreamStateHandle> subNonPartitionableState = new ArrayList<StreamStateHandle>();
            Tuple2<Collection<KeyedStateHandle>, Collection<KeyedStateHandle>> subKeyedState = null;
            ArrayList<Collection<OperatorStateHandle>> subManagedOperatorState = new ArrayList<Collection<OperatorStateHandle>>();
            ArrayList<Collection<OperatorStateHandle>> subRawOperatorState = new ArrayList<Collection<OperatorStateHandle>>();
            for (int operatorIndex = 0; operatorIndex < operatorIDs.size(); ++operatorIndex) {
                OperatorState operatorState = operatorStates.get(operatorIndex);
                int oldParallelism = operatorState.getParallelism();
                this.reAssignSubNonPartitionedStates(operatorState, subTaskIndex, newParallelism, oldParallelism, subNonPartitionableState);
                this.reAssignSubPartitionableState(newManagedOperatorStates, newRawOperatorStates, subTaskIndex, operatorIndex, subManagedOperatorState, subRawOperatorState);
                if (operatorIndex != operatorIDs.size() - 1) continue;
                subKeyedState = this.reAssignSubKeyedStates(operatorState, keyGroupPartitions, subTaskIndex, newParallelism, oldParallelism);
            }
            if (this.allElementsAreNull(subNonPartitionableState) && this.allElementsAreNull(subManagedOperatorState) && this.allElementsAreNull(subRawOperatorState) && subKeyedState == null) continue;
            TaskStateHandles taskStateHandles = new TaskStateHandles(new ChainedStateHandle<StreamStateHandle>(subNonPartitionableState), subManagedOperatorState, subRawOperatorState, subKeyedState != null ? (Collection)subKeyedState.f0 : null, subKeyedState != null ? (Collection)subKeyedState.f1 : null);
            currentExecutionAttempt.setInitialState(taskStateHandles);
        }
    }

    public void checkParallelismPreconditions(List<OperatorState> operatorStates, ExecutionJobVertex executionJobVertex) {
        for (OperatorState operatorState : operatorStates) {
            StateAssignmentOperation.checkParallelismPreconditions(operatorState, executionJobVertex);
        }
    }

    private void reAssignSubPartitionableState(List<List<Collection<OperatorStateHandle>>> newMangedOperatorStates, List<List<Collection<OperatorStateHandle>>> newRawOperatorStates, int subTaskIndex, int operatorIndex, List<Collection<OperatorStateHandle>> subManagedOperatorState, List<Collection<OperatorStateHandle>> subRawOperatorState) {
        if (newMangedOperatorStates.get(operatorIndex) != null) {
            subManagedOperatorState.add(newMangedOperatorStates.get(operatorIndex).get(subTaskIndex));
        } else {
            subManagedOperatorState.add(null);
        }
        if (newRawOperatorStates.get(operatorIndex) != null) {
            subRawOperatorState.add(newRawOperatorStates.get(operatorIndex).get(subTaskIndex));
        } else {
            subRawOperatorState.add(null);
        }
    }

    private Tuple2<Collection<KeyedStateHandle>, Collection<KeyedStateHandle>> reAssignSubKeyedStates(OperatorState operatorState, List<KeyGroupRange> keyGroupPartitions, int subTaskIndex, int newParallelism, int oldParallelism) {
        List<KeyedStateHandle> subRawKeyedState;
        List<KeyedStateHandle> subManagedKeyedState;
        if (newParallelism == oldParallelism) {
            if (operatorState.getState(subTaskIndex) != null) {
                KeyedStateHandle oldSubManagedKeyedState = operatorState.getState(subTaskIndex).getManagedKeyedState();
                KeyedStateHandle oldSubRawKeyedState = operatorState.getState(subTaskIndex).getRawKeyedState();
                subManagedKeyedState = oldSubManagedKeyedState != null ? Collections.singletonList(oldSubManagedKeyedState) : null;
                subRawKeyedState = oldSubRawKeyedState != null ? Collections.singletonList(oldSubRawKeyedState) : null;
            } else {
                subManagedKeyedState = null;
                subRawKeyedState = null;
            }
        } else {
            subManagedKeyedState = StateAssignmentOperation.getManagedKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex));
            subRawKeyedState = StateAssignmentOperation.getRawKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex));
        }
        if (subManagedKeyedState == null && subRawKeyedState == null) {
            return null;
        }
        return new Tuple2(subManagedKeyedState, subRawKeyedState);
    }

    private <X> boolean allElementsAreNull(List<X> nonPartitionableStates) {
        for (X streamStateHandle : nonPartitionableStates) {
            if (streamStateHandle == null) continue;
            return false;
        }
        return true;
    }

    private void reAssignSubNonPartitionedStates(OperatorState operatorState, int subTaskIndex, int newParallelism, int oldParallelism, List<StreamStateHandle> subNonPartitionableState) {
        if (oldParallelism == newParallelism) {
            if (operatorState.getState(subTaskIndex) != null) {
                subNonPartitionableState.add(operatorState.getState(subTaskIndex).getLegacyOperatorState());
            } else {
                subNonPartitionableState.add(null);
            }
        } else {
            subNonPartitionableState.add(null);
        }
    }

    private void reDistributePartitionableStates(List<OperatorState> operatorStates, int newParallelism, List<List<Collection<OperatorStateHandle>>> newManagedOperatorStates, List<List<Collection<OperatorStateHandle>>> newRawOperatorStates) {
        ArrayList<List<OperatorStateHandle>> oldManagedOperatorStates = new ArrayList<List<OperatorStateHandle>>();
        ArrayList<List<OperatorStateHandle>> oldRawOperatorStates = new ArrayList<List<OperatorStateHandle>>();
        this.collectPartionableStates(operatorStates, oldManagedOperatorStates, oldRawOperatorStates);
        OperatorStateRepartitioner opStateRepartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE;
        for (int operatorIndex = 0; operatorIndex < operatorStates.size(); ++operatorIndex) {
            int oldParallelism = operatorStates.get(operatorIndex).getParallelism();
            newManagedOperatorStates.add(StateAssignmentOperation.applyRepartitioner(opStateRepartitioner, (List)oldManagedOperatorStates.get(operatorIndex), oldParallelism, newParallelism));
            newRawOperatorStates.add(StateAssignmentOperation.applyRepartitioner(opStateRepartitioner, (List)oldRawOperatorStates.get(operatorIndex), oldParallelism, newParallelism));
        }
    }

    private void collectPartionableStates(List<OperatorState> operatorStates, List<List<OperatorStateHandle>> managedOperatorStates, List<List<OperatorStateHandle>> rawOperatorStates) {
        for (OperatorState operatorState : operatorStates) {
            ArrayList<OperatorStateHandle> managedOperatorState = null;
            ArrayList<OperatorStateHandle> rawOperatorState = null;
            for (int i = 0; i < operatorState.getParallelism(); ++i) {
                OperatorSubtaskState operatorSubtaskState = operatorState.getState(i);
                if (operatorSubtaskState == null) continue;
                if (operatorSubtaskState.getManagedOperatorState() != null) {
                    if (managedOperatorState == null) {
                        managedOperatorState = new ArrayList<OperatorStateHandle>();
                    }
                    managedOperatorState.add(operatorSubtaskState.getManagedOperatorState());
                }
                if (operatorSubtaskState.getRawOperatorState() == null) continue;
                if (rawOperatorState == null) {
                    rawOperatorState = new ArrayList<OperatorStateHandle>();
                }
                rawOperatorState.add(operatorSubtaskState.getRawOperatorState());
            }
            managedOperatorStates.add(managedOperatorState);
            rawOperatorStates.add(rawOperatorState);
        }
    }

    public static List<KeyedStateHandle> getManagedKeyedStateHandles(OperatorState operatorState, KeyGroupRange subtaskKeyGroupRange) {
        ArrayList<KeyedStateHandle> subtaskKeyedStateHandles = null;
        for (int i = 0; i < operatorState.getParallelism(); ++i) {
            KeyedStateHandle intersectedKeyedStateHandle;
            if (operatorState.getState(i) == null || operatorState.getState(i).getManagedKeyedState() == null || (intersectedKeyedStateHandle = operatorState.getState(i).getManagedKeyedState().getIntersection(subtaskKeyGroupRange)) == null) continue;
            if (subtaskKeyedStateHandles == null) {
                subtaskKeyedStateHandles = new ArrayList<KeyedStateHandle>();
            }
            subtaskKeyedStateHandles.add(intersectedKeyedStateHandle);
        }
        return subtaskKeyedStateHandles;
    }

    public static List<KeyedStateHandle> getRawKeyedStateHandles(OperatorState operatorState, KeyGroupRange subtaskKeyGroupRange) {
        ArrayList<KeyedStateHandle> subtaskKeyedStateHandles = null;
        for (int i = 0; i < operatorState.getParallelism(); ++i) {
            KeyedStateHandle intersectedKeyedStateHandle;
            if (operatorState.getState(i) == null || operatorState.getState(i).getRawKeyedState() == null || (intersectedKeyedStateHandle = operatorState.getState(i).getRawKeyedState().getIntersection(subtaskKeyGroupRange)) == null) continue;
            if (subtaskKeyedStateHandles == null) {
                subtaskKeyedStateHandles = new ArrayList<KeyedStateHandle>();
            }
            subtaskKeyedStateHandles.add(intersectedKeyedStateHandle);
        }
        return subtaskKeyedStateHandles;
    }

    public static List<KeyGroupRange> createKeyGroupPartitions(int numberKeyGroups, int parallelism) {
        Preconditions.checkArgument((numberKeyGroups >= parallelism ? 1 : 0) != 0);
        ArrayList<KeyGroupRange> result = new ArrayList<KeyGroupRange>(parallelism);
        for (int i = 0; i < parallelism; ++i) {
            result.add(KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(numberKeyGroups, parallelism, i));
        }
        return result;
    }

    private static void checkParallelismPreconditions(OperatorState operatorState, ExecutionJobVertex executionJobVertex) {
        if (operatorState.getMaxParallelism() != executionJobVertex.getMaxParallelism()) {
            if (!executionJobVertex.isMaxParallelismConfigured()) {
                LOG.debug("Overriding maximum parallelism for JobVertex {} from {} to {}", new Object[]{executionJobVertex.getJobVertexId(), executionJobVertex.getMaxParallelism(), operatorState.getMaxParallelism()});
                executionJobVertex.setMaxParallelism(operatorState.getMaxParallelism());
            } else {
                throw new IllegalStateException("The maximum parallelism (" + operatorState.getMaxParallelism() + ") with which the latest checkpoint of the execution job vertex " + executionJobVertex + " has been taken and the current maximum parallelism (" + executionJobVertex.getMaxParallelism() + ") changed. This is currently not supported.");
            }
        }
        int oldParallelism = operatorState.getParallelism();
        int newParallelism = executionJobVertex.getParallelism();
        if (operatorState.hasNonPartitionedState() && oldParallelism != newParallelism) {
            throw new IllegalStateException("Cannot restore the latest checkpoint because the operator " + (Object)((Object)executionJobVertex.getJobVertexId()) + " has non-partitioned state and its parallelism changed. The operator " + (Object)((Object)executionJobVertex.getJobVertexId()) + " has parallelism " + newParallelism + " whereas the corresponding state object has a parallelism of " + oldParallelism);
        }
    }

    private static void checkStateMappingCompleteness(boolean allowNonRestoredState, Map<OperatorID, OperatorState> operatorStates, Map<JobVertexID, ExecutionJobVertex> tasks) {
        HashSet<OperatorID> allOperatorIDs = new HashSet<OperatorID>();
        for (ExecutionJobVertex executionJobVertex : tasks.values()) {
            allOperatorIDs.addAll(executionJobVertex.getOperatorIDs());
        }
        for (Map.Entry entry : operatorStates.entrySet()) {
            OperatorState operatorState = (OperatorState)entry.getValue();
            if (allOperatorIDs.contains(entry.getKey())) continue;
            if (allowNonRestoredState) {
                LOG.info("Skipped checkpoint state for operator {}.", (Object)operatorState.getOperatorID());
                continue;
            }
            throw new IllegalStateException("There is no operator for the state " + (Object)((Object)operatorState.getOperatorID()));
        }
    }

    public static List<Collection<OperatorStateHandle>> applyRepartitioner(OperatorStateRepartitioner opStateRepartitioner, List<OperatorStateHandle> chainOpParallelStates, int oldParallelism, int newParallelism) {
        if (chainOpParallelStates == null) {
            return null;
        }
        if (newParallelism != oldParallelism) {
            return opStateRepartitioner.repartitionState(chainOpParallelStates, newParallelism);
        }
        ArrayList<Collection<OperatorStateHandle>> repackStream = new ArrayList<Collection<OperatorStateHandle>>(newParallelism);
        for (OperatorStateHandle operatorStateHandle : chainOpParallelStates) {
            Map<String, OperatorStateHandle.StateMetaInfo> partitionOffsets = operatorStateHandle.getStateNameToPartitionOffsets();
            for (OperatorStateHandle.StateMetaInfo metaInfo : partitionOffsets.values()) {
                if (!OperatorStateHandle.Mode.BROADCAST.equals((Object)metaInfo.getDistributionMode())) continue;
                return opStateRepartitioner.repartitionState(chainOpParallelStates, newParallelism);
            }
            repackStream.add(Collections.singletonList(operatorStateHandle));
        }
        return repackStream;
    }

    public static List<KeyedStateHandle> getKeyedStateHandles(Collection<? extends KeyedStateHandle> keyedStateHandles, KeyGroupRange subtaskKeyGroupRange) {
        ArrayList<KeyedStateHandle> subtaskKeyedStateHandles = new ArrayList<KeyedStateHandle>();
        for (KeyedStateHandle keyedStateHandle : keyedStateHandles) {
            KeyedStateHandle intersectedKeyedStateHandle = keyedStateHandle.getIntersection(subtaskKeyGroupRange);
            if (intersectedKeyedStateHandle == null) continue;
            subtaskKeyedStateHandles.add(intersectedKeyedStateHandle);
        }
        return subtaskKeyedStateHandles;
    }
}

