/*
 * Decompiled with CFR 0.152.
 */
package org.noear.solon.ai.rag.splitter;

import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingRegistry;
import com.knuddels.jtokkit.api.EncodingType;
import com.knuddels.jtokkit.api.IntArrayList;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import org.noear.solon.ai.rag.splitter.TextSplitter;

public class TokenSizeTextSplitter
extends TextSplitter {
    private EncodingRegistry encodingRegistry = Encodings.newLazyEncodingRegistry();
    private EncodingType encodingType = EncodingType.CL100K_BASE;
    private final int chunkSize;
    private final int minChunkSizeChars;
    private final int minChunkLengthToEmbed;
    private final int maxChunkCount;
    private final boolean keepSeparator;

    public TokenSizeTextSplitter() {
        this(500);
    }

    public TokenSizeTextSplitter(int chunkSize) {
        this(chunkSize, 300);
    }

    public TokenSizeTextSplitter(int chunkSize, int minChunkSizeChars) {
        this(chunkSize, minChunkSizeChars, 5, 1000, true);
    }

    public TokenSizeTextSplitter(int chunkSize, int minChunkSizeChars, int minChunkLengthToEmbed, int maxChunkCount, boolean keepSeparator) {
        this.chunkSize = chunkSize;
        this.minChunkSizeChars = minChunkSizeChars;
        this.minChunkLengthToEmbed = minChunkLengthToEmbed;
        this.maxChunkCount = maxChunkCount;
        this.keepSeparator = keepSeparator;
    }

    public void setEncodingRegistry(EncodingRegistry encodingRegistry) {
        if (encodingRegistry != null) {
            this.encodingRegistry = encodingRegistry;
        }
    }

    public void setEncodingType(EncodingType encodingType) {
        if (encodingType != null) {
            this.encodingType = encodingType;
        }
    }

    @Override
    protected List<String> splitText(String text) {
        Encoding encoding = this.encodingRegistry.getEncoding(this.encodingType);
        ArrayList<String> chunks = new ArrayList<String>();
        if (text != null && !text.trim().isEmpty()) {
            String remaining_text;
            List<Integer> tokens = this.encodeTokens(encoding, text);
            int chunksCount = 0;
            while (!tokens.isEmpty() && chunksCount < this.maxChunkCount) {
                String chunkTextToAppend;
                List<Integer> chunk = tokens.subList(0, Math.min(this.chunkSize, tokens.size()));
                String chunkText = this.decodeTokens(encoding, chunk);
                if (chunkText.trim().isEmpty()) {
                    tokens = tokens.subList(chunk.size(), tokens.size());
                    continue;
                }
                int lastPunctuation = Math.max(chunkText.lastIndexOf(46), Math.max(chunkText.lastIndexOf(63), Math.max(chunkText.lastIndexOf(33), chunkText.lastIndexOf(10))));
                if (lastPunctuation > 0 && lastPunctuation > this.minChunkSizeChars) {
                    chunkText = chunkText.substring(0, lastPunctuation + 1);
                }
                String string = chunkTextToAppend = this.keepSeparator ? chunkText.trim() : chunkText.replace(System.lineSeparator(), " ").trim();
                if (chunkTextToAppend.length() > this.minChunkLengthToEmbed) {
                    chunks.add(chunkTextToAppend);
                }
                tokens = tokens.subList(this.encodeTokens(encoding, chunkText).size(), tokens.size());
                ++chunksCount;
            }
            if (!tokens.isEmpty() && (remaining_text = this.decodeTokens(encoding, tokens).replace(System.lineSeparator(), " ").trim()).length() > this.minChunkLengthToEmbed) {
                chunks.add(remaining_text);
            }
        }
        return chunks;
    }

    protected List<Integer> encodeTokens(Encoding encoding, String text) {
        Objects.requireNonNull(text, "tokens is null");
        return encoding.encode(text).boxed();
    }

    protected String decodeTokens(Encoding encoding, List<Integer> tokens) {
        Objects.requireNonNull(tokens, "tokens is null");
        IntArrayList tmp = new IntArrayList(tokens.size());
        tokens.forEach(arg_0 -> ((IntArrayList)tmp).add(arg_0));
        return encoding.decode(tmp);
    }
}

