package org.nd4j.linalg.api.parallel;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.Future;
import java.util.concurrent.RunnableFuture;
import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.parallel.TaskCreator;
import org.nd4j.linalg.executors.ExecutorServiceProvider;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/api/parallel/DefaultParallelExecutioner.class */
public class DefaultParallelExecutioner implements ParallelExecutioner {
    private ExecutorService executorService;
    private ForkJoinPool forkJoinPool;
    private boolean enable;
    public static final String ENABLED = "org.nd4j.parallel.enabled";
    private static Logger log = LoggerFactory.getLogger(DefaultParallelExecutioner.class);

    public DefaultParallelExecutioner(ForkJoinPool forkJoinPool) {
        this.enable = true;
        this.enable = getEnabled();
        this.forkJoinPool = forkJoinPool;
        if (this.enable) {
            return;
        }
        log.warn("Nd4j Parallel execution disabled");
    }

    public DefaultParallelExecutioner(ExecutorService executorService) {
        this.enable = true;
        this.executorService = executorService;
        this.enable = getEnabled();
        if (this.enable) {
            return;
        }
        log.warn("Nd4j Parallel execution disabled");
    }

    public DefaultParallelExecutioner() {
        this(getEnabled() ? ExecutorServiceProvider.getForkJoinPool() : null);
    }

    public static boolean getEnabled() {
        return Boolean.parseBoolean(System.getProperty("org.nd4j.parallel.enabled", "true"));
    }

    @Override // org.nd4j.linalg.api.parallel.ParallelExecutioner
    public void setParallelEnabled(boolean z) {
        this.enable = z;
        if (z) {
            this.forkJoinPool = null;
        }
    }

    @Override // org.nd4j.linalg.api.parallel.ParallelExecutioner
    public boolean parallelEnabled() {
        return this.enable;
    }

    @Override // org.nd4j.linalg.api.parallel.ParallelExecutioner
    public INDArray execBasedOnArraysAlongDimension(INDArray iNDArray, Accumulation accumulation, OpExecutioner opExecutioner, int... iArr) {
        INDArray create = Nd4j.create(ArrayUtil.removeIndex(accumulation.x().shape(), iArr));
        if (!parallelEnabled()) {
            for (int i = 0; i < accumulation.x().tensorssAlongDimension(iArr); i++) {
                create.putScalar(i, opExecutioner.execAndReturn((Accumulation) accumulation.opForDimension(i, iArr)).getFinalResult().doubleValue());
            }
            return create;
        }
        if (this.forkJoinPool != null) {
            List<ForkJoinTask<INDArray>> parititonForkJoinBasedOnTensorsAlongDimension = TaskCreator.parititonForkJoinBasedOnTensorsAlongDimension(iNDArray, accumulation, opExecutioner, create, iArr);
            ArrayList arrayList = new ArrayList();
            Iterator<ForkJoinTask<INDArray>> it = parititonForkJoinBasedOnTensorsAlongDimension.iterator();
            while (it.hasNext()) {
                arrayList.add(this.forkJoinPool.submit(it.next()));
            }
            Iterator<ForkJoinTask<INDArray>> it2 = parititonForkJoinBasedOnTensorsAlongDimension.iterator();
            while (it2.hasNext()) {
                try {
                    it2.next().get();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                } catch (ExecutionException e2) {
                    e2.printStackTrace();
                }
            }
        } else {
            Pair<List<Runnable>, CountDownLatch> parititonRunnablesBasedOnTensorsAlongDimension = TaskCreator.parititonRunnablesBasedOnTensorsAlongDimension(iNDArray, accumulation, opExecutioner, iArr);
            ArrayList arrayList2 = new ArrayList();
            Iterator it3 = ((List) parititonRunnablesBasedOnTensorsAlongDimension.getFirst()).iterator();
            while (it3.hasNext()) {
                arrayList2.add((RunnableFuture) this.executorService.submit((Runnable) it3.next()));
            }
            Iterator it4 = arrayList2.iterator();
            while (it4.hasNext()) {
                try {
                    ((RunnableFuture) it4.next()).get();
                } catch (InterruptedException e3) {
                    Thread.currentThread().interrupt();
                } catch (ExecutionException e4) {
                    e4.printStackTrace();
                }
            }
        }
        return create;
    }

    @Override // org.nd4j.linalg.api.parallel.ParallelExecutioner
    public void execBasedOnArraysAlongDimension(INDArray iNDArray, Op op, OpExecutioner opExecutioner, int... iArr) {
        if (!parallelEnabled()) {
            int tensorssAlongDimension = iNDArray.tensorssAlongDimension(iArr);
            for (int i = 0; i < tensorssAlongDimension; i++) {
                opExecutioner.exec(op.opForDimension(i, iArr));
            }
            return;
        }
        if (this.forkJoinPool != null) {
            Pair<CountDownLatch, List<ForkJoinTask<INDArray>>> parititonForkJoinBasedOnTensorsAlongDimension = TaskCreator.parititonForkJoinBasedOnTensorsAlongDimension(iNDArray, op, opExecutioner, iArr);
            ArrayList arrayList = new ArrayList();
            Iterator it = ((List) parititonForkJoinBasedOnTensorsAlongDimension.getSecond()).iterator();
            while (it.hasNext()) {
                arrayList.add(this.forkJoinPool.submit((ForkJoinTask) it.next()));
            }
            Iterator it2 = arrayList.iterator();
            while (it2.hasNext()) {
                try {
                    ((ForkJoinTask) it2.next()).get();
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                } catch (ExecutionException e2) {
                    e2.printStackTrace();
                }
            }
            return;
        }
        Pair<List<Runnable>, CountDownLatch> parititonRunnablesBasedOnTensorsAlongDimension = TaskCreator.parititonRunnablesBasedOnTensorsAlongDimension(iNDArray, op, opExecutioner, iArr);
        ArrayList arrayList2 = new ArrayList();
        Iterator it3 = ((List) parititonRunnablesBasedOnTensorsAlongDimension.getFirst()).iterator();
        while (it3.hasNext()) {
            arrayList2.add((RunnableFuture) this.executorService.submit((Runnable) it3.next()));
        }
        Iterator it4 = arrayList2.iterator();
        while (it4.hasNext()) {
            try {
                ((RunnableFuture) it4.next()).get();
            } catch (InterruptedException e3) {
                Thread.currentThread().interrupt();
            } catch (ExecutionException e4) {
                e4.printStackTrace();
            }
        }
    }

    @Override // org.nd4j.linalg.api.parallel.ParallelExecutioner
    public void execBasedOnSlices(INDArray iNDArray, Op op, OpExecutioner opExecutioner) {
        if (!parallelEnabled()) {
            INDArray x = op.x();
            INDArray y = op.y();
            INDArray z = op.z();
            for (int i = 0; i < iNDArray.slices(); i++) {
                if (op.y() != null) {
                    op.setX(x.slice(i));
                    op.setY(y.slice(i));
                    op.setZ(z.slice(i));
                } else {
                    op.setX(x.slice(i));
                    op.setZ(z.slice(i));
                }
                opExecutioner.exec(op);
            }
            return;
        }
        if (this.forkJoinPool != null) {
            Iterator it = ((List) TaskCreator.parititonForkJoinBasedOnSlices(iNDArray, op, opExecutioner).getFirst()).iterator();
            while (it.hasNext()) {
                this.forkJoinPool.execute((ForkJoinTask<?>) it.next());
            }
            return;
        }
        Pair<List<Runnable>, CountDownLatch> parititonRunnablesBasedOnSlices = TaskCreator.parititonRunnablesBasedOnSlices(iNDArray, op, opExecutioner);
        ArrayList arrayList = new ArrayList();
        Iterator it2 = ((List) parititonRunnablesBasedOnSlices.getFirst()).iterator();
        while (it2.hasNext()) {
            arrayList.add((RunnableFuture) this.executorService.submit((Runnable) it2.next()));
        }
        Iterator it3 = arrayList.iterator();
        while (it3.hasNext()) {
            try {
                ((RunnableFuture) it3.next()).get();
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            } catch (ExecutionException e2) {
                e2.printStackTrace();
            }
        }
    }

    @Override // org.nd4j.linalg.api.parallel.ParallelExecutioner
    public void execBasedOnArraysAlongDimension(INDArray iNDArray, TaskCreator.INDArrayTask iNDArrayTask, int... iArr) {
        if (!parallelEnabled()) {
            int tensorssAlongDimension = iNDArray.tensorssAlongDimension(iArr);
            for (int i = 0; i < tensorssAlongDimension; i++) {
                iNDArrayTask.perform(iNDArray.tensorAlongDimension(i, iArr));
            }
            return;
        }
        if (this.forkJoinPool != null) {
            Pair<List<ForkJoinTask<INDArray>>, CountDownLatch> parititonForkJoinBasedOnTensorsAlongDimension = TaskCreator.parititonForkJoinBasedOnTensorsAlongDimension(iNDArray, iNDArrayTask, iArr);
            Iterator it = ((List) parititonForkJoinBasedOnTensorsAlongDimension.getFirst()).iterator();
            while (it.hasNext()) {
                this.forkJoinPool.submit((ForkJoinTask) it.next());
            }
            Iterator it2 = ((List) parititonForkJoinBasedOnTensorsAlongDimension.getFirst()).iterator();
            while (it2.hasNext()) {
                try {
                    ((ForkJoinTask) it2.next()).get();
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                } catch (ExecutionException e2) {
                    e2.printStackTrace();
                }
            }
            return;
        }
        List<Runnable> parititonRunnablesBasedOnTensorsAlongDimension = TaskCreator.parititonRunnablesBasedOnTensorsAlongDimension(iNDArray, iNDArrayTask, iArr);
        ArrayList arrayList = new ArrayList();
        Iterator<Runnable> it3 = parititonRunnablesBasedOnTensorsAlongDimension.iterator();
        while (it3.hasNext()) {
            arrayList.add((RunnableFuture) this.executorService.submit(it3.next()));
        }
        Iterator it4 = arrayList.iterator();
        while (it4.hasNext()) {
            try {
                ((RunnableFuture) it4.next()).get();
            } catch (InterruptedException e3) {
                e3.printStackTrace();
            } catch (ExecutionException e4) {
                e4.printStackTrace();
            }
        }
    }

    @Override // org.nd4j.linalg.api.parallel.ParallelExecutioner
    public void execBasedOnArraysAlongDimension(INDArray[] iNDArrayArr, TaskCreator.INDArrayTask iNDArrayTask, int... iArr) {
        if (!parallelEnabled()) {
            int tensorssAlongDimension = iNDArrayArr[0].tensorssAlongDimension(iArr);
            INDArray[] iNDArrayArr2 = new INDArray[iNDArrayArr.length];
            for (int i = 0; i < tensorssAlongDimension; i++) {
                for (int i2 = 0; i2 < iNDArrayArr2.length; i2++) {
                    iNDArrayArr2[i2] = iNDArrayArr[i].tensorAlongDimension(i2, iArr);
                }
                iNDArrayTask.perform(iNDArrayArr2);
            }
            return;
        }
        if (this.forkJoinPool != null) {
            Pair<List<ForkJoinTask<INDArray[]>>, CountDownLatch> parititonForkJoinBasedOnTensorsAlongDimension = TaskCreator.parititonForkJoinBasedOnTensorsAlongDimension(iNDArrayArr, iNDArrayTask, iArr);
            Iterator it = ((List) parititonForkJoinBasedOnTensorsAlongDimension.getFirst()).iterator();
            while (it.hasNext()) {
                this.forkJoinPool.execute((ForkJoinTask<?>) it.next());
            }
            Iterator it2 = ((List) parititonForkJoinBasedOnTensorsAlongDimension.getFirst()).iterator();
            while (it2.hasNext()) {
                try {
                    ((ForkJoinTask) it2.next()).get();
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                } catch (ExecutionException e2) {
                    e2.printStackTrace();
                }
            }
            return;
        }
        List<Runnable> parititonRunnablesBasedOnTensorsAlongDimension = TaskCreator.parititonRunnablesBasedOnTensorsAlongDimension(iNDArrayArr, iNDArrayTask, iArr);
        ArrayList arrayList = new ArrayList();
        Iterator<Runnable> it3 = parititonRunnablesBasedOnTensorsAlongDimension.iterator();
        while (it3.hasNext()) {
            arrayList.add((RunnableFuture) this.executorService.submit(it3.next()));
        }
        Iterator it4 = arrayList.iterator();
        while (it4.hasNext()) {
            try {
                ((RunnableFuture) it4.next()).get();
            } catch (InterruptedException e3) {
                e3.printStackTrace();
            } catch (ExecutionException e4) {
                e4.printStackTrace();
            }
        }
    }

    @Override // org.nd4j.linalg.api.parallel.ParallelExecutioner
    public void execBasedOnSlices(INDArray iNDArray, TaskCreator.INDArrayTask iNDArrayTask) {
        if (!parallelEnabled()) {
            for (int i = 0; i < iNDArray.slices(); i++) {
                iNDArrayTask.perform(iNDArray.slice(i));
            }
            return;
        }
        if (this.forkJoinPool != null) {
            Pair<List<ForkJoinTask<INDArray>>, CountDownLatch> parititonForkJoinBasedOnSlices = TaskCreator.parititonForkJoinBasedOnSlices(iNDArray, iNDArrayTask);
            Iterator it = ((List) parititonForkJoinBasedOnSlices.getFirst()).iterator();
            while (it.hasNext()) {
                this.forkJoinPool.execute((ForkJoinTask<?>) it.next());
            }
            Iterator it2 = ((List) parititonForkJoinBasedOnSlices.getFirst()).iterator();
            while (it2.hasNext()) {
                try {
                    ((ForkJoinTask) it2.next()).get();
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                } catch (ExecutionException e2) {
                    e2.printStackTrace();
                }
            }
            return;
        }
        Pair<List<Runnable>, CountDownLatch> parititonRunnablesBasedOnSlices = TaskCreator.parititonRunnablesBasedOnSlices(iNDArray, iNDArrayTask);
        ArrayList arrayList = new ArrayList();
        Iterator it3 = ((List) parititonRunnablesBasedOnSlices.getFirst()).iterator();
        while (it3.hasNext()) {
            arrayList.add((RunnableFuture) this.executorService.submit((Runnable) it3.next()));
        }
        Iterator it4 = arrayList.iterator();
        while (it4.hasNext()) {
            try {
                ((RunnableFuture) it4.next()).get();
            } catch (InterruptedException e3) {
                Thread.currentThread().interrupt();
            } catch (ExecutionException e4) {
                e4.printStackTrace();
            }
        }
    }

    @Override // org.nd4j.linalg.api.parallel.ParallelExecutioner
    public Future exec(Runnable runnable) {
        if (this.executorService == null) {
            this.executorService = ExecutorServiceProvider.getExecutorService();
        }
        return this.executorService.submit(runnable);
    }

    @Override // org.nd4j.linalg.api.parallel.ParallelExecutioner
    public <T> void exec(ForkJoinTask<T> forkJoinTask) {
        if (this.forkJoinPool == null) {
            this.forkJoinPool = ExecutorServiceProvider.getForkJoinPool();
        }
        this.forkJoinPool.execute((ForkJoinTask<?>) forkJoinTask);
    }
}
