/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.nodes.exec.batch;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rex.RexWindowBound;
import org.apache.calcite.sql.SqlKind;
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.table.api.TableConfig;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.api.config.ExecutionConfigOptions;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.binary.BinaryRowData;
import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator;
import org.apache.flink.table.planner.codegen.over.MultiFieldRangeBoundComparatorCodeGenerator;
import org.apache.flink.table.planner.codegen.over.RangeBoundComparatorCodeGenerator;
import org.apache.flink.table.planner.codegen.sort.ComparatorCodeGenerator;
import org.apache.flink.table.planner.delegation.PlannerBase;
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
import org.apache.flink.table.planner.plan.nodes.exec.InputProperty;
import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecOverAggregateBase;
import org.apache.flink.table.planner.plan.nodes.exec.spec.OverSpec;
import org.apache.flink.table.planner.plan.nodes.exec.spec.SortSpec;
import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil;
import org.apache.flink.table.planner.plan.utils.AggregateInfoList;
import org.apache.flink.table.planner.plan.utils.AggregateUtil;
import org.apache.flink.table.planner.plan.utils.OverAggregateUtil;
import org.apache.flink.table.planner.plan.utils.SortUtil;
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
import org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction;
import org.apache.flink.table.runtime.generated.GeneratedRecordComparator;
import org.apache.flink.table.runtime.operators.over.BufferDataOverWindowOperator;
import org.apache.flink.table.runtime.operators.over.NonBufferOverWindowOperator;
import org.apache.flink.table.runtime.operators.over.frame.InsensitiveOverFrame;
import org.apache.flink.table.runtime.operators.over.frame.OffsetOverFrame;
import org.apache.flink.table.runtime.operators.over.frame.OverWindowFrame;
import org.apache.flink.table.runtime.operators.over.frame.RangeSlidingOverFrame;
import org.apache.flink.table.runtime.operators.over.frame.RangeUnboundedFollowingOverFrame;
import org.apache.flink.table.runtime.operators.over.frame.RangeUnboundedPrecedingOverFrame;
import org.apache.flink.table.runtime.operators.over.frame.RowSlidingOverFrame;
import org.apache.flink.table.runtime.operators.over.frame.RowUnboundedFollowingOverFrame;
import org.apache.flink.table.runtime.operators.over.frame.RowUnboundedPrecedingOverFrame;
import org.apache.flink.table.runtime.operators.over.frame.UnboundedOverWindowFrame;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;

public class BatchExecOverAggregate
extends BatchExecOverAggregateBase {
    public BatchExecOverAggregate(OverSpec overSpec, InputProperty inputProperty, RowType outputType, String description) {
        super(overSpec, inputProperty, outputType, description);
    }

    @Override
    protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
        long managedMemory;
        BufferDataOverWindowOperator operator;
        ExecEdge inputEdge = this.getInputEdges().get(0);
        Transformation<?> inputTransform = inputEdge.translateToPlan(planner);
        RowType inputType = (RowType)inputEdge.getOutputType();
        TableConfig tableConfig = planner.getTableConfig();
        int[] partitionFields = this.overSpec.getPartition().getFieldIndices();
        GeneratedRecordComparator genComparator = ComparatorCodeGenerator.gen(tableConfig, "SortComparator", inputType, SortUtil.getAscendingSortSpec(partitionFields));
        RowType inputTypeWithConstants = this.getInputTypeWithConstants();
        SortSpec sortSpec = this.overSpec.getGroups().get(this.overSpec.getGroups().size() - 1).getSort();
        if (!this.needBufferData()) {
            int numOfGroup = this.overSpec.getGroups().size();
            GeneratedAggsHandleFunction[] aggsHandlers = new GeneratedAggsHandleFunction[numOfGroup];
            boolean[] resetAccumulators = new boolean[numOfGroup];
            for (int i = 0; i < numOfGroup; ++i) {
                OverSpec.GroupSpec group = this.overSpec.getGroups().get(i);
                AggregateInfoList aggInfoList = AggregateUtil.transformToBatchAggregateInfoList(inputTypeWithConstants, JavaScalaConversionUtil.toScala(group.getAggCalls()), null, sortSpec.getFieldIndices());
                AggsHandlerCodeGenerator generator = new AggsHandlerCodeGenerator(new CodeGeneratorContext(tableConfig), planner.getRelBuilder(), JavaScalaConversionUtil.toScala(inputType.getChildren()), false);
                aggsHandlers[i] = generator.needAccumulate().withConstants(JavaScalaConversionUtil.toScala(this.getConstants())).generateAggsHandler("BoundedOverAggregateHelper", aggInfoList);
                BatchExecOverAggregateBase.OverWindowMode mode = this.inferGroupMode(group);
                resetAccumulators[i] = mode == BatchExecOverAggregateBase.OverWindowMode.ROW && group.getLowerBound().isCurrentRow() && group.getUpperBound().isCurrentRow();
            }
            operator = new NonBufferOverWindowOperator(aggsHandlers, genComparator, resetAccumulators);
            managedMemory = 0L;
        } else {
            List<OverWindowFrame> windowFrames = this.createOverWindowFrames(planner, inputType, sortSpec, inputTypeWithConstants);
            operator = new BufferDataOverWindowOperator(windowFrames.toArray(new OverWindowFrame[0]), genComparator, inputType.getChildren().stream().allMatch(BinaryRowData::isInFixedLengthPart));
            managedMemory = ((MemorySize)tableConfig.getConfiguration().get(ExecutionConfigOptions.TABLE_EXEC_RESOURCE_EXTERNAL_BUFFER_MEMORY)).getBytes();
        }
        return ExecNodeUtil.createOneInputTransformation(inputTransform, this.getDescription(), SimpleOperatorFactory.of((StreamOperator)operator), InternalTypeInfo.of((LogicalType)this.getOutputType()), inputTransform.getParallelism(), managedMemory);
    }

    private List<OverWindowFrame> createOverWindowFrames(PlannerBase planner, RowType inputType, SortSpec sortSpec, RowType inputTypeWithConstants) {
        ArrayList<OverWindowFrame> windowFrames = new ArrayList<OverWindowFrame>();
        for (OverSpec.GroupSpec group : this.overSpec.getGroups()) {
            InsensitiveOverFrame frame;
            BatchExecOverAggregateBase.OverWindowMode mode = this.inferGroupMode(group);
            if (mode == BatchExecOverAggregateBase.OverWindowMode.OFFSET) {
                for (AggregateCall aggCall : group.getAggCalls()) {
                    OffsetOverFrame.CalcOffsetFunc & Serializable calcOffsetFunc;
                    Long offset;
                    long flag;
                    AggregateInfoList aggInfoList = AggregateUtil.transformToBatchAggregateInfoList(inputTypeWithConstants, JavaScalaConversionUtil.toScala(Collections.singletonList(aggCall)), new boolean[]{true}, sortSpec.getFieldIndices());
                    AggsHandlerCodeGenerator generator = new AggsHandlerCodeGenerator(new CodeGeneratorContext(planner.getTableConfig()), planner.getRelBuilder(), JavaScalaConversionUtil.toScala(inputType.getChildren()), false);
                    GeneratedAggsHandleFunction genAggsHandler = generator.needAccumulate().needRetract().withConstants(JavaScalaConversionUtil.toScala(this.getConstants())).generateAggsHandler("BoundedOverAggregateHelper", aggInfoList);
                    long l = flag = aggCall.getAggregation().kind == SqlKind.LEAD ? 1L : -1L;
                    if (aggCall.getArgList().size() >= 2) {
                        int constantIndex = aggCall.getArgList().get(1) - this.overSpec.getOriginalInputFields();
                        if (constantIndex < 0) {
                            offset = null;
                            int rowIndex = aggCall.getArgList().get(1);
                            switch (inputType.getTypeAt(rowIndex).getTypeRoot()) {
                                case BIGINT: {
                                    calcOffsetFunc = (OffsetOverFrame.CalcOffsetFunc & Serializable)row -> row.getLong(rowIndex) * flag;
                                    break;
                                }
                                case INTEGER: {
                                    calcOffsetFunc = (OffsetOverFrame.CalcOffsetFunc & Serializable)row -> (long)row.getInt(rowIndex) * flag;
                                    break;
                                }
                                case SMALLINT: {
                                    calcOffsetFunc = (OffsetOverFrame.CalcOffsetFunc & Serializable)row -> (long)row.getShort(rowIndex) * flag;
                                    break;
                                }
                                default: {
                                    throw new RuntimeException("The column type must be in long/int/short.");
                                }
                            }
                        } else {
                            long constantOffset = this.getConstants().get(constantIndex).getValueAs(Long.class);
                            offset = constantOffset * flag;
                            calcOffsetFunc = null;
                        }
                    } else {
                        offset = flag;
                        calcOffsetFunc = null;
                    }
                    windowFrames.add((OverWindowFrame)new OffsetOverFrame(genAggsHandler, offset, calcOffsetFunc));
                }
                continue;
            }
            AggregateInfoList aggInfoList = AggregateUtil.transformToBatchAggregateInfoList(inputTypeWithConstants, JavaScalaConversionUtil.toScala(group.getAggCalls()), null, sortSpec.getFieldIndices());
            AggsHandlerCodeGenerator generator = new AggsHandlerCodeGenerator(new CodeGeneratorContext(planner.getTableConfig()), planner.getRelBuilder(), JavaScalaConversionUtil.toScala(inputType.getChildren()), false);
            GeneratedAggsHandleFunction genAggsHandler = generator.needAccumulate().withConstants(JavaScalaConversionUtil.toScala(this.getConstants())).generateAggsHandler("BoundedOverAggregateHelper", aggInfoList);
            RowType valueType = generator.valueType();
            switch (mode) {
                case RANGE: {
                    GeneratedRecordComparator genBoundComparator;
                    if (this.isUnboundedWindow(group)) {
                        frame = new UnboundedOverWindowFrame(genAggsHandler, valueType);
                        break;
                    }
                    if (this.isUnboundedPrecedingWindow(group)) {
                        genBoundComparator = this.createBoundComparator(planner, sortSpec, group.getUpperBound(), false, inputType);
                        frame = new RangeUnboundedPrecedingOverFrame(genAggsHandler, genBoundComparator);
                        break;
                    }
                    if (this.isUnboundedFollowingWindow(group)) {
                        genBoundComparator = this.createBoundComparator(planner, sortSpec, group.getLowerBound(), true, inputType);
                        frame = new RangeUnboundedFollowingOverFrame(valueType, genAggsHandler, genBoundComparator);
                        break;
                    }
                    if (this.isSlidingWindow(group)) {
                        GeneratedRecordComparator genLeftBoundComparator = this.createBoundComparator(planner, sortSpec, group.getLowerBound(), true, inputType);
                        GeneratedRecordComparator genRightBoundComparator = this.createBoundComparator(planner, sortSpec, group.getUpperBound(), false, inputType);
                        frame = new RangeSlidingOverFrame(inputType, valueType, genAggsHandler, genLeftBoundComparator, genRightBoundComparator);
                        break;
                    }
                    throw new TableException("This should not happen.");
                }
                case ROW: {
                    if (this.isUnboundedWindow(group)) {
                        frame = new UnboundedOverWindowFrame(genAggsHandler, valueType);
                        break;
                    }
                    if (this.isUnboundedPrecedingWindow(group)) {
                        frame = new RowUnboundedPrecedingOverFrame(genAggsHandler, OverAggregateUtil.getLongBoundary(this.overSpec, group.getUpperBound()));
                        break;
                    }
                    if (this.isUnboundedFollowingWindow(group)) {
                        frame = new RowUnboundedFollowingOverFrame(valueType, genAggsHandler, OverAggregateUtil.getLongBoundary(this.overSpec, group.getLowerBound()));
                        break;
                    }
                    if (this.isSlidingWindow(group)) {
                        frame = new RowSlidingOverFrame(inputType, valueType, genAggsHandler, OverAggregateUtil.getLongBoundary(this.overSpec, group.getLowerBound()), OverAggregateUtil.getLongBoundary(this.overSpec, group.getUpperBound()));
                        break;
                    }
                    throw new TableException("This should not happen.");
                }
                case INSENSITIVE: {
                    frame = new InsensitiveOverFrame(genAggsHandler);
                    break;
                }
                default: {
                    throw new TableException("This should not happen.");
                }
            }
            windowFrames.add((OverWindowFrame)frame);
        }
        return windowFrames;
    }

    private GeneratedRecordComparator createBoundComparator(PlannerBase planner, SortSpec sortSpec, RexWindowBound windowBound, boolean isLowerBound, RowType inputType) {
        Object bound = OverAggregateUtil.getBoundary(this.overSpec, windowBound);
        if (!windowBound.isCurrentRow()) {
            int sortKey = sortSpec.getFieldIndices()[0];
            return new RangeBoundComparatorCodeGenerator(planner.getRelBuilder(), planner.getTableConfig(), inputType, bound, sortKey, inputType.getTypeAt(sortKey), sortSpec.getAscendingOrders()[0], isLowerBound).generateBoundComparator("RangeBoundComparator");
        }
        return new MultiFieldRangeBoundComparatorCodeGenerator(planner.getTableConfig(), inputType, sortSpec, isLowerBound).generateBoundComparator("MultiFieldRangeBoundComparator");
    }

    private boolean needBufferData() {
        return this.overSpec.getGroups().stream().anyMatch(group -> {
            BatchExecOverAggregateBase.OverWindowMode mode = this.inferGroupMode((OverSpec.GroupSpec)group);
            switch (mode) {
                case INSENSITIVE: {
                    return false;
                }
                case ROW: {
                    return !(group.getLowerBound().isCurrentRow() && group.getUpperBound().isCurrentRow() || group.getLowerBound().isUnbounded() && group.getUpperBound().isCurrentRow());
                }
            }
            return true;
        });
    }
}

