/*
 * Decompiled with CFR 0.152.
 */
package org.apache.seatunnel.connectors.seatunnel.qdrant.sink;

import io.qdrant.client.PointIdFactory;
import io.qdrant.client.QdrantClient;
import io.qdrant.client.ValueFactory;
import io.qdrant.client.VectorFactory;
import io.qdrant.client.grpc.JsonWithInt;
import io.qdrant.client.grpc.Points;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.seatunnel.api.table.catalog.CatalogTable;
import org.apache.seatunnel.api.table.catalog.PrimaryKey;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.api.table.type.SqlType;
import org.apache.seatunnel.common.exception.CommonErrorCode;
import org.apache.seatunnel.common.exception.SeaTunnelErrorCode;
import org.apache.seatunnel.common.utils.VectorUtils;
import org.apache.seatunnel.connectors.seatunnel.qdrant.config.QdrantParameters;
import org.apache.seatunnel.connectors.seatunnel.qdrant.exception.QdrantConnectorException;

public class QdrantBatchWriter {
    private final int batchSize;
    private final CatalogTable catalogTable;
    private final String collectionName;
    private final QdrantClient qdrantClient;
    private final List<Points.PointStruct> qdrantDataCache;
    private volatile int writeCount = 0;

    public QdrantBatchWriter(CatalogTable catalogTable, Integer batchSize, QdrantParameters params) {
        this.catalogTable = catalogTable;
        this.qdrantClient = params.buildQdrantClient();
        this.collectionName = params.getCollectionName();
        this.batchSize = batchSize;
        this.qdrantDataCache = new ArrayList<Points.PointStruct>(batchSize);
    }

    public void addToBatch(SeaTunnelRow element) {
        Points.PointStruct point = this.buildPoint(element);
        this.qdrantDataCache.add(point);
        ++this.writeCount;
    }

    public boolean needFlush() {
        return this.writeCount >= this.batchSize;
    }

    public synchronized void flush() {
        if (CollectionUtils.isEmpty(this.qdrantDataCache)) {
            return;
        }
        this.upsert();
        this.qdrantDataCache.clear();
        this.writeCount = 0;
    }

    public void close() {
        this.qdrantClient.close();
    }

    private Points.PointStruct buildPoint(SeaTunnelRow element) {
        SeaTunnelRowType seaTunnelRowType = this.catalogTable.getSeaTunnelRowType();
        PrimaryKey primaryKey = this.catalogTable.getTableSchema().getPrimaryKey();
        Points.PointStruct.Builder point = Points.PointStruct.newBuilder();
        Points.NamedVectors.Builder namedVectors = Points.NamedVectors.newBuilder();
        for (int i = 0; i < seaTunnelRowType.getFieldNames().length; ++i) {
            String fieldName = seaTunnelRowType.getFieldNames()[i];
            SeaTunnelDataType fieldType = seaTunnelRowType.getFieldType(i);
            Object value = element.getField(i);
            if (PrimaryKey.isPrimaryKeyField((PrimaryKey)primaryKey, (String)fieldName)) {
                point.setId(QdrantBatchWriter.pointId(fieldType, value));
                continue;
            }
            JsonWithInt.Value payloadValue = QdrantBatchWriter.buildPayload(fieldType, value);
            if (payloadValue != null) {
                point.putPayload(fieldName, payloadValue);
                continue;
            }
            Points.Vector vector = QdrantBatchWriter.buildVector(fieldType, value);
            if (vector == null) continue;
            namedVectors.putVectors(fieldName, vector);
        }
        if (!point.hasId()) {
            point.setId(PointIdFactory.id(UUID.randomUUID()));
        }
        point.setVectors(Points.Vectors.newBuilder().setVectors(namedVectors).build());
        return point.build();
    }

    private void upsert() {
        try {
            this.qdrantClient.upsertAsync(Points.UpsertPoints.newBuilder().setCollectionName(this.collectionName).addAllPoints(this.qdrantDataCache).build()).get();
        }
        catch (InterruptedException | ExecutionException e) {
            throw new RuntimeException("Upsert failed", e);
        }
    }

    public static Points.PointId pointId(SeaTunnelDataType<?> fieldType, Object value) {
        SqlType sqlType = fieldType.getSqlType();
        switch (sqlType) {
            case INT: {
                return PointIdFactory.id(Integer.parseInt(value.toString()));
            }
            case STRING: {
                return PointIdFactory.id(UUID.fromString(value.toString()));
            }
        }
        throw new QdrantConnectorException((SeaTunnelErrorCode)CommonErrorCode.UNSUPPORTED_DATA_TYPE, "Unexpected value type for point ID: " + sqlType.name());
    }

    public static JsonWithInt.Value buildPayload(SeaTunnelDataType<?> fieldType, Object value) {
        SqlType sqlType = fieldType.getSqlType();
        switch (sqlType) {
            case INT: 
            case SMALLINT: 
            case BIGINT: {
                return ValueFactory.value(Integer.parseInt(value.toString()));
            }
            case FLOAT: 
            case DOUBLE: {
                return ValueFactory.value(Long.parseLong(value.toString()));
            }
            case STRING: 
            case DATE: {
                return ValueFactory.value(value.toString());
            }
            case BOOLEAN: {
                return ValueFactory.value(Boolean.parseBoolean(value.toString()));
            }
        }
        return null;
    }

    public static Points.Vector buildVector(SeaTunnelDataType<?> fieldType, Object value) {
        SqlType sqlType = fieldType.getSqlType();
        switch (sqlType) {
            case FLOAT_VECTOR: 
            case FLOAT16_VECTOR: 
            case BFLOAT16_VECTOR: 
            case BINARY_VECTOR: {
                ByteBuffer floatVectorBuffer = (ByteBuffer)value;
                Float[] floats = VectorUtils.toFloatArray((ByteBuffer)floatVectorBuffer);
                return VectorFactory.vector(Arrays.stream(floats).collect(Collectors.toList()));
            }
        }
        return null;
    }
}

