package org.nd4j.linalg.cpu.nativecpu.ops;

import java.util.Arrays;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.accum.Variance;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.nativeblas.NativeOps;

/* loaded from: input_file:org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.class */
public class NativeOpExecutioner extends DefaultOpExecutioner {
    private NativeOps loop = new NativeOps();

    public Op exec(Op op) {
        if (op instanceof ScalarOp) {
            exec((ScalarOp) op);
        } else if (op instanceof TransformOp) {
            exec((TransformOp) op);
        } else if (op instanceof Accumulation) {
            exec((Accumulation) op);
        } else if (op instanceof IndexAccumulation) {
            exec((IndexAccumulation) op);
        } else if (op instanceof BroadcastOp) {
            BroadcastOp broadcastOp = (BroadcastOp) op;
            exec(broadcastOp, broadcastOp.getDimension());
        }
        return op;
    }

    public INDArray exec(IndexAccumulation indexAccumulation, int... iArr) {
        Arrays.sort(iArr);
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] < 0) {
                int i2 = i;
                iArr[i2] = iArr[i2] + indexAccumulation.x().rank();
            }
        }
        if (iArr.length == indexAccumulation.x().rank()) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        int[] removeIndex = Shape.wholeArrayDimension(iArr) ? new int[]{1, 1} : ArrayUtil.removeIndex(indexAccumulation.x().shape(), iArr);
        if (indexAccumulation.x().isVector() && indexAccumulation.x().length() == ArrayUtil.prod(removeIndex)) {
            return indexAccumulation.x();
        }
        if (removeIndex.length == 1) {
            removeIndex = iArr[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
        } else if (removeIndex.length == 0) {
            removeIndex = new int[]{1, 1};
        }
        indexAccumulation.setZ(Nd4j.valueArrayOf(removeIndex, indexAccumulation.zeroDouble()));
        if (iArr.length == indexAccumulation.x().rank()) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        long address = Nd4j.createBuffer(iArr).address();
        long[] jArr = new long[1];
        long address2 = indexAccumulation.x().data().address();
        long address3 = indexAccumulation.z().data().address();
        if (indexAccumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            this.loop.execIndexReduceDouble(jArr, indexAccumulation.opNum(), address2, indexAccumulation.x().shapeInfoDataBuffer().address(), getAddressForExtraArgs(indexAccumulation), address3, indexAccumulation.z().shapeInfoDataBuffer().address(), address, iArr.length);
        } else {
            this.loop.execIndexReduceFloat(jArr, indexAccumulation.opNum(), indexAccumulation.x().data().address(), indexAccumulation.x().shapeInfoDataBuffer().address(), getAddressForExtraArgs(indexAccumulation), indexAccumulation.z().data().address(), indexAccumulation.z().shapeInfoDataBuffer().address(), address, iArr.length);
        }
        return indexAccumulation.z();
    }

    protected void doAccumulationOp(Accumulation accumulation) {
        exec(accumulation);
    }

    protected void doBroadcastOp(BroadcastOp broadcastOp) {
        exec((Op) broadcastOp);
    }

    protected void doIndexAccumulationOp(IndexAccumulation indexAccumulation) {
        exec(indexAccumulation);
    }

    protected void doScalarOp(ScalarOp scalarOp) {
        exec(scalarOp);
    }

    protected void doTransformOp(TransformOp transformOp) {
        exec(transformOp);
    }

    public INDArray exec(Accumulation accumulation, int... iArr) {
        Arrays.sort(iArr);
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] < 0) {
                int i2 = i;
                iArr[i2] = iArr[i2] + accumulation.x().rank();
            }
        }
        if (iArr.length == accumulation.x().rank()) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        int[] removeIndex = Shape.wholeArrayDimension(iArr) ? new int[]{1, 1} : ArrayUtil.removeIndex(accumulation.x().shape(), iArr);
        if (removeIndex.length == 1) {
            removeIndex = iArr[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
        } else if (removeIndex.length == 0) {
            removeIndex = new int[]{1, 1};
        }
        if (accumulation.x().isVector() && accumulation.x().length() == ArrayUtil.prod(removeIndex)) {
            return accumulation.noOp();
        }
        INDArray valueArrayOf = Nd4j.valueArrayOf(removeIndex, accumulation.zeroDouble());
        accumulation.setZ(valueArrayOf);
        long[] jArr = new long[1];
        long address = Nd4j.createBuffer(iArr).address();
        if (accumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (accumulation instanceof Variance) {
                if (valueArrayOf.isScalar()) {
                    valueArrayOf.putScalar(0, this.loop.execSummaryStatsScalarDouble(jArr, accumulation.opNum(), accumulation.x().data().address(), accumulation.x().shapeInfoDataBuffer().address(), getAddressForExtraArgs(accumulation), true));
                } else {
                    this.loop.execSummaryStatsDouble(jArr, accumulation.opNum(), accumulation.x().data().address(), accumulation.x().shapeInfoDataBuffer().address(), getAddressForExtraArgs(accumulation), accumulation.z().data().address(), accumulation.z().shapeInfoDataBuffer().address(), address, iArr.length, ((Variance) accumulation).isBiasCorrected());
                }
            } else if (accumulation.y() != null) {
                if (valueArrayOf.isScalar()) {
                    valueArrayOf.putScalar(0, this.loop.execReduce3ScalarDouble(jArr, accumulation.opNum(), accumulation.x().data().address(), accumulation.x().shapeInfoDataBuffer().address(), getAddressForExtraArgs(accumulation), accumulation.y().data().address(), accumulation.y().shapeInfoDataBuffer().address()));
                } else {
                    this.loop.execReduce3Double(jArr, accumulation.opNum(), accumulation.x().data().address(), accumulation.x().shapeInfoDataBuffer().address(), getAddressForExtraArgs(accumulation), accumulation.y().data().address(), accumulation.y().shapeInfoDataBuffer().address(), accumulation.z().data().address(), accumulation.z().shapeInfoDataBuffer().address(), address, iArr.length);
                }
            } else if (valueArrayOf.isScalar()) {
                valueArrayOf.putScalar(0, this.loop.execReduceScalarDouble(jArr, accumulation.opNum(), accumulation.x().data().address(), accumulation.x().shapeInfoDataBuffer().address(), getAddressForExtraArgs(accumulation)));
            } else {
                this.loop.execReduceDouble(jArr, accumulation.opNum(), accumulation.x().data().address(), accumulation.x().shapeInfoDataBuffer().address(), getAddressForExtraArgs(accumulation), accumulation.z().data().address(), accumulation.z().shapeInfoDataBuffer().address(), address, iArr.length);
            }
        } else if (accumulation instanceof Variance) {
            Variance variance = (Variance) accumulation;
            if (valueArrayOf.isScalar()) {
                valueArrayOf.putScalar(0, this.loop.execSummaryStatsScalarFloat(jArr, accumulation.opNum(), accumulation.x().data().address(), accumulation.x().shapeInfoDataBuffer().address(), getAddressForExtraArgs(accumulation), variance.isBiasCorrected()));
            } else {
                this.loop.execSummaryStatsFloat(jArr, accumulation.opNum(), accumulation.x().data().address(), accumulation.x().shapeInfoDataBuffer().address(), getAddressForExtraArgs(accumulation), accumulation.z().data().address(), accumulation.z().shapeInfoDataBuffer().address(), address, iArr.length, variance.isBiasCorrected());
            }
        } else if (accumulation.y() != null) {
            if (valueArrayOf.isScalar()) {
                valueArrayOf.putScalar(0, this.loop.execReduce3ScalarFloat(jArr, accumulation.opNum(), accumulation.x().data().address(), accumulation.x().shapeInfoDataBuffer().address(), getAddressForExtraArgs(accumulation), accumulation.y().data().address(), accumulation.y().shapeInfoDataBuffer().address()));
            } else {
                this.loop.execReduce3Float(jArr, accumulation.opNum(), accumulation.x().data().address(), accumulation.x().shapeInfoDataBuffer().address(), getAddressForExtraArgs(accumulation), accumulation.y().data().address(), accumulation.y().shapeInfoDataBuffer().address(), accumulation.z().data().address(), accumulation.z().shapeInfoDataBuffer().address(), address, iArr.length);
            }
        } else if (valueArrayOf.isScalar()) {
            valueArrayOf.putScalar(0, this.loop.execReduceScalarFloat(jArr, accumulation.opNum(), accumulation.x().data().address(), accumulation.x().shapeInfoDataBuffer().address(), getAddressForExtraArgs(accumulation)));
        } else {
            this.loop.execReduceFloat(jArr, accumulation.opNum(), accumulation.x().data().address(), accumulation.x().shapeInfoDataBuffer().address(), getAddressForExtraArgs(accumulation), accumulation.z().data().address(), accumulation.z().shapeInfoDataBuffer().address(), address, iArr.length);
        }
        return valueArrayOf;
    }

    private void exec(ScalarOp scalarOp) {
        if ((scalarOp.x() instanceof IComplexNDArray) || executionMode() == OpExecutioner.ExecutionMode.JAVA) {
            super.exec(scalarOp);
            return;
        }
        long[] jArr = new long[1];
        if (scalarOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (scalarOp.x().elementWiseStride() < 1 || scalarOp.isExecSpecial() || scalarOp.z().elementWiseStride() < 1 || scalarOp.isExecSpecial()) {
                this.loop.execScalarDouble(jArr, scalarOp.opNum(), scalarOp.x().data().address(), scalarOp.x().shapeInfoDataBuffer().address(), scalarOp.z().data().address(), scalarOp.z().shapeInfoDataBuffer().address(), scalarOp.scalar().doubleValue(), getAddressForExtraArgs(scalarOp));
                return;
            } else {
                this.loop.execScalarDouble(jArr, scalarOp.opNum(), scalarOp.x().data().address(), scalarOp.x().elementWiseStride(), scalarOp.z().data().address(), scalarOp.z().elementWiseStride(), scalarOp.scalar().doubleValue(), getAddressForExtraArgs(scalarOp), scalarOp.n());
                return;
            }
        }
        if (scalarOp.x().elementWiseStride() < 1 || scalarOp.isExecSpecial() || scalarOp.z().elementWiseStride() < 1 || scalarOp.isExecSpecial()) {
            this.loop.execScalarFloat(jArr, scalarOp.opNum(), scalarOp.x().data().address(), scalarOp.x().shapeInfoDataBuffer().address(), scalarOp.z().data().address(), scalarOp.z().shapeInfoDataBuffer().address(), scalarOp.scalar().floatValue(), getAddressForExtraArgs(scalarOp));
        } else {
            this.loop.execScalarFloat(jArr, scalarOp.opNum(), scalarOp.x().data().address(), scalarOp.x().elementWiseStride(), scalarOp.z().data().address(), scalarOp.z().elementWiseStride(), scalarOp.scalar().floatValue(), getAddressForExtraArgs(scalarOp), scalarOp.n());
        }
    }

    private long getAddressForExtraArgs(Op op) {
        if (op.extraArgs() != null) {
            return op.extraArgsDataBuff().address();
        }
        return 0L;
    }

    private void exec(TransformOp transformOp) {
        long[] jArr = new long[1];
        if (transformOp.x().data().dataType() != DataBuffer.Type.DOUBLE) {
            if (transformOp.y() == null) {
                if (transformOp.x().elementWiseStride() < 1 || transformOp.isExecSpecial() || transformOp.x().ordering() != transformOp.z().ordering()) {
                    this.loop.execTransformFloat(jArr, transformOp.opNum(), transformOp.x().data().address(), transformOp.x().shapeInfoDataBuffer().address(), transformOp.z().data().address(), transformOp.z().shapeInfoDataBuffer().address(), getAddressForExtraArgs(transformOp));
                    return;
                } else {
                    this.loop.execTransformFloat(jArr, transformOp.opNum(), transformOp.x().data().address(), transformOp.x().elementWiseStride(), transformOp.z().data().address(), transformOp.z().elementWiseStride(), getAddressForExtraArgs(transformOp), transformOp.n());
                    return;
                }
            }
            if (transformOp.x().elementWiseStride() < 1 || transformOp.y().elementWiseStride() < 1 || transformOp.x().elementWiseStride() != transformOp.y().elementWiseStride() || transformOp.isExecSpecial() || transformOp.x().ordering() != transformOp.y().ordering()) {
                this.loop.execPairwiseTransformFloat(jArr, transformOp.opNum(), transformOp.x().data().address(), transformOp.x().shapeInfoDataBuffer().address(), transformOp.y().data().address(), transformOp.y().shapeInfoDataBuffer().address(), transformOp.z().data().address(), transformOp.z().shapeInfoDataBuffer().address(), getAddressForExtraArgs(transformOp));
                return;
            } else {
                this.loop.execPairwiseTransformFloat(jArr, transformOp.opNum(), transformOp.x().data().address(), transformOp.x().elementWiseStride(), transformOp.y().data().address(), transformOp.y().elementWiseStride(), transformOp.z().data().address(), transformOp.z().elementWiseStride(), getAddressForExtraArgs(transformOp), transformOp.n());
                return;
            }
        }
        if (transformOp.y() == null) {
            if (transformOp.x().elementWiseStride() < 1 || transformOp.isExecSpecial() || transformOp.isExecSpecial() || transformOp.x().ordering() != transformOp.z().ordering()) {
                this.loop.execTransformDouble(jArr, transformOp.opNum(), transformOp.x().data().address(), transformOp.x().shapeInfoDataBuffer().address(), transformOp.z().data().address(), transformOp.z().shapeInfoDataBuffer().address(), getAddressForExtraArgs(transformOp));
                return;
            } else {
                this.loop.execTransformDouble(jArr, transformOp.opNum(), transformOp.x().data().address(), transformOp.x().elementWiseStride(), transformOp.z().data().address(), transformOp.z().elementWiseStride(), getAddressForExtraArgs(transformOp), transformOp.n());
                return;
            }
        }
        if (transformOp.x().elementWiseStride() < 1 || transformOp.y().elementWiseStride() < 1 || transformOp.x().elementWiseStride() != transformOp.y().elementWiseStride() || transformOp.isExecSpecial() || transformOp.x().ordering() != transformOp.y().ordering() || transformOp.x().ordering() != transformOp.z().ordering()) {
            this.loop.execPairwiseTransformDouble(jArr, transformOp.opNum(), transformOp.x().data().address(), transformOp.x().shapeInfoDataBuffer().address(), transformOp.y().data().address(), transformOp.y().shapeInfoDataBuffer().address(), transformOp.z().data().address(), transformOp.z().shapeInfoDataBuffer().address(), getAddressForExtraArgs(transformOp));
        } else {
            this.loop.execPairwiseTransformDouble(jArr, transformOp.opNum(), transformOp.x().data().address(), transformOp.x().elementWiseStride(), transformOp.y().data().address(), transformOp.y().elementWiseStride(), transformOp.z().data().address(), transformOp.z().elementWiseStride(), getAddressForExtraArgs(transformOp), transformOp.n());
        }
    }

    public INDArray exec(BroadcastOp broadcastOp, int... iArr) {
        Arrays.sort(iArr);
        long[] jArr = new long[1];
        long address = Nd4j.createBuffer(iArr).address();
        if (broadcastOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            this.loop.execBroadcastDouble(jArr, broadcastOp.opNum(), broadcastOp.x().data().address(), broadcastOp.x().shapeInfoDataBuffer().address(), broadcastOp.y().data().address(), broadcastOp.y().shapeInfoDataBuffer().address(), broadcastOp.z().data().address(), broadcastOp.z().shapeInfoDataBuffer().address(), address, iArr.length);
        } else {
            this.loop.execBroadcastFloat(jArr, broadcastOp.opNum(), broadcastOp.x().data().address(), broadcastOp.x().shapeInfoDataBuffer().address(), broadcastOp.y().data().address(), broadcastOp.y().shapeInfoDataBuffer().address(), broadcastOp.z().data().address(), broadcastOp.z().shapeInfoDataBuffer().address(), address, iArr.length);
        }
        return broadcastOp.z();
    }

    private void exec(IndexAccumulation indexAccumulation) {
        if ((indexAccumulation.x() instanceof IComplexNDArray) || executionMode() == OpExecutioner.ExecutionMode.JAVA) {
            super.exec(indexAccumulation);
            return;
        }
        long[] jArr = new long[1];
        if (indexAccumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            indexAccumulation.setFinalResult((int) this.loop.execIndexReduceScalarDouble(jArr, indexAccumulation.opNum(), indexAccumulation.x().data().address(), indexAccumulation.x().shapeInfoDataBuffer().address(), getAddressForExtraArgs(indexAccumulation)));
        } else {
            indexAccumulation.setFinalResult((int) this.loop.execIndexReduceScalarFloat(jArr, indexAccumulation.opNum(), indexAccumulation.x().data().address(), indexAccumulation.x().shapeInfoDataBuffer().address(), getAddressForExtraArgs(indexAccumulation)));
        }
    }

    private void exec(Accumulation accumulation) {
        if ((accumulation.x() instanceof IComplexNDArray) || executionMode() == OpExecutioner.ExecutionMode.JAVA) {
            super.exec(accumulation);
            return;
        }
        long[] jArr = new long[1];
        if (accumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (accumulation instanceof Variance) {
                accumulation.setFinalResult(Double.valueOf(this.loop.execSummaryStatsScalarDouble(jArr, accumulation.opNum(), accumulation.x().data().address(), accumulation.x().shapeInfoDataBuffer().address(), getAddressForExtraArgs(accumulation), true)));
                return;
            } else if (accumulation.y() != null) {
                accumulation.setFinalResult(Double.valueOf(this.loop.execReduce3ScalarDouble(jArr, accumulation.opNum(), accumulation.x().data().address(), accumulation.x().shapeInfoDataBuffer().address(), getAddressForExtraArgs(accumulation), accumulation.y().data().address(), accumulation.y().shapeInfoDataBuffer().address())));
                return;
            } else {
                accumulation.setFinalResult(Double.valueOf(this.loop.execReduceScalarDouble(jArr, accumulation.opNum(), accumulation.x().data().address(), accumulation.x().shapeInfoDataBuffer().address(), getAddressForExtraArgs(accumulation))));
                return;
            }
        }
        if (accumulation instanceof Variance) {
            accumulation.setFinalResult(Float.valueOf(this.loop.execSummaryStatsScalarFloat(jArr, accumulation.opNum(), accumulation.x().data().address(), accumulation.x().shapeInfoDataBuffer().address(), getAddressForExtraArgs(accumulation), ((Variance) accumulation).isBiasCorrected())));
        } else if (accumulation.y() != null) {
            accumulation.setFinalResult(Float.valueOf(this.loop.execReduce3ScalarFloat(jArr, accumulation.opNum(), accumulation.x().data().address(), accumulation.x().shapeInfoDataBuffer().address(), getAddressForExtraArgs(accumulation), accumulation.y().data().address(), accumulation.y().shapeInfoDataBuffer().address())));
        } else {
            accumulation.setFinalResult(Float.valueOf(this.loop.execReduceScalarFloat(jArr, accumulation.opNum(), accumulation.x().data().address(), accumulation.x().shapeInfoDataBuffer().address(), getAddressForExtraArgs(accumulation))));
        }
    }
}
