package org.nd4j.linalg.api.parallel.tasks.cpu.accumulation;

import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Future;
import java.util.concurrent.RecursiveAction;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.parallel.tasks.Task;
import org.nd4j.linalg.api.parallel.tasks.TaskExecutorProvider;
import org.nd4j.linalg.api.shape.tensor.TensorCalculator;

/* loaded from: input_file:org/nd4j/linalg/api/parallel/tasks/cpu/accumulation/CPUAccumulations1dAction.class */
public class CPUAccumulations1dAction extends RecursiveAction implements Task<Void> {
    private Future future;
    private List<Task<?>> subTasks = null;
    private Accumulation op;
    private int threshold;
    private TensorCalculator tCalcx;
    private TensorCalculator tCalcy;
    private int firstTensor;
    private int lastTensor;
    private INDArray output;

    public CPUAccumulations1dAction(Accumulation accumulation, int i, TensorCalculator tensorCalculator, TensorCalculator tensorCalculator2, int i2, int i3, INDArray iNDArray) {
        this.op = accumulation;
        this.threshold = i;
        this.tCalcx = tensorCalculator;
        this.tCalcy = tensorCalculator2;
        this.firstTensor = i2;
        this.lastTensor = i3;
        this.output = iNDArray;
    }

    @Override // java.util.concurrent.RecursiveAction
    protected void compute() {
        int i = (this.lastTensor - this.firstTensor) + 1;
        int tensorLength = i * this.tCalcx.getTensorLength();
        if (i > 1 && tensorLength > this.threshold) {
            int i2 = i / 2;
            CPUAccumulations1dAction cPUAccumulations1dAction = new CPUAccumulations1dAction(this.op, this.threshold, this.tCalcx, this.tCalcy, this.firstTensor, (this.firstTensor + i2) - 1, this.output);
            cPUAccumulations1dAction.fork();
            CPUAccumulations1dAction cPUAccumulations1dAction2 = new CPUAccumulations1dAction(this.op, this.threshold, this.tCalcx, this.tCalcy, this.firstTensor + i2, this.lastTensor, this.output);
            cPUAccumulations1dAction2.fork();
            cPUAccumulations1dAction.join();
            cPUAccumulations1dAction2.join();
            return;
        }
        if (i != 1 || tensorLength <= this.threshold) {
            execute();
            return;
        }
        int offsetForTensor = this.tCalcx.getOffsetForTensor(this.firstTensor);
        int offsetForTensor2 = this.tCalcy != null ? this.tCalcy.getOffsetForTensor(this.firstTensor) : 0;
        int elementWiseStrideForTensor = this.tCalcx.getElementWiseStrideForTensor();
        int elementWiseStrideForTensor2 = this.tCalcy != null ? this.tCalcy.getElementWiseStrideForTensor() : 0;
        int tensorLength2 = this.tCalcx.getTensorLength();
        int i3 = tensorLength2 / 2;
        CPUAccumulationTask cPUAccumulationTask = new CPUAccumulationTask(this.op, this.threshold, i3, offsetForTensor, offsetForTensor2, elementWiseStrideForTensor, elementWiseStrideForTensor2, false);
        cPUAccumulationTask.fork();
        CPUAccumulationTask cPUAccumulationTask2 = new CPUAccumulationTask(this.op, this.threshold, tensorLength2 - i3, offsetForTensor + (i3 * elementWiseStrideForTensor), offsetForTensor2 + (i3 * elementWiseStrideForTensor2), elementWiseStrideForTensor, elementWiseStrideForTensor2, false);
        cPUAccumulationTask2.fork();
        this.output.putScalar(this.firstTensor, this.op.calculateFinalResult(this.op.combineSubResults(cPUAccumulationTask.join().doubleValue(), cPUAccumulationTask2.join().doubleValue()), this.tCalcx.getTensorLength()));
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.Task, java.util.concurrent.Callable
    public Void call() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    private void execute() {
        DataBuffer data = this.op.x().data();
        DataBuffer data2 = this.op.y() != null ? this.op.y().data() : null;
        int elementWiseStrideForTensor = this.tCalcx.getElementWiseStrideForTensor();
        int tensorLength = this.tCalcx.getTensorLength();
        if (data2 == null) {
            for (int i = this.firstTensor; i <= this.lastTensor; i++) {
                int offsetForTensor = this.tCalcx.getOffsetForTensor(i);
                if (data.allocationMode() != DataBuffer.AllocationMode.HEAP) {
                    ByteBuffer asNio = data.asNio();
                    FloatBuffer asFloatBuffer = asNio.asFloatBuffer();
                    if (data.dataType() == DataBuffer.Type.FLOAT) {
                        float zeroFloat = this.op.zeroFloat();
                        int i2 = 0;
                        if (elementWiseStrideForTensor == 1) {
                            int i3 = 0;
                            while (i3 < tensorLength) {
                                zeroFloat = this.op.update(zeroFloat, this.op.op(asFloatBuffer.get(offsetForTensor + i3)));
                                i3++;
                                i2++;
                            }
                        } else {
                            int i4 = 0;
                            while (i4 < tensorLength) {
                                zeroFloat = this.op.update(zeroFloat, this.op.op(asFloatBuffer.get(offsetForTensor + (i4 * elementWiseStrideForTensor))));
                                i4++;
                                i2++;
                            }
                        }
                        this.output.putScalar(i, this.op.calculateFinalResult(zeroFloat, tensorLength));
                    } else {
                        DoubleBuffer asDoubleBuffer = asNio.asDoubleBuffer();
                        double zeroDouble = this.op.zeroDouble();
                        int i5 = 0;
                        if (elementWiseStrideForTensor == 1) {
                            int i6 = 0;
                            while (i6 < tensorLength) {
                                zeroDouble = this.op.update(zeroDouble, this.op.op(asDoubleBuffer.get(offsetForTensor + i6)));
                                i6++;
                                i5++;
                            }
                        } else {
                            int i7 = 0;
                            while (i7 < tensorLength) {
                                zeroDouble = this.op.update(zeroDouble, this.op.op(asDoubleBuffer.get(offsetForTensor + (i7 * elementWiseStrideForTensor))));
                                i7++;
                                i5++;
                            }
                        }
                        this.output.putScalar(i, this.op.calculateFinalResult(zeroDouble, tensorLength));
                    }
                } else if (data.dataType() == DataBuffer.Type.FLOAT) {
                    float[] fArr = (float[]) data.array();
                    float zeroFloat2 = this.op.zeroFloat();
                    if (elementWiseStrideForTensor == 1) {
                        for (int i8 = 0; i8 < tensorLength; i8++) {
                            zeroFloat2 = this.op.update(zeroFloat2, this.op.op(fArr[offsetForTensor + i8]));
                        }
                    } else {
                        for (int i9 = 0; i9 < tensorLength; i9++) {
                            zeroFloat2 = this.op.update(zeroFloat2, this.op.op(fArr[offsetForTensor + (i9 * elementWiseStrideForTensor)]));
                        }
                    }
                    this.output.putScalar(i, this.op.calculateFinalResult(zeroFloat2, tensorLength));
                } else {
                    double[] dArr = (double[]) data.array();
                    double zeroDouble2 = this.op.zeroDouble();
                    if (elementWiseStrideForTensor == 1) {
                        for (int i10 = 0; i10 < tensorLength; i10++) {
                            zeroDouble2 = this.op.update(zeroDouble2, this.op.op(dArr[offsetForTensor + i10]));
                        }
                    } else {
                        for (int i11 = 0; i11 < tensorLength; i11++) {
                            zeroDouble2 = this.op.update(zeroDouble2, this.op.op(dArr[offsetForTensor + (i11 * elementWiseStrideForTensor)]));
                        }
                    }
                    this.output.putScalar(i, this.op.calculateFinalResult(zeroDouble2, tensorLength));
                }
            }
            return;
        }
        int elementWiseStrideForTensor2 = this.tCalcy.getElementWiseStrideForTensor();
        for (int i12 = this.firstTensor; i12 <= this.lastTensor; i12++) {
            int offsetForTensor2 = this.tCalcx.getOffsetForTensor(i12);
            int offsetForTensor3 = this.tCalcy.getOffsetForTensor(i12);
            if (data.allocationMode() != DataBuffer.AllocationMode.HEAP) {
                ByteBuffer asNio2 = data.asNio();
                ByteBuffer asNio3 = data2.asNio();
                if (data.dataType() == DataBuffer.Type.FLOAT) {
                    FloatBuffer asFloatBuffer2 = asNio2.asFloatBuffer();
                    FloatBuffer asFloatBuffer3 = asNio3.asFloatBuffer();
                    float zeroFloat3 = this.op.zeroFloat();
                    int i13 = 0;
                    if (elementWiseStrideForTensor == 1 && elementWiseStrideForTensor2 == 1) {
                        int i14 = 0;
                        while (i14 < tensorLength) {
                            zeroFloat3 = this.op.update(zeroFloat3, this.op.op(asFloatBuffer2.get(offsetForTensor2 + i14), asFloatBuffer3.get(offsetForTensor3 + i14)));
                            i14++;
                            i13++;
                        }
                    } else {
                        int i15 = 0;
                        while (i15 < tensorLength) {
                            zeroFloat3 = this.op.update(zeroFloat3, this.op.op(asFloatBuffer2.get(offsetForTensor2 + (i15 * elementWiseStrideForTensor)), asFloatBuffer3.get(offsetForTensor3 + (i15 * elementWiseStrideForTensor2))));
                            i15++;
                            i13++;
                        }
                    }
                    this.output.putScalar(i12, this.op.calculateFinalResult(zeroFloat3, tensorLength));
                } else {
                    DoubleBuffer asDoubleBuffer2 = asNio2.asDoubleBuffer();
                    DoubleBuffer asDoubleBuffer3 = asNio3.asDoubleBuffer();
                    double zeroDouble3 = this.op.zeroDouble();
                    int i16 = 0;
                    if (elementWiseStrideForTensor == 1 && elementWiseStrideForTensor2 == 1) {
                        int i17 = 0;
                        while (i17 < tensorLength) {
                            zeroDouble3 = this.op.update(zeroDouble3, this.op.op(asDoubleBuffer2.get(offsetForTensor2 + i17), asDoubleBuffer3.get(offsetForTensor3 + i17)));
                            i17++;
                            i16++;
                        }
                    } else {
                        int i18 = 0;
                        while (i18 < tensorLength) {
                            zeroDouble3 = this.op.update(zeroDouble3, this.op.op(asDoubleBuffer2.get(offsetForTensor2 + (i18 * elementWiseStrideForTensor)), asDoubleBuffer3.get(offsetForTensor3 + (i18 * elementWiseStrideForTensor2))));
                            i18++;
                            i16++;
                        }
                    }
                    this.output.putScalar(i12, this.op.calculateFinalResult(zeroDouble3, tensorLength));
                }
            } else if (data.dataType() == DataBuffer.Type.FLOAT) {
                float[] fArr2 = (float[]) data.array();
                float[] fArr3 = (float[]) data2.array();
                float zeroFloat4 = this.op.zeroFloat();
                if (elementWiseStrideForTensor == 1 && elementWiseStrideForTensor2 == 1) {
                    for (int i19 = 0; i19 < tensorLength; i19++) {
                        zeroFloat4 = this.op.update(zeroFloat4, this.op.op(fArr2[offsetForTensor2 + i19], fArr3[offsetForTensor3 + i19]));
                    }
                } else {
                    for (int i20 = 0; i20 < tensorLength; i20++) {
                        zeroFloat4 = this.op.update(zeroFloat4, this.op.op(fArr2[offsetForTensor2 + (i20 * elementWiseStrideForTensor)], fArr3[offsetForTensor3 + (i20 * elementWiseStrideForTensor2)]));
                    }
                }
                this.output.putScalar(i12, this.op.calculateFinalResult(zeroFloat4, tensorLength));
            } else {
                double[] dArr2 = (double[]) data.array();
                double[] dArr3 = (double[]) data2.array();
                double zeroDouble4 = this.op.zeroDouble();
                if (elementWiseStrideForTensor == 1 && elementWiseStrideForTensor2 == 1) {
                    for (int i21 = 0; i21 < tensorLength; i21++) {
                        zeroDouble4 = this.op.update(zeroDouble4, this.op.op(dArr2[offsetForTensor2 + i21], dArr3[offsetForTensor3 + i21]));
                    }
                } else {
                    for (int i22 = 0; i22 < tensorLength; i22++) {
                        zeroDouble4 = this.op.update(zeroDouble4, this.op.op(dArr2[offsetForTensor2 + (i22 * elementWiseStrideForTensor)], dArr3[offsetForTensor3 + (i22 * elementWiseStrideForTensor2)]));
                    }
                }
                this.output.putScalar(i12, this.op.calculateFinalResult(zeroDouble4, tensorLength));
            }
        }
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.nd4j.linalg.api.parallel.tasks.Task
    public Void invokeBlocking() {
        invokeAsync();
        return blockUntilComplete();
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.Task
    public void invokeAsync() {
        this.future = TaskExecutorProvider.getTaskExecutor().executeAsync(this);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.nd4j.linalg.api.parallel.tasks.Task
    public Void blockUntilComplete() {
        if (this.future == null) {
            invokeAsync();
        }
        try {
            this.future.get();
            if (this.subTasks == null) {
                return null;
            }
            Iterator<Task<?>> it = this.subTasks.iterator();
            while (it.hasNext()) {
                it.next().blockUntilComplete();
            }
            return null;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}
