package org.nd4j.linalg.api.parallel.tasks.cpu;

import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Future;
import java.util.concurrent.RecursiveAction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.parallel.tasks.Task;
import org.nd4j.linalg.api.parallel.tasks.TaskExecutorProvider;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/nd4j/linalg/api/parallel/tasks/cpu/BaseCPUAction.class */
public abstract class BaseCPUAction extends RecursiveAction implements Task<Void> {
    protected final int threshold;
    protected int n;
    protected int offsetX;
    protected int offsetY;
    protected int offsetZ;
    protected int incrX;
    protected int incrY;
    protected int incrZ;
    protected boolean doTensorFirst;
    protected int tensorIdx;
    protected int tensorDim;
    protected boolean executed;
    protected Future<Void> future;
    protected List<Task<Void>> subTasks;

    public BaseCPUAction(int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8) {
        this.executed = false;
        this.threshold = i;
        this.n = i2;
        this.offsetX = i3;
        this.offsetY = i4;
        this.offsetZ = i5;
        this.incrX = i6;
        this.incrY = i7;
        this.incrZ = i8;
        this.doTensorFirst = false;
    }

    public BaseCPUAction(Op op, int i) {
        this.executed = false;
        this.threshold = i;
        this.n = op.x().length();
        this.offsetX = op.x().offset();
        this.offsetY = op.y() != null ? op.y().offset() : 0;
        this.offsetZ = op.z() != null ? op.z().offset() : 0;
        this.incrX = op.x().elementWiseStride();
        this.incrY = op.y() != null ? op.y().elementWiseStride() : 0;
        this.incrZ = op.z() != null ? op.z().elementWiseStride() : 0;
        this.doTensorFirst = false;
        if (this.incrX < 1 || this.incrY < 1 || this.incrZ < 1) {
            if (this.incrX == -1) {
                this.incrX = op.x().reshape(1, ArrayUtil.prod(op.x().shape())).stride(1);
            }
            if (this.incrY == -1) {
                if (op.y() == op.x()) {
                    this.incrY = this.incrX;
                } else {
                    this.incrY = op.y().reshape(1, ArrayUtil.prod(op.y().shape())).stride(1);
                }
            }
            if (this.incrZ == -1) {
                if (op.z() == op.x()) {
                    this.incrZ = this.incrX;
                } else {
                    this.incrY = op.z().reshape(1, ArrayUtil.prod(op.z().shape())).stride(1);
                }
            }
        }
    }

    public BaseCPUAction(Op op, int i, int i2, int i3) {
        this.executed = false;
        this.doTensorFirst = true;
        this.threshold = i;
        this.tensorIdx = i2;
        this.tensorDim = i3;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void doTensorFirst(Op op) {
        INDArray x = op.x();
        INDArray y = op.y();
        INDArray z = op.z();
        INDArray tensorAlongDimension = x.tensorAlongDimension(this.tensorIdx, this.tensorDim);
        this.n = tensorAlongDimension.length();
        this.offsetX = tensorAlongDimension.offset();
        this.incrX = tensorAlongDimension.elementWiseStride();
        if (y == null) {
            this.offsetY = 0;
            this.incrY = 0;
        } else if (y == x) {
            this.offsetY = this.offsetX;
            this.incrY = this.incrX;
        } else {
            INDArray tensorAlongDimension2 = y.tensorAlongDimension(this.tensorIdx, this.tensorDim);
            this.offsetY = tensorAlongDimension2.offset();
            this.incrY = tensorAlongDimension2.elementWiseStride();
        }
        if (z == null) {
            this.offsetZ = 0;
            this.incrZ = 0;
            return;
        }
        if (z == x) {
            this.offsetZ = this.offsetX;
            this.incrZ = this.incrX;
        } else if (z == y) {
            this.offsetZ = this.offsetY;
            this.incrZ = this.incrY;
        } else {
            INDArray tensorAlongDimension3 = z.tensorAlongDimension(this.tensorIdx, this.tensorDim);
            this.offsetZ = tensorAlongDimension3.offset();
            this.incrZ = tensorAlongDimension3.elementWiseStride();
        }
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.Task
    public void invokeAsync() {
        this.future = TaskExecutorProvider.getTaskExecutor().executeAsync(this);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.nd4j.linalg.api.parallel.tasks.Task
    public Void blockUntilComplete() {
        if (this.future == null) {
            invokeAsync();
        }
        try {
            this.future.get();
            if (this.subTasks == null) {
                return null;
            }
            Iterator<Task<Void>> it = this.subTasks.iterator();
            while (it.hasNext()) {
                it.next().blockUntilComplete();
            }
            return null;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.nd4j.linalg.api.parallel.tasks.Task
    public Void invokeBlocking() {
        invokeAsync();
        return blockUntilComplete();
    }
}
