package org.nd4j.linalg.api.parallel.tasks.cpu.indexaccum;

import io.netty.buffer.ByteBuf;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.parallel.tasks.Task;

/* loaded from: input_file:org/nd4j/linalg/api/parallel/tasks/cpu/indexaccum/CPUIndexAccumulationTask.class */
public class CPUIndexAccumulationTask extends BaseCPUIndexAccumulationTask {
    protected List<Task<Pair<Double, Integer>>> subTasks;

    public CPUIndexAccumulationTask(IndexAccumulation indexAccumulation, int i, int i2, int i3, int i4, int i5, int i6, int i7, boolean z) {
        super(indexAccumulation, i, i2, i3, i4, i5, i6, i7, z);
    }

    public CPUIndexAccumulationTask(IndexAccumulation indexAccumulation, int i, boolean z) {
        super(indexAccumulation, i, z);
    }

    public CPUIndexAccumulationTask(IndexAccumulation indexAccumulation, int i, int i2, int i3, boolean z) {
        super(indexAccumulation, i, i2, i3, z);
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.Task
    public Pair<Double, Integer> blockUntilComplete() {
        if (this.future == null) {
            invokeAsync();
        }
        try {
            Pair<Double, Integer> pair = (Pair) this.future.get();
            if (this.subTasks != null) {
                pair = this.op.zeroPair();
                Iterator<Task<Pair<Double, Integer>>> it = this.subTasks.iterator();
                while (it.hasNext()) {
                    pair = this.op.combineSubResults(pair, it.next().blockUntilComplete());
                }
            }
            if (this.outerTask) {
                this.op.setFinalResult(((Integer) pair.getSecond()).intValue());
            }
            return pair;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.Task, java.util.concurrent.Callable
    public Pair<Double, Integer> call() {
        if (this.doTensorFirst) {
            doTensorFirst(this.op);
        }
        if (this.n <= this.threshold) {
            return execute();
        }
        int i = 1 + (this.n / this.threshold);
        this.subTasks = new ArrayList(i);
        int i2 = this.n / i;
        int i3 = 0;
        int i4 = 0;
        while (i4 < i) {
            int i5 = i4 == i - 1 ? this.n - i3 : i2;
            CPUIndexAccumulationTask cPUIndexAccumulationTask = new CPUIndexAccumulationTask(this.op, this.threshold, i5, this.offsetX + (i3 * this.incrX), this.offsetY + (i3 * this.incrY), this.incrX, this.incrY, this.elementOffset + i3, false);
            cPUIndexAccumulationTask.invokeAsync();
            this.subTasks.add(cPUIndexAccumulationTask);
            i3 += i5;
            i4++;
        }
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // java.util.concurrent.RecursiveTask
    public Pair<Double, Integer> compute() {
        if (this.doTensorFirst) {
            doTensorFirst(this.op);
        }
        if (this.n <= this.threshold) {
            return execute();
        }
        int i = this.n / 2;
        CPUIndexAccumulationTask cPUIndexAccumulationTask = new CPUIndexAccumulationTask(this.op, this.threshold, i, this.offsetX, this.offsetY, this.incrX, this.incrY, this.elementOffset, false);
        cPUIndexAccumulationTask.fork();
        CPUIndexAccumulationTask cPUIndexAccumulationTask2 = new CPUIndexAccumulationTask(this.op, this.threshold, this.n - i, this.offsetX + (i * this.incrX), this.offsetY + (i * this.incrY), this.incrX, this.incrY, this.elementOffset + i, false);
        cPUIndexAccumulationTask2.fork();
        return this.op.combineSubResults(cPUIndexAccumulationTask.join(), cPUIndexAccumulationTask2.join());
    }

    private Pair<Double, Integer> execute() {
        DataBuffer data = this.op.x().data();
        DataBuffer data2 = this.op.y() != null ? this.op.y().data() : null;
        if (data2 == null) {
            if (data.allocationMode() == DataBuffer.AllocationMode.HEAP) {
                if (data.dataType() == DataBuffer.Type.FLOAT) {
                    float[] fArr = (float[]) data.array();
                    float zeroFloat = this.op.zeroFloat();
                    int i = -1;
                    if (this.incrX == 1) {
                        for (int i2 = 0; i2 < this.n; i2++) {
                            i = this.op.update(zeroFloat, i, fArr[this.offsetX + i2], i2);
                            if (i == i2) {
                                zeroFloat = this.op.op(fArr[this.offsetX + i2]);
                            }
                        }
                    } else {
                        for (int i3 = 0; i3 < this.n; i3++) {
                            i = this.op.update(zeroFloat, i, fArr[this.offsetX + (i3 * this.incrX)], i3);
                            if (i == i3) {
                                zeroFloat = this.op.op(fArr[this.offsetX + (i3 * this.incrX)]);
                            }
                        }
                    }
                    int i4 = i + this.elementOffset;
                    if (this.outerTask) {
                        this.op.setFinalResult(i4);
                    }
                    return new Pair<>(Double.valueOf(zeroFloat), Integer.valueOf(i4));
                }
                double[] dArr = (double[]) data.array();
                double zeroDouble = this.op.zeroDouble();
                int i5 = -1;
                if (this.incrX == 1) {
                    for (int i6 = 0; i6 < this.n; i6++) {
                        i5 = this.op.update(zeroDouble, i5, dArr[this.offsetX + i6], i6);
                        if (i5 == i6) {
                            zeroDouble = this.op.op(dArr[this.offsetX + i6]);
                        }
                    }
                } else {
                    for (int i7 = 0; i7 < this.n; i7++) {
                        i5 = this.op.update(zeroDouble, i5, dArr[this.offsetX + (i7 * this.incrX)], i7);
                        if (i5 == i7) {
                            zeroDouble = this.op.op(dArr[this.offsetX + (i7 * this.incrX)]);
                        }
                    }
                }
                int i8 = i5 + this.elementOffset;
                if (this.outerTask) {
                    this.op.setFinalResult(i8);
                }
                return new Pair<>(Double.valueOf(zeroDouble), Integer.valueOf(i8));
            }
            ByteBuf asNetty = data.asNetty();
            if (data.dataType() == DataBuffer.Type.FLOAT) {
                int i9 = 4 * this.offsetX;
                float zeroFloat2 = this.op.zeroFloat();
                int i10 = -1;
                int i11 = 0;
                if (this.incrX == 1) {
                    int i12 = 0;
                    while (i12 < 4 * this.n) {
                        float f = asNetty.getFloat(i9 + i12);
                        i10 = this.op.update(zeroFloat2, i10, f, i11);
                        if (i10 == i11) {
                            zeroFloat2 = this.op.op(f);
                        }
                        i12 += 4;
                        i11++;
                    }
                } else {
                    int i13 = 0;
                    while (i13 < 4 * this.n) {
                        float f2 = asNetty.getFloat(i9 + (i13 * this.incrX));
                        i10 = this.op.update(zeroFloat2, i10, f2, i11);
                        if (i10 == i11) {
                            zeroFloat2 = this.op.op(f2);
                        }
                        i13 += 4;
                        i11++;
                    }
                }
                int i14 = i10 + this.elementOffset;
                if (this.outerTask) {
                    this.op.setFinalResult(i14);
                }
                return new Pair<>(Double.valueOf(zeroFloat2), Integer.valueOf(i14));
            }
            int i15 = 8 * this.offsetX;
            double zeroDouble2 = this.op.zeroDouble();
            int i16 = -1;
            int i17 = 0;
            if (this.incrX == 1) {
                int i18 = 0;
                while (i18 < 8 * this.n) {
                    double d = asNetty.getDouble(i15 + i18);
                    i16 = this.op.update(zeroDouble2, i16, d, i17);
                    if (i16 == i17) {
                        zeroDouble2 = this.op.op(d);
                    }
                    i18 += 8;
                    i17++;
                }
            } else {
                int i19 = 0;
                while (i19 < 8 * this.n) {
                    double d2 = asNetty.getDouble(i15 + (i19 * this.incrX));
                    i16 = this.op.update(zeroDouble2, i16, d2, i17);
                    if (i16 == i17) {
                        zeroDouble2 = this.op.op(d2);
                    }
                    i19 += 8;
                    i17++;
                }
            }
            int i20 = i16 + this.elementOffset;
            if (this.outerTask) {
                this.op.setFinalResult(i20);
            }
            return new Pair<>(Double.valueOf(zeroDouble2), Integer.valueOf(i20));
        }
        if (data.allocationMode() == DataBuffer.AllocationMode.HEAP) {
            if (data.dataType() == DataBuffer.Type.FLOAT) {
                float[] fArr2 = (float[]) data.array();
                float[] fArr3 = (float[]) data2.array();
                float zeroFloat3 = this.op.zeroFloat();
                int i21 = -1;
                if (this.incrX == 1 && this.incrY == 1) {
                    for (int i22 = 0; i22 < this.n; i22++) {
                        i21 = this.op.update(zeroFloat3, i21, fArr2[this.offsetX + i22], fArr3[this.offsetY + i22], i22);
                        if (i21 == i22) {
                            zeroFloat3 = this.op.op(fArr2[this.offsetX + i22], fArr3[this.offsetY + i22]);
                        }
                    }
                } else {
                    for (int i23 = 0; i23 < this.n; i23++) {
                        i21 = this.op.update(zeroFloat3, i21, fArr2[this.offsetX + (i23 * this.incrX)], fArr3[this.offsetY + (i23 * this.incrY)], i23);
                        if (i21 == i23) {
                            zeroFloat3 = this.op.op(fArr2[this.offsetX + (i23 * this.incrX)], fArr3[this.offsetY + (i23 * this.incrY)]);
                        }
                    }
                }
                int i24 = i21 + this.elementOffset;
                if (this.outerTask) {
                    this.op.setFinalResult(i24);
                }
                return new Pair<>(Double.valueOf(zeroFloat3), Integer.valueOf(i24));
            }
            double[] dArr2 = (double[]) data.array();
            double[] dArr3 = (double[]) data2.array();
            double zeroDouble3 = this.op.zeroDouble();
            int i25 = -1;
            if (this.incrX == 1 && this.incrY == 1) {
                for (int i26 = 0; i26 < this.n; i26++) {
                    i25 = this.op.update(zeroDouble3, i25, dArr2[this.offsetX + i26], dArr3[this.offsetY + i26], i26);
                    if (i25 == i26) {
                        zeroDouble3 = this.op.op(dArr2[this.offsetX + i26], dArr3[this.offsetY + i26]);
                    }
                }
            } else {
                for (int i27 = 0; i27 < this.n; i27++) {
                    i25 = this.op.update(zeroDouble3, i25, dArr2[this.offsetX + (i27 * this.incrX)], dArr3[this.offsetY + (i27 * this.incrY)], i27);
                    if (i25 == i27) {
                        zeroDouble3 = this.op.op(dArr2[this.offsetX + (i27 * this.incrX)], dArr3[this.offsetY + (i27 * this.incrY)]);
                    }
                }
            }
            int i28 = i25 + this.elementOffset;
            if (this.outerTask) {
                this.op.setFinalResult(i28);
            }
            return new Pair<>(Double.valueOf(zeroDouble3), Integer.valueOf(i28));
        }
        ByteBuf asNetty2 = data.asNetty();
        ByteBuf asNetty3 = data2.asNetty();
        if (data.dataType() == DataBuffer.Type.FLOAT) {
            int i29 = 4 * this.offsetX;
            int i30 = 4 * this.offsetY;
            float zeroFloat4 = this.op.zeroFloat();
            int i31 = -1;
            int i32 = 0;
            if (this.incrX == 1 && this.incrY == 1) {
                int i33 = 0;
                while (i33 < 4 * this.n) {
                    float f3 = asNetty2.getFloat(i29 + i33);
                    float f4 = asNetty3.getFloat(i30 + i33);
                    i31 = this.op.update(zeroFloat4, i31, f3, f4, i32);
                    if (i31 == i32) {
                        zeroFloat4 = this.op.op(f3, f4);
                    }
                    i33 += 4;
                    i32++;
                }
            } else {
                int i34 = 0;
                while (i34 < 4 * this.n) {
                    float f5 = asNetty2.getFloat(i29 + (i34 * this.incrX));
                    float f6 = asNetty3.getFloat(i30 + (i34 * this.incrY));
                    i31 = this.op.update(zeroFloat4, i31, f5, f6, i32);
                    if (i31 == i32) {
                        zeroFloat4 = this.op.op(f5, f6);
                    }
                    i34 += 4;
                    i32++;
                }
            }
            int i35 = i31 + this.elementOffset;
            if (this.outerTask) {
                this.op.setFinalResult(i35);
            }
            return new Pair<>(Double.valueOf(zeroFloat4), Integer.valueOf(i35));
        }
        int i36 = 8 * this.offsetX;
        int i37 = 8 * this.offsetY;
        double zeroDouble4 = this.op.zeroDouble();
        int i38 = -1;
        int i39 = 0;
        if (this.incrX == 1 && this.incrY == 1) {
            int i40 = 0;
            while (i40 < 8 * this.n) {
                double d3 = asNetty2.getDouble(i36 + i40);
                double d4 = asNetty3.getDouble(i37 + i40);
                i38 = this.op.update(zeroDouble4, i38, d3, d4, i39);
                if (i38 == i39) {
                    zeroDouble4 = this.op.op(d3, d4);
                }
                i40 += 8;
                i39++;
            }
        } else {
            int i41 = 0;
            while (i41 < 8 * this.n) {
                double d5 = asNetty2.getDouble(i36 + (i41 * this.incrX));
                double d6 = asNetty3.getDouble(i37 + (i41 * this.incrY));
                i38 = this.op.update(zeroDouble4, i38, d5, d6, i39);
                if (i38 == i39) {
                    zeroDouble4 = this.op.op(d5, d6);
                }
                i41 += 8;
                i39++;
            }
        }
        int i42 = i38 + this.elementOffset;
        if (this.outerTask) {
            this.op.setFinalResult(i42);
        }
        return new Pair<>(Double.valueOf(zeroDouble4), Integer.valueOf(i42));
    }
}
