package org.nd4j.linalg.api.parallel.tasks.cpu;

import java.util.Arrays;
import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.api.parallel.tasks.Task;
import org.nd4j.linalg.api.parallel.tasks.TaskFactory;
import org.nd4j.linalg.api.parallel.tasks.cpu.accumulation.CPUAccumulationAlongDimensionTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.accumulation.CPUAccumulationTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.accumulation.CPUAccumulationViaTensorTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.indexaccum.CPUIndexAccumulationAlongDimensionTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.indexaccum.CPUIndexAccumulationTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.indexaccum.CPUIndexAccumulationViaTensorTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.misc.CPUCol2ImTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.misc.CPUIm2ColTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.scalar.CPUScalarOpAction;
import org.nd4j.linalg.api.parallel.tasks.cpu.scalar.CPUScalarOpViaTensorAction;
import org.nd4j.linalg.api.parallel.tasks.cpu.transform.CPUTransformAlongDimensionTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.transform.CPUTransformOpAction;
import org.nd4j.linalg.api.parallel.tasks.cpu.transform.CPUTransformOpViaTensorTask;
import org.nd4j.linalg.api.parallel.tasks.cpu.vector.CpuBroadcastOp;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/api/parallel/tasks/cpu/CPUTaskFactory.class */
public class CPUTaskFactory implements TaskFactory {
    public static final String PARALLEL_THRESHOLD = "org.nd4j.parallel.cpu.threshold";
    private static Logger log = LoggerFactory.getLogger(CPUTaskFactory.class);
    protected int parallelThreshold;

    public CPUTaskFactory() {
        this.parallelThreshold = 1024;
        String property = System.getProperty(PARALLEL_THRESHOLD, null);
        if (property != null) {
            int i = -1;
            try {
                i = Integer.parseInt(property);
            } catch (NumberFormatException e) {
                log.warn("Error parsing CPUTaskFactory parallel threshold: \"" + property + "\"");
                log.warn("CPUTaskFactory parallel threshold set to default: " + this.parallelThreshold);
            }
            if (i != -1) {
                if (i <= 0) {
                    log.warn("Invalid CPUTaskFactory parallel threshold; using default: " + this.parallelThreshold);
                } else {
                    this.parallelThreshold = i;
                }
            }
        }
    }

    public void setParallelThreshold(int i) {
        this.parallelThreshold = i;
    }

    public int getParallelThreshold() {
        return this.parallelThreshold;
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.TaskFactory
    public Task<Void> getTransformAction(TransformOp transformOp) {
        boolean canDoOpDirectly;
        INDArray x = transformOp.x();
        INDArray y = transformOp.y();
        INDArray z = transformOp.z();
        if (y == null) {
            if (x == z) {
                canDoOpDirectly = OpExecutionerUtil.canDoOpDirectly(x);
            } else {
                canDoOpDirectly = OpExecutionerUtil.canDoOpDirectly(x, z);
                if (!Arrays.equals(x.shape(), z.shape())) {
                    throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", z.shape=" + Arrays.toString(z.shape()));
                }
            }
        } else if (x == z) {
            canDoOpDirectly = OpExecutionerUtil.canDoOpDirectly(x, y);
            if (!Arrays.equals(x.shape(), y.shape())) {
                throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", y.shape=" + Arrays.toString(y.shape()));
            }
        } else {
            canDoOpDirectly = OpExecutionerUtil.canDoOpDirectly(x, y, z);
            if (!Arrays.equals(x.shape(), y.shape()) || !Arrays.equals(x.shape(), z.shape())) {
                throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", y.shape=" + Arrays.toString(y.shape()) + ", z.shape=" + Arrays.toString(z.shape()));
            }
        }
        return canDoOpDirectly ? new CPUTransformOpAction(transformOp, this.parallelThreshold) : new CPUTransformOpViaTensorTask(transformOp, this.parallelThreshold);
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.TaskFactory
    public Task<Void> getTransformAction(TransformOp transformOp, int... iArr) {
        INDArray x = transformOp.x();
        INDArray y = transformOp.y();
        INDArray z = transformOp.z();
        if (y == null) {
            if (x != z && !Arrays.equals(x.shape(), z.shape())) {
                throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", z.shape=" + Arrays.toString(z.shape()));
            }
        } else if (x == z) {
            if (!Arrays.equals(x.shape(), y.shape())) {
                throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", y.shape=" + Arrays.toString(y.shape()));
            }
        } else if (!Arrays.equals(x.shape(), y.shape()) || !Arrays.equals(x.shape(), z.shape())) {
            throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", y.shape=" + Arrays.toString(y.shape()) + ", z.shape=" + Arrays.toString(z.shape()));
        }
        return new CPUTransformAlongDimensionTask(transformOp, this.parallelThreshold, iArr);
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.TaskFactory
    public Task<Void> getScalarAction(ScalarOp scalarOp) {
        boolean canDoOpDirectly;
        INDArray x = scalarOp.x();
        INDArray z = scalarOp.z();
        if (x == z) {
            canDoOpDirectly = OpExecutionerUtil.canDoOpDirectly(x);
        } else {
            canDoOpDirectly = OpExecutionerUtil.canDoOpDirectly(x, z);
            if (!Arrays.equals(x.shape(), z.shape())) {
                throw new IllegalArgumentException("Shapes do not match: x.shape= " + Arrays.toString(x.shape()) + ", z.shape=" + Arrays.toString(z.shape()));
            }
        }
        return canDoOpDirectly ? new CPUScalarOpAction(scalarOp, this.parallelThreshold) : new CPUScalarOpViaTensorAction(scalarOp, this.parallelThreshold);
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.TaskFactory
    public Task<Double> getAccumulationTask(Accumulation accumulation, boolean z) {
        boolean canDoOpDirectly;
        INDArray x = accumulation.x();
        INDArray y = accumulation.y();
        if (y == null) {
            canDoOpDirectly = OpExecutionerUtil.canDoOpDirectly(x);
        } else {
            canDoOpDirectly = OpExecutionerUtil.canDoOpDirectly(x, y);
            if (!Arrays.equals(x.shape(), y.shape())) {
                throw new IllegalArgumentException("Shapes do not match: x.shape= " + Arrays.toString(x.shape()) + ", y.shape= " + Arrays.toString(y.shape()));
            }
        }
        return canDoOpDirectly ? new CPUAccumulationTask(accumulation, this.parallelThreshold, z) : new CPUAccumulationViaTensorTask(accumulation, this.parallelThreshold, z);
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.TaskFactory
    public Task<Double> getAccumulationTask(Accumulation accumulation) {
        return getAccumulationTask(accumulation, true);
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.TaskFactory
    public Task<INDArray> getAccumulationTask(Accumulation accumulation, int... iArr) {
        INDArray x = accumulation.x();
        INDArray y = accumulation.y();
        INDArray z = accumulation.z();
        if (y == null) {
            if (x != z && !Arrays.equals(x.shape(), z.shape())) {
                throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", z.shape=" + Arrays.toString(z.shape()));
            }
        } else if (x == z) {
            if (!Arrays.equals(x.shape(), y.shape())) {
                throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", y.shape=" + Arrays.toString(y.shape()));
            }
        } else if (!Arrays.equals(x.shape(), y.shape()) || !Arrays.equals(x.shape(), z.shape())) {
            throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", y.shape=" + Arrays.toString(y.shape()) + ", z.shape=" + Arrays.toString(z.shape()));
        }
        return new CPUAccumulationAlongDimensionTask(accumulation, this.parallelThreshold, iArr);
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.TaskFactory
    public Task<Pair<Double, Integer>> getIndexAccumulationTask(IndexAccumulation indexAccumulation) {
        INDArray x = indexAccumulation.x();
        INDArray y = indexAccumulation.y();
        if (y == null || Arrays.equals(x.shape(), y.shape())) {
            return x.isVector() ? true : x.ordering() == 'c' ? y == null ? OpExecutionerUtil.canDoOpDirectly(x) : OpExecutionerUtil.canDoOpDirectly(x, y) : false ? new CPUIndexAccumulationTask(indexAccumulation, this.parallelThreshold, true) : new CPUIndexAccumulationViaTensorTask(indexAccumulation, this.parallelThreshold, true);
        }
        throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", y.shape=" + Arrays.toString(y.shape()));
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.TaskFactory
    public Task<INDArray> getIndexAccumulationTask(IndexAccumulation indexAccumulation, int... iArr) {
        INDArray x = indexAccumulation.x();
        INDArray y = indexAccumulation.y();
        if (y == null || Arrays.equals(x.shape(), y.shape())) {
            return new CPUIndexAccumulationAlongDimensionTask(indexAccumulation, this.parallelThreshold, iArr);
        }
        throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", y.shape=" + Arrays.toString(y.shape()));
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.TaskFactory
    public Task<Void> getBroadcastOpAction(BroadcastOp broadcastOp) {
        INDArray x = broadcastOp.x();
        INDArray y = broadcastOp.y();
        if (x.size(broadcastOp.getDimension()[0]) != y.length()) {
            throw new IllegalArgumentException("Shapes do not match: x.shape=" + Arrays.toString(x.shape()) + ", y.shape=" + Arrays.toString(y.shape()) + ", y should be vector with length=x.size(" + broadcastOp.getDimension() + ")");
        }
        return new CpuBroadcastOp(broadcastOp, this.parallelThreshold);
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.TaskFactory
    public Task<INDArray> getIm2ColTask(INDArray iNDArray, int i, int i2, int i3, int i4, int i5, int i6, boolean z) {
        return new CPUIm2ColTask(iNDArray, i, i2, i3, i4, i5, i6, z, this.parallelThreshold);
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.TaskFactory
    public Task<INDArray> getCol2ImTask(INDArray iNDArray, int i, int i2, int i3, int i4, int i5, int i6) {
        return new CPUCol2ImTask(iNDArray, i, i2, i3, i4, i5, i6, this.parallelThreshold);
    }
}
