package org.nd4j.linalg.api.parallel.tasks.cpu.transform;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.concurrent.RecursiveAction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.api.parallel.tasks.cpu.BaseCPUAction;

/* loaded from: input_file:org/nd4j/linalg/api/parallel/tasks/cpu/transform/CPUTransformOpViaTensorTask.class */
public class CPUTransformOpViaTensorTask extends BaseCPUAction {
    protected final TransformOp op;

    public CPUTransformOpViaTensorTask(TransformOp transformOp, int i) {
        super(i, 0, 0, 0, 0, 0, 0, 0);
        this.op = transformOp;
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.Task, java.util.concurrent.Callable
    public Void call() {
        execute(false);
        return null;
    }

    @Override // java.util.concurrent.RecursiveAction
    protected void compute() {
        execute(true);
    }

    private void execute(boolean z) {
        INDArray x = this.op.x();
        INDArray y = this.op.y();
        INDArray z2 = this.op.z();
        int chooseElementWiseTensorDimension = y == null ? x == z2 ? OpExecutionerUtil.chooseElementWiseTensorDimension(x) : OpExecutionerUtil.chooseElementWiseTensorDimension(x, z2) : x == z2 ? OpExecutionerUtil.chooseElementWiseTensorDimension(x, y) : OpExecutionerUtil.chooseElementWiseTensorDimension(x, y, z2);
        int tensorssAlongDimension = x.tensorssAlongDimension(chooseElementWiseTensorDimension);
        ArrayList arrayList = null;
        if (z) {
            arrayList = new ArrayList(tensorssAlongDimension);
        } else {
            this.subTasks = new ArrayList(tensorssAlongDimension);
        }
        if (tensorssAlongDimension == 1) {
            CPUTransformOpAction cPUTransformOpAction = new CPUTransformOpAction(this.op, this.threshold);
            if (z) {
                cPUTransformOpAction.invoke();
                return;
            } else {
                cPUTransformOpAction.invokeAsync();
                this.subTasks.add(cPUTransformOpAction);
                return;
            }
        }
        if (x.rank() == 2) {
            OpExecutionerUtil.Tensor1DStats tensor1DStats = OpExecutionerUtil.get1DTensorStats(x, chooseElementWiseTensorDimension);
            int tensorLength = tensor1DStats.getTensorLength();
            int elementWiseStride = tensor1DStats.getElementWiseStride();
            if (y != null) {
                OpExecutionerUtil.Tensor1DStats tensor1DStats2 = OpExecutionerUtil.get1DTensorStats(y, chooseElementWiseTensorDimension);
                int i = tensor1DStats2.elementWiseStride;
                if (x == z2) {
                    for (int i2 = 0; i2 < tensorssAlongDimension; i2++) {
                        int firstTensorOffset = tensor1DStats.getFirstTensorOffset() + (i2 * tensor1DStats.getTensorStartSeparation());
                        CPUTransformOpAction cPUTransformOpAction2 = new CPUTransformOpAction(this.op, this.threshold, tensorLength, firstTensorOffset, tensor1DStats2.getFirstTensorOffset() + (i2 * tensor1DStats2.getTensorStartSeparation()), firstTensorOffset, elementWiseStride, i, elementWiseStride);
                        if (z) {
                            cPUTransformOpAction2.fork();
                            arrayList.add(cPUTransformOpAction2);
                        } else {
                            cPUTransformOpAction2.invokeAsync();
                            this.subTasks.add(cPUTransformOpAction2);
                        }
                    }
                } else {
                    OpExecutionerUtil.Tensor1DStats tensor1DStats3 = OpExecutionerUtil.get1DTensorStats(z2, chooseElementWiseTensorDimension);
                    int elementWiseStride2 = tensor1DStats3.getElementWiseStride();
                    for (int i3 = 0; i3 < tensorssAlongDimension; i3++) {
                        CPUTransformOpAction cPUTransformOpAction3 = new CPUTransformOpAction(this.op, this.threshold, tensorLength, tensor1DStats.getFirstTensorOffset() + (i3 * tensor1DStats.getTensorStartSeparation()), tensor1DStats2.getFirstTensorOffset() + (i3 * tensor1DStats2.getTensorStartSeparation()), tensor1DStats3.getFirstTensorOffset() + (i3 * tensor1DStats3.getTensorStartSeparation()), elementWiseStride, i, elementWiseStride2);
                        if (z) {
                            cPUTransformOpAction3.fork();
                            arrayList.add(cPUTransformOpAction3);
                        } else {
                            cPUTransformOpAction3.invokeAsync();
                            this.subTasks.add(cPUTransformOpAction3);
                        }
                    }
                }
            } else if (x == z2) {
                for (int i4 = 0; i4 < tensorssAlongDimension; i4++) {
                    int firstTensorOffset2 = tensor1DStats.getFirstTensorOffset() + (i4 * tensor1DStats.getTensorStartSeparation());
                    CPUTransformOpAction cPUTransformOpAction4 = new CPUTransformOpAction(this.op, this.threshold, tensorLength, firstTensorOffset2, 0, firstTensorOffset2, elementWiseStride, 0, elementWiseStride);
                    if (z) {
                        cPUTransformOpAction4.fork();
                        arrayList.add(cPUTransformOpAction4);
                    } else {
                        cPUTransformOpAction4.invokeAsync();
                        this.subTasks.add(cPUTransformOpAction4);
                    }
                }
            } else {
                OpExecutionerUtil.Tensor1DStats tensor1DStats4 = OpExecutionerUtil.get1DTensorStats(z2, chooseElementWiseTensorDimension);
                int elementWiseStride3 = tensor1DStats4.getElementWiseStride();
                for (int i5 = 0; i5 < tensorssAlongDimension; i5++) {
                    CPUTransformOpAction cPUTransformOpAction5 = new CPUTransformOpAction(this.op, this.threshold, tensorLength, tensor1DStats.getFirstTensorOffset() + (i5 * tensor1DStats.getTensorStartSeparation()), 0, tensor1DStats4.getFirstTensorOffset() + (i5 * tensor1DStats4.getTensorStartSeparation()), elementWiseStride, 0, elementWiseStride3);
                    if (z) {
                        cPUTransformOpAction5.fork();
                        arrayList.add(cPUTransformOpAction5);
                    } else {
                        cPUTransformOpAction5.invokeAsync();
                        this.subTasks.add(cPUTransformOpAction5);
                    }
                }
            }
        } else {
            for (int i6 = 0; i6 < tensorssAlongDimension; i6++) {
                CPUTransformOpAction cPUTransformOpAction6 = new CPUTransformOpAction(this.op, this.threshold, i6, chooseElementWiseTensorDimension);
                if (z) {
                    cPUTransformOpAction6.fork();
                    arrayList.add(cPUTransformOpAction6);
                } else {
                    cPUTransformOpAction6.invokeAsync();
                    this.subTasks.add(cPUTransformOpAction6);
                }
            }
        }
        if (z) {
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                ((RecursiveAction) it.next()).join();
            }
        }
    }
}
