package org.apache.seatunnel.transform.nlpmodel.llm.remote;

import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.api.table.type.SqlType;
import org.apache.seatunnel.format.json.RowToJsonConverters;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.JsonNode;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ArrayNode;

/* loaded from: input_file:org/apache/seatunnel/transform/nlpmodel/llm/remote/AbstractModel.class */
public abstract class AbstractModel implements Model {
    protected static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
    private final RowToJsonConverters.RowToJsonConverter rowToJsonConverter = getRowToJsonConverter();
    private final SeaTunnelRowType rowType;
    private final String prompt;
    private final SqlType outputType;
    private final List<String> projectionColumns;

    public AbstractModel(SeaTunnelRowType seaTunnelRowType, SqlType sqlType, List<String> list, String str) {
        this.rowType = seaTunnelRowType;
        this.prompt = str;
        this.outputType = sqlType;
        this.projectionColumns = list;
    }

    public RowToJsonConverters.RowToJsonConverter getRowToJsonConverter() {
        RowToJsonConverters rowToJsonConverters = new RowToJsonConverters();
        if (this.projectionColumns == null || this.projectionColumns.isEmpty()) {
            return rowToJsonConverters.createConverter(this.rowType, null);
        }
        ArrayList arrayList = new ArrayList();
        for (String str : this.projectionColumns) {
            int indexOf = this.rowType.indexOf(str);
            if (indexOf == -1) {
                throw new IllegalArgumentException("Field name " + str + " does not exist in the row type.");
            }
            arrayList.add(this.rowType.getFieldType(indexOf));
        }
        return rowToJsonConverters.createConverter(new SeaTunnelRowType((String[]) this.projectionColumns.toArray(new String[0]), (SeaTunnelDataType[]) arrayList.toArray(new SeaTunnelDataType[0])), null);
    }

    private String getPromptWithLimit() {
        return this.prompt + "\n The following rules need to be followed: \n 1. The received data is an array, and the result is returned in the form of an array.\n 2. Only the result needs to be returned, and no other information can be returned.\n 3. The element type of the array is " + this.outputType.toString() + ".\n Eg: [\"value1\", \"value2\"]";
    }

    @Override // org.apache.seatunnel.transform.nlpmodel.llm.remote.Model
    public List<String> inference(List<SeaTunnelRow> list) throws IOException {
        ArrayNode createArrayNode = OBJECT_MAPPER.createArrayNode();
        for (SeaTunnelRow seaTunnelRow : list) {
            JsonNode createObjectNode = OBJECT_MAPPER.createObjectNode();
            this.rowToJsonConverter.convert(OBJECT_MAPPER, createObjectNode, createProjectionSeaTunnelRow(seaTunnelRow));
            createArrayNode.add(createObjectNode);
        }
        return chatWithModel(getPromptWithLimit(), OBJECT_MAPPER.writeValueAsString(createArrayNode));
    }

    @VisibleForTesting
    public SeaTunnelRow createProjectionSeaTunnelRow(SeaTunnelRow seaTunnelRow) {
        if (seaTunnelRow == null || this.projectionColumns == null || this.projectionColumns.isEmpty()) {
            return seaTunnelRow;
        }
        SeaTunnelRow seaTunnelRow2 = new SeaTunnelRow(this.projectionColumns.size());
        for (int i = 0; i < this.projectionColumns.size(); i++) {
            String str = this.projectionColumns.get(i);
            int indexOf = this.rowType.indexOf(str);
            if (indexOf == -1) {
                throw new IllegalArgumentException("Field name " + str + " does not exist in the row type.");
            }
            seaTunnelRow2.setField(i, seaTunnelRow.getField(indexOf));
        }
        return seaTunnelRow2;
    }

    protected abstract List<String> chatWithModel(String str, String str2) throws IOException;

    /* JADX INFO: Access modifiers changed from: protected */
    public String convertData(String str) {
        return this.outputType == SqlType.BOOLEAN ? str.toLowerCase() : str;
    }
}
