/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.shape.tensor;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.api.shape.tensor.TensorCalculator;
import org.nd4j.linalg.util.ArrayUtil;

public class TensorCalculator1d
implements TensorCalculator {
    private int baseOffset;
    private int[] shape;
    private int[] stride;
    private int tensorDim;
    private int[] shapeMinusTensorDim;
    private int elementWiseStride;
    private int[] tensorShape;
    private int[] tensorStride;
    private int numTensors;

    public TensorCalculator1d(INDArray arr, int tensorDim) {
        this(arr.offset(), arr.shape(), arr.stride(), tensorDim);
    }

    public TensorCalculator1d(int baseOffset, int[] shape, int[] stride, int tensorDim) {
        this.baseOffset = baseOffset;
        this.shape = shape;
        this.stride = stride;
        if (tensorDim < 0) {
            tensorDim += shape.length;
        }
        this.tensorDim = tensorDim;
        this.shapeMinusTensorDim = ArrayUtil.removeIndex((int[])shape, (int)tensorDim);
        this.elementWiseStride = stride[tensorDim];
        this.tensorShape = new int[]{1, shape[tensorDim]};
        this.tensorStride = new int[]{1, this.elementWiseStride};
        this.numTensors = ArrayUtil.prod((int[])this.shapeMinusTensorDim);
    }

    @Override
    public int getNumTensors() {
        return this.numTensors;
    }

    @Override
    public int getOffsetForTensor(int tensorIdx) {
        int[] indicesMinusTensorDim = Shape.ind2subC(this.shapeMinusTensorDim, (long)tensorIdx);
        int offset = this.baseOffset;
        int j = 0;
        for (int i = 0; i < this.shape.length; ++i) {
            if (i == this.tensorDim) continue;
            offset += indicesMinusTensorDim[j++] * this.stride[i];
        }
        return offset;
    }

    @Override
    public int[] getShape() {
        return this.tensorShape;
    }

    @Override
    public int[] getStride() {
        return this.tensorStride;
    }

    @Override
    public int getBaseOffset() {
        return this.baseOffset;
    }

    @Override
    public int getElementWiseStrideForTensor() {
        return this.elementWiseStride;
    }

    @Override
    public int getTensorLength() {
        return this.shape[this.tensorDim];
    }
}

