package org.nd4j.linalg.api.parallel.tasks.cpu.misc;

import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Future;
import java.util.concurrent.RecursiveTask;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.parallel.tasks.Task;
import org.nd4j.linalg.api.parallel.tasks.TaskExecutorProvider;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/api/parallel/tasks/cpu/misc/CPUCol2ImTask.class */
public class CPUCol2ImTask extends RecursiveTask<INDArray> implements Task<INDArray> {
    protected Future<INDArray> future;
    protected List<CPUCol2ImTask> subTasks;
    protected final INDArray col;
    protected INDArray imgOut;
    protected final int kernelHeight;
    protected final int kernelWidth;
    protected final int strideY;
    protected final int strideX;
    protected final int padHeight;
    protected final int padWidth;
    protected final int imgHeight;
    protected final int imgWidth;
    protected final int parallelThreshold;
    protected final int exampleFrom;
    protected final int exampleTo;
    protected final int depthFrom;
    protected final int depthTo;

    public CPUCol2ImTask(INDArray iNDArray, int i, int i2, int i3, int i4, int i5, int i6, int i7) {
        this(iNDArray, getNewOutputArray(iNDArray, i5, i6), i, i2, i3, i4, i5, i6, 0, iNDArray.size(0), 0, iNDArray.size(1), i7);
    }

    public CPUCol2ImTask(INDArray iNDArray, INDArray iNDArray2, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11) {
        this.col = iNDArray;
        this.imgOut = iNDArray2;
        this.kernelHeight = iNDArray.size(2);
        this.kernelWidth = iNDArray.size(3);
        this.strideY = i;
        this.strideX = i2;
        this.padHeight = i3;
        this.padWidth = i4;
        this.imgHeight = i5;
        this.imgWidth = i6;
        this.parallelThreshold = i11;
        this.exampleFrom = i7;
        this.exampleTo = i8;
        this.depthFrom = i9;
        this.depthTo = i10;
    }

    private static INDArray getNewOutputArray(INDArray iNDArray, int i, int i2) {
        return Nd4j.create(iNDArray.size(0), iNDArray.size(1), i, i2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.concurrent.RecursiveTask
    public INDArray compute() {
        splitOrExecute(true);
        return this.imgOut;
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.Task, java.util.concurrent.Callable
    public INDArray call() {
        splitOrExecute(true);
        return null;
    }

    private void splitOrExecute(boolean z) {
        CPUCol2ImTask cPUCol2ImTask;
        CPUCol2ImTask cPUCol2ImTask2;
        if (!z) {
            this.subTasks = new ArrayList();
        }
        if (this.parallelThreshold == Integer.MAX_VALUE || opSize() <= this.parallelThreshold) {
            execute();
            return;
        }
        int i = this.exampleTo - this.exampleFrom;
        if (i > 1) {
            int i2 = i / 2;
            cPUCol2ImTask = new CPUCol2ImTask(this.col, this.imgOut, this.strideY, this.strideX, this.padHeight, this.padWidth, this.imgHeight, this.imgWidth, this.exampleFrom, this.exampleFrom + i2, this.depthFrom, this.depthTo, this.parallelThreshold);
            if (z) {
                cPUCol2ImTask.fork();
            } else {
                cPUCol2ImTask.invokeAsync();
                this.subTasks.add(cPUCol2ImTask);
            }
            cPUCol2ImTask2 = new CPUCol2ImTask(this.col, this.imgOut, this.strideY, this.strideX, this.padHeight, this.padWidth, this.imgHeight, this.imgWidth, this.exampleFrom + i2, this.exampleTo, this.depthFrom, this.depthTo, this.parallelThreshold);
            if (z) {
                cPUCol2ImTask2.fork();
            } else {
                cPUCol2ImTask2.invokeAsync();
                this.subTasks.add(cPUCol2ImTask2);
            }
        } else {
            int i3 = this.depthTo - this.depthFrom;
            if (i3 <= 1) {
                execute();
                return;
            }
            int i4 = i3 / 2;
            cPUCol2ImTask = new CPUCol2ImTask(this.col, this.imgOut, this.strideY, this.strideX, this.padHeight, this.padWidth, this.imgHeight, this.imgWidth, this.exampleFrom, this.exampleTo, this.depthFrom, this.depthFrom + i4, this.parallelThreshold);
            cPUCol2ImTask.fork();
            cPUCol2ImTask2 = new CPUCol2ImTask(this.col, this.imgOut, this.strideY, this.strideX, this.padHeight, this.padWidth, this.imgHeight, this.imgWidth, this.exampleFrom, this.exampleTo, this.depthFrom + i4, this.depthTo, this.parallelThreshold);
            cPUCol2ImTask2.fork();
        }
        if (z) {
            cPUCol2ImTask.join();
            cPUCol2ImTask2.join();
        }
    }

    private int opSize() {
        return (this.exampleTo - this.exampleFrom) * (this.depthTo - this.depthFrom) * this.col.size(4) * this.col.size(5) * this.kernelHeight * this.kernelWidth;
    }

    private void execute() {
        DataBuffer data = this.col.data();
        if (data.allocationMode() == DataBuffer.AllocationMode.HEAP) {
            if (data.dataType() == DataBuffer.Type.FLOAT) {
                doHeapFloat();
                return;
            } else {
                doHeapDouble();
                return;
            }
        }
        if (data.dataType() == DataBuffer.Type.FLOAT) {
            doDirectFloat();
        } else {
            doDirectDouble();
        }
    }

    private void doHeapFloat() {
        DataBuffer data = this.col.data();
        DataBuffer data2 = this.imgOut.data();
        int[] shape = this.imgOut.shape();
        int[] stride = this.imgOut.stride();
        int[] shape2 = this.col.shape();
        int[] stride2 = this.col.stride();
        int[] iArr = new int[4];
        int[] iArr2 = new int[6];
        int i = stride2[2];
        int i2 = stride2[3];
        int i3 = stride[2];
        int i4 = stride[3];
        int i5 = shape[2];
        int i6 = shape[3];
        int i7 = shape2[4];
        int i8 = shape2[5];
        boolean z = this.padHeight > 0 || this.padWidth > 0;
        float[] fArr = (float[]) data.array();
        float[] fArr2 = (float[]) data2.array();
        for (int i9 = this.exampleFrom; i9 < this.exampleTo; i9++) {
            for (int i10 = this.depthFrom; i10 < this.depthTo; i10++) {
                iArr2[0] = i9;
                iArr2[1] = i10;
                iArr[0] = i9;
                iArr[1] = i10;
                for (int i11 = 0; i11 < i8; i11++) {
                    for (int i12 = 0; i12 < i7; i12++) {
                        iArr2[4] = i12;
                        iArr2[5] = i11;
                        int offsetUnsafe6 = getOffsetUnsafe6(0, shape2, stride2, iArr2);
                        if (z) {
                            int i13 = (i12 * this.strideY) - this.padHeight;
                            int i14 = (i11 * this.strideX) - this.padWidth;
                            iArr[2] = i13;
                            iArr[3] = i14;
                            int offsetUnsafe4 = getOffsetUnsafe4(0, shape, stride, iArr);
                            if (i <= i2) {
                                for (int i15 = 0; i15 < this.kernelWidth; i15++) {
                                    if (i14 + i15 >= 0 && i14 + i15 < i6) {
                                        for (int i16 = 0; i16 < this.kernelHeight; i16++) {
                                            if (i13 + i16 >= 0 && i13 + i16 < i5) {
                                                int i17 = offsetUnsafe4 + (i16 * i3) + (i15 * i4);
                                                fArr2[i17] = fArr2[i17] + fArr[offsetUnsafe6 + (i16 * i) + (i15 * i2)];
                                            }
                                        }
                                    }
                                }
                            } else {
                                for (int i18 = 0; i18 < this.kernelHeight; i18++) {
                                    if (i13 + i18 >= 0 && i13 + i18 < i5) {
                                        for (int i19 = 0; i19 < this.kernelWidth; i19++) {
                                            if (i14 + i19 >= 0 && i14 + i19 < i6) {
                                                int i20 = offsetUnsafe4 + (i18 * i3) + (i19 * i4);
                                                fArr2[i20] = fArr2[i20] + fArr[offsetUnsafe6 + (i18 * i) + (i19 * i2)];
                                            }
                                        }
                                    }
                                }
                            }
                        } else {
                            int i21 = i12 * this.strideY;
                            int i22 = i11 * this.strideX;
                            iArr[2] = i21;
                            iArr[3] = i22;
                            int offsetUnsafe42 = getOffsetUnsafe4(0, shape, stride, iArr);
                            if (i <= i2) {
                                for (int i23 = 0; i23 < this.kernelWidth; i23++) {
                                    for (int i24 = 0; i24 < this.kernelHeight; i24++) {
                                        int i25 = offsetUnsafe42 + (i24 * i3) + (i23 * i4);
                                        fArr2[i25] = fArr2[i25] + fArr[offsetUnsafe6 + (i24 * i) + (i23 * i2)];
                                    }
                                }
                            } else {
                                for (int i26 = 0; i26 < this.kernelHeight; i26++) {
                                    for (int i27 = 0; i27 < this.kernelWidth; i27++) {
                                        int i28 = offsetUnsafe42 + (i26 * i3) + (i27 * i4);
                                        fArr2[i28] = fArr2[i28] + fArr[offsetUnsafe6 + (i26 * i) + (i27 * i2)];
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
    }

    private void doHeapDouble() {
        DataBuffer data = this.col.data();
        DataBuffer data2 = this.imgOut.data();
        int offset = this.imgOut.offset();
        int[] shape = this.imgOut.shape();
        int[] stride = this.imgOut.stride();
        int offset2 = this.col.offset();
        int[] shape2 = this.col.shape();
        int[] stride2 = this.col.stride();
        int[] iArr = new int[4];
        int[] iArr2 = new int[6];
        int i = stride2[2];
        int i2 = stride2[3];
        int i3 = stride[2];
        int i4 = stride[3];
        int i5 = shape[2];
        int i6 = shape[3];
        int i7 = shape2[4];
        int i8 = shape2[5];
        boolean z = this.padHeight > 0 || this.padWidth > 0;
        double[] dArr = (double[]) data.array();
        double[] dArr2 = (double[]) data2.array();
        for (int i9 = this.exampleFrom; i9 < this.exampleTo; i9++) {
            for (int i10 = this.depthFrom; i10 < this.depthTo; i10++) {
                iArr2[0] = i9;
                iArr2[1] = i10;
                iArr[0] = i9;
                iArr[1] = i10;
                for (int i11 = 0; i11 < i8; i11++) {
                    for (int i12 = 0; i12 < i7; i12++) {
                        iArr2[4] = i12;
                        iArr2[5] = i11;
                        int offsetUnsafe6 = getOffsetUnsafe6(offset2, shape2, stride2, iArr2);
                        if (z) {
                            int i13 = (i12 * this.strideY) - this.padHeight;
                            int i14 = (i11 * this.strideX) - this.padWidth;
                            iArr[2] = i13;
                            iArr[3] = i14;
                            int offsetUnsafe4 = getOffsetUnsafe4(offset, shape, stride, iArr);
                            if (i <= i2) {
                                for (int i15 = 0; i15 < this.kernelWidth; i15++) {
                                    if (i14 + i15 >= 0 && i14 + i15 < i6) {
                                        for (int i16 = 0; i16 < this.kernelHeight; i16++) {
                                            if (i13 + i16 >= 0 && i13 + i16 < i5) {
                                                int i17 = offsetUnsafe4 + (i16 * i3) + (i15 * i4);
                                                dArr2[i17] = dArr2[i17] + dArr[offsetUnsafe6 + (i16 * i) + (i15 * i2)];
                                            }
                                        }
                                    }
                                }
                            } else {
                                for (int i18 = 0; i18 < this.kernelHeight; i18++) {
                                    if (i13 + i18 >= 0 && i13 + i18 < i5) {
                                        for (int i19 = 0; i19 < this.kernelWidth; i19++) {
                                            if (i14 + i19 >= 0 && i14 + i19 < i6) {
                                                int i20 = offsetUnsafe4 + (i18 * i3) + (i19 * i4);
                                                dArr2[i20] = dArr2[i20] + dArr[offsetUnsafe6 + (i18 * i) + (i19 * i2)];
                                            }
                                        }
                                    }
                                }
                            }
                        } else {
                            int i21 = i12 * this.strideY;
                            int i22 = i11 * this.strideX;
                            iArr[2] = i21;
                            iArr[3] = i22;
                            int offsetUnsafe42 = getOffsetUnsafe4(offset, shape, stride, iArr);
                            if (i <= i2) {
                                for (int i23 = 0; i23 < this.kernelWidth; i23++) {
                                    for (int i24 = 0; i24 < this.kernelHeight; i24++) {
                                        int i25 = offsetUnsafe42 + (i24 * i3) + (i23 * i4);
                                        dArr2[i25] = dArr2[i25] + dArr[offsetUnsafe6 + (i24 * i) + (i23 * i2)];
                                    }
                                }
                            } else {
                                for (int i26 = 0; i26 < this.kernelHeight; i26++) {
                                    for (int i27 = 0; i27 < this.kernelWidth; i27++) {
                                        int i28 = offsetUnsafe42 + (i26 * i3) + (i27 * i4);
                                        dArr2[i28] = dArr2[i28] + dArr[offsetUnsafe6 + (i26 * i) + (i27 * i2)];
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
    }

    private void doDirectFloat() {
        DataBuffer data = this.col.data();
        DataBuffer data2 = this.imgOut.data();
        int[] shape = this.imgOut.shape();
        int[] stride = this.imgOut.stride();
        int[] shape2 = this.col.shape();
        int[] stride2 = this.col.stride();
        int[] iArr = new int[4];
        int[] iArr2 = new int[6];
        int i = stride2[2];
        int i2 = stride2[3];
        int i3 = stride[2];
        int i4 = stride[3];
        int i5 = shape[2];
        int i6 = shape[3];
        int i7 = shape2[4];
        int i8 = shape2[5];
        boolean z = this.padHeight > 0 || this.padWidth > 0;
        FloatBuffer asNioFloat = data.asNioFloat();
        FloatBuffer asNioFloat2 = data2.asNioFloat();
        for (int i9 = this.exampleFrom; i9 < this.exampleTo; i9++) {
            for (int i10 = this.depthFrom; i10 < this.depthTo; i10++) {
                iArr2[0] = i9;
                iArr2[1] = i10;
                iArr[0] = i9;
                iArr[1] = i10;
                for (int i11 = 0; i11 < i8; i11++) {
                    for (int i12 = 0; i12 < i7; i12++) {
                        iArr2[4] = i12;
                        iArr2[5] = i11;
                        int offsetUnsafe6 = getOffsetUnsafe6(0, shape2, stride2, iArr2);
                        if (z) {
                            int i13 = (i12 * this.strideY) - this.padHeight;
                            int i14 = (i11 * this.strideX) - this.padWidth;
                            iArr[2] = i13;
                            iArr[3] = i14;
                            int offsetUnsafe4 = getOffsetUnsafe4(0, shape, stride, iArr);
                            if (i <= i2) {
                                for (int i15 = 0; i15 < this.kernelWidth; i15++) {
                                    if (i14 + i15 >= 0 && i14 + i15 < i6) {
                                        for (int i16 = 0; i16 < this.kernelHeight; i16++) {
                                            if (i13 + i16 >= 0 && i13 + i16 < i5) {
                                                asNioFloat2.put(offsetUnsafe4 + (i16 * i3) + (i15 * i4), asNioFloat2.get(offsetUnsafe4 + (i16 * i3) + (i15 * i4)) + asNioFloat.get(offsetUnsafe6 + (i16 * i) + (i15 * i2)));
                                            }
                                        }
                                    }
                                }
                            } else {
                                for (int i17 = 0; i17 < this.kernelHeight; i17++) {
                                    if (i13 + i17 >= 0 && i13 + i17 < i5) {
                                        for (int i18 = 0; i18 < this.kernelWidth; i18++) {
                                            if (i14 + i18 >= 0 && i14 + i18 < i6) {
                                                asNioFloat2.put(offsetUnsafe4 + (i17 * i3) + (i18 * i4), asNioFloat2.get(offsetUnsafe4 + (i17 * i3) + (i18 * i4)) + asNioFloat.get(offsetUnsafe6 + (i17 * i) + (i18 * i2)));
                                            }
                                        }
                                    }
                                }
                            }
                        } else {
                            int i19 = i12 * this.strideY;
                            int i20 = i11 * this.strideX;
                            iArr[2] = i19;
                            iArr[3] = i20;
                            int offsetUnsafe42 = getOffsetUnsafe4(0, shape, stride, iArr);
                            if (i <= i2) {
                                for (int i21 = 0; i21 < this.kernelWidth; i21++) {
                                    for (int i22 = 0; i22 < this.kernelHeight; i22++) {
                                        asNioFloat2.put(offsetUnsafe42 + (i22 * i3) + (i21 * i4), asNioFloat2.get(offsetUnsafe42 + (i22 * i3) + (i21 * i4)) + asNioFloat.get(offsetUnsafe6 + (i22 * i) + (i21 * i2)));
                                    }
                                }
                            } else {
                                for (int i23 = 0; i23 < this.kernelHeight; i23++) {
                                    for (int i24 = 0; i24 < this.kernelWidth; i24++) {
                                        asNioFloat2.put(offsetUnsafe42 + (i23 * i3) + (i24 * i4), asNioFloat2.get(offsetUnsafe42 + (i23 * i3) + (i24 * i4)) + asNioFloat.get(offsetUnsafe6 + (i23 * i) + (i24 * i2)));
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
    }

    private void doDirectDouble() {
        DataBuffer data = this.col.data();
        DataBuffer data2 = this.imgOut.data();
        int[] shape = this.imgOut.shape();
        int[] stride = this.imgOut.stride();
        int[] shape2 = this.col.shape();
        int[] stride2 = this.col.stride();
        int[] iArr = new int[4];
        int[] iArr2 = new int[6];
        int i = stride2[2];
        int i2 = stride2[3];
        int i3 = stride[2];
        int i4 = stride[3];
        int i5 = shape[2];
        int i6 = shape[3];
        int i7 = shape2[4];
        int i8 = shape2[5];
        boolean z = this.padHeight > 0 || this.padWidth > 0;
        DoubleBuffer asNioDouble = data.asNioDouble();
        DoubleBuffer asNioDouble2 = data2.asNioDouble();
        for (int i9 = this.exampleFrom; i9 < this.exampleTo; i9++) {
            for (int i10 = this.depthFrom; i10 < this.depthTo; i10++) {
                iArr2[0] = i9;
                iArr2[1] = i10;
                iArr[0] = i9;
                iArr[1] = i10;
                for (int i11 = 0; i11 < i8; i11++) {
                    for (int i12 = 0; i12 < i7; i12++) {
                        iArr2[4] = i12;
                        iArr2[5] = i11;
                        int offsetUnsafe6 = getOffsetUnsafe6(0, shape2, stride2, iArr2);
                        if (z) {
                            int i13 = (i12 * this.strideY) - this.padHeight;
                            int i14 = (i11 * this.strideX) - this.padWidth;
                            iArr[2] = i13;
                            iArr[3] = i14;
                            int offsetUnsafe4 = getOffsetUnsafe4(0, shape, stride, iArr);
                            if (i <= i2) {
                                for (int i15 = 0; i15 < this.kernelWidth; i15++) {
                                    if (i14 + i15 >= 0 && i14 + i15 < i6) {
                                        for (int i16 = 0; i16 < this.kernelHeight; i16++) {
                                            if (i13 + i16 >= 0 && i13 + i16 < i5) {
                                                asNioDouble2.put(offsetUnsafe4 + (i16 * i3) + (i15 * i4), asNioDouble2.get(offsetUnsafe4 + (i16 * i3) + (i15 * i4)) + asNioDouble.get(offsetUnsafe6 + (i16 * i) + (i15 * i2)));
                                            }
                                        }
                                    }
                                }
                            } else {
                                for (int i17 = 0; i17 < this.kernelHeight; i17++) {
                                    if (i13 + i17 >= 0 && i13 + i17 < i5) {
                                        for (int i18 = 0; i18 < this.kernelWidth; i18++) {
                                            if (i14 + i18 >= 0 && i14 + i18 < i6) {
                                                asNioDouble2.put(offsetUnsafe4 + (i17 * i3) + (i18 * i4), asNioDouble2.get(offsetUnsafe4 + (i17 * i3) + (i18 * i4)) + asNioDouble.get(offsetUnsafe6 + (i17 * i) + (i18 * i2)));
                                            }
                                        }
                                    }
                                }
                            }
                        } else {
                            int i19 = i12 * this.strideY;
                            int i20 = i11 * this.strideX;
                            iArr[2] = i19;
                            iArr[3] = i20;
                            int offsetUnsafe42 = getOffsetUnsafe4(0, shape, stride, iArr);
                            if (i <= i2) {
                                for (int i21 = 0; i21 < this.kernelWidth; i21++) {
                                    for (int i22 = 0; i22 < this.kernelHeight; i22++) {
                                        asNioDouble2.put(offsetUnsafe42 + (i22 * i3) + (i21 * i4), asNioDouble2.get(offsetUnsafe42 + (i22 * i3) + (i21 * i4)) + asNioDouble.get(offsetUnsafe6 + (i22 * i) + (i21 * i2)));
                                    }
                                }
                            } else {
                                for (int i23 = 0; i23 < this.kernelHeight; i23++) {
                                    for (int i24 = 0; i24 < this.kernelWidth; i24++) {
                                        asNioDouble2.put(offsetUnsafe42 + (i23 * i3) + (i24 * i4), asNioDouble2.get(offsetUnsafe42 + (i23 * i3) + (i24 * i4)) + asNioDouble.get(offsetUnsafe6 + (i23 * i) + (i24 * i2)));
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
    }

    private static int getOffsetUnsafe4(int i, int[] iArr, int[] iArr2, int[] iArr3) {
        int i2 = i;
        if (iArr[0] != 1) {
            i2 += iArr3[0] * iArr2[0];
        }
        if (iArr[1] != 1) {
            i2 += iArr3[1] * iArr2[1];
        }
        if (iArr[2] != 1) {
            i2 += iArr3[2] * iArr2[2];
        }
        if (iArr[3] != 1) {
            i2 += iArr3[3] * iArr2[3];
        }
        return i2;
    }

    private static int getOffsetUnsafe6(int i, int[] iArr, int[] iArr2, int[] iArr3) {
        int i2 = i;
        if (iArr[0] != 1) {
            i2 += iArr3[0] * iArr2[0];
        }
        if (iArr[1] != 1) {
            i2 += iArr3[1] * iArr2[1];
        }
        if (iArr[4] != 1) {
            i2 += iArr3[4] * iArr2[4];
        }
        if (iArr[5] != 1) {
            i2 += iArr3[5] * iArr2[5];
        }
        return i2;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.nd4j.linalg.api.parallel.tasks.Task
    public INDArray invokeBlocking() {
        invokeAsync();
        return blockUntilComplete();
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.Task
    public void invokeAsync() {
        this.future = TaskExecutorProvider.getTaskExecutor().executeAsync(this);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.nd4j.linalg.api.parallel.tasks.Task
    public INDArray blockUntilComplete() {
        try {
            this.future.get();
            if (this.subTasks != null) {
                Iterator<CPUCol2ImTask> it = this.subTasks.iterator();
                while (it.hasNext()) {
                    it.next().blockUntilComplete();
                }
            }
            return this.imgOut;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}
