/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicdataset.nlp;

import ai.djl.basicdataset.BasicDatasets;
import ai.djl.basicdataset.utils.TextData;
import ai.djl.engine.Engine;
import ai.djl.modality.nlp.Vocabulary;
import ai.djl.modality.nlp.embedding.EmbeddingException;
import ai.djl.modality.nlp.embedding.TextEmbedding;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.Repository;
import ai.djl.repository.Resource;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;

public abstract class TextDataset
extends RandomAccessDataset {
    protected TextData sourceTextData;
    protected TextData targetTextData;
    protected NDManager manager;
    protected Dataset.Usage usage;
    protected Resource resource;
    protected boolean prepared;
    protected List<Sample> samples;

    public TextDataset(Builder<?> builder) {
        super(builder);
        this.sourceTextData = new TextData(TextData.getDefaultConfiguration().update(builder.sourceConfiguration));
        this.targetTextData = new TextData(TextData.getDefaultConfiguration().update(builder.targetConfiguration));
        this.manager = builder.manager;
        this.manager.setName("textDataset");
        this.usage = builder.usage;
    }

    public TextEmbedding getTextEmbedding(boolean source) {
        TextData textData = source ? this.sourceTextData : this.targetTextData;
        return textData.getTextEmbedding();
    }

    public Vocabulary getVocabulary(boolean source) {
        TextData textData = source ? this.sourceTextData : this.targetTextData;
        return textData.getVocabulary();
    }

    public String getRawText(long index, boolean source) {
        TextData textData = source ? this.sourceTextData : this.targetTextData;
        return textData.getRawText(index);
    }

    public List<String> getProcessedText(long index, boolean source) {
        TextData textData = source ? this.sourceTextData : this.targetTextData;
        return textData.getProcessedText(index);
    }

    public List<Sample> getSamples() {
        if (this.samples == null) {
            this.samples = new ArrayList<Sample>();
            int i = 0;
            while ((long)i < this.size()) {
                List<String> text = this.getProcessedText(i, true);
                this.samples.add(new Sample(i, text.size()));
                ++i;
            }
            this.samples.sort(Comparator.comparingInt(o -> ((Sample)o).sentenceLength));
        }
        return this.samples;
    }

    protected void preprocess(List<String> newTextData, boolean source) throws EmbeddingException {
        TextData textData = source ? this.sourceTextData : this.targetTextData;
        textData.preprocess(this.manager, newTextData.subList(0, (int)Math.min(this.limit, (long)newTextData.size())));
    }

    public static abstract class Builder<T extends Builder<T>>
    extends RandomAccessDataset.BaseBuilder<T> {
        TextData.Configuration sourceConfiguration = new TextData.Configuration();
        TextData.Configuration targetConfiguration = new TextData.Configuration();
        NDManager manager = Engine.getInstance().newBaseManager();
        protected Repository repository = BasicDatasets.REPOSITORY;
        protected String groupId = "ai.djl.basicdataset";
        protected String artifactId;
        protected Dataset.Usage usage = Dataset.Usage.TRAIN;

        Builder() {
        }

        public T setSourceConfiguration(TextData.Configuration sourceConfiguration) {
            this.sourceConfiguration = sourceConfiguration;
            return (T)((Object)((Builder)this.self()));
        }

        public T setTargetConfiguration(TextData.Configuration targetConfiguration) {
            this.targetConfiguration = targetConfiguration;
            return (T)((Object)((Builder)this.self()));
        }

        public T optManager(NDManager manager) {
            this.manager = manager.newSubManager();
            return (T)((Object)((Builder)this.self()));
        }

        public T optUsage(Dataset.Usage usage) {
            this.usage = usage;
            return (T)((Object)((Builder)this.self()));
        }

        public T optRepository(Repository repository) {
            this.repository = repository;
            return (T)((Object)((Builder)this.self()));
        }

        public T optGroupId(String groupId) {
            this.groupId = groupId;
            return (T)((Object)((Builder)this.self()));
        }

        public T optArtifactId(String artifactId) {
            if (artifactId.contains(":")) {
                String[] tokens = artifactId.split(":");
                this.groupId = tokens[0];
                this.artifactId = tokens[1];
            } else {
                this.artifactId = artifactId;
            }
            return (T)((Object)((Builder)this.self()));
        }
    }

    public static final class Sample {
        private int sentenceLength;
        private long index;

        public Sample(int index, int sentenceLength) {
            this.index = index;
            this.sentenceLength = sentenceLength;
        }

        public int getSentenceLength() {
            return this.sentenceLength;
        }

        public long getIndex() {
            return this.index;
        }
    }
}

