/*
 * Decompiled with CFR 0.152.
 */
package org.apache.seatunnel.connectors.seatunnel.milvus.sink;

import com.google.gson.JsonObject;
import io.milvus.v2.client.ConnectConfig;
import io.milvus.v2.client.MilvusClientV2;
import io.milvus.v2.common.IndexParam;
import io.milvus.v2.service.collection.request.AlterCollectionReq;
import io.milvus.v2.service.collection.request.DescribeCollectionReq;
import io.milvus.v2.service.collection.request.GetLoadStateReq;
import io.milvus.v2.service.collection.request.LoadCollectionReq;
import io.milvus.v2.service.collection.response.DescribeCollectionResp;
import io.milvus.v2.service.index.request.CreateIndexReq;
import io.milvus.v2.service.partition.request.CreatePartitionReq;
import io.milvus.v2.service.partition.request.HasPartitionReq;
import io.milvus.v2.service.vector.request.InsertReq;
import io.milvus.v2.service.vector.request.UpsertReq;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.lang3.StringUtils;
import org.apache.seatunnel.api.configuration.ReadonlyConfig;
import org.apache.seatunnel.api.table.catalog.CatalogTable;
import org.apache.seatunnel.api.table.catalog.PrimaryKey;
import org.apache.seatunnel.api.table.type.CommonOptions;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.common.exception.SeaTunnelErrorCode;
import org.apache.seatunnel.common.utils.SeaTunnelException;
import org.apache.seatunnel.connectors.seatunnel.milvus.config.MilvusSinkOptions;
import org.apache.seatunnel.connectors.seatunnel.milvus.exception.MilvusConnectionErrorCode;
import org.apache.seatunnel.connectors.seatunnel.milvus.exception.MilvusConnectorException;
import org.apache.seatunnel.connectors.seatunnel.milvus.utils.MilvusConnectorUtils;
import org.apache.seatunnel.connectors.seatunnel.milvus.utils.sink.MilvusSinkConverter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MilvusBufferBatchWriter {
    private static final Logger log = LoggerFactory.getLogger(MilvusBufferBatchWriter.class);
    private final CatalogTable catalogTable;
    private final ReadonlyConfig config;
    private final String collectionName;
    private final Boolean autoId;
    private final Boolean enableUpsert;
    private Boolean hasPartitionKey;
    private MilvusClientV2 milvusClient;
    private final MilvusSinkConverter milvusSinkConverter;
    private int batchSize;
    private volatile Map<String, List<JsonObject>> milvusDataCache;
    private final AtomicLong writeCache = new AtomicLong();
    private final AtomicLong writeCount = new AtomicLong();
    private final List<String> jsonFieldNames;
    private final String dynamicFieldName;

    public MilvusBufferBatchWriter(CatalogTable catalogTable, ReadonlyConfig config) throws SeaTunnelException {
        this.catalogTable = catalogTable;
        this.config = config;
        this.autoId = this.getAutoId(catalogTable.getTableSchema().getPrimaryKey(), (Boolean)config.get(MilvusSinkOptions.ENABLE_AUTO_ID));
        this.enableUpsert = (Boolean)config.get(MilvusSinkOptions.ENABLE_UPSERT);
        this.batchSize = (Integer)config.get(MilvusSinkOptions.BATCH_SIZE);
        this.collectionName = catalogTable.getTablePath().getTableName();
        this.milvusDataCache = new HashMap<String, List<JsonObject>>();
        this.milvusSinkConverter = new MilvusSinkConverter();
        this.dynamicFieldName = MilvusConnectorUtils.getDynamicField(catalogTable);
        this.jsonFieldNames = MilvusConnectorUtils.getJsonField(catalogTable);
        this.initMilvusClient(config);
    }

    private void initMilvusClient(ReadonlyConfig config) throws SeaTunnelException {
        try {
            log.info("begin to init Milvus client");
            String dbName = this.catalogTable.getTablePath().getDatabaseName();
            String collectionName = this.catalogTable.getTablePath().getTableName();
            ConnectConfig connectConfig = ConnectConfig.builder().uri((String)config.get(MilvusSinkOptions.URL)).token((String)config.get(MilvusSinkOptions.TOKEN)).build();
            this.milvusClient = new MilvusClientV2(connectConfig);
            if (StringUtils.isNotEmpty(dbName)) {
                this.milvusClient.useDatabase(dbName);
            }
            this.hasPartitionKey = MilvusConnectorUtils.hasPartitionKey(this.milvusClient, collectionName);
            if ((Integer)config.get(MilvusSinkOptions.RATE_LIMIT) > 0) {
                log.info("set rate limit for collection: " + collectionName);
                HashMap<String, String> properties = new HashMap<String, String>();
                properties.put("collection.insertRate.max.mb", ((Integer)config.get(MilvusSinkOptions.RATE_LIMIT)).toString());
                properties.put("collection.upsertRate.max.mb", ((Integer)config.get(MilvusSinkOptions.RATE_LIMIT)).toString());
                Object alterCollectionReq = ((AlterCollectionReq.AlterCollectionReqBuilder)((AlterCollectionReq.AlterCollectionReqBuilder)AlterCollectionReq.builder().collectionName(collectionName)).properties(properties)).build();
                this.milvusClient.alterCollection((AlterCollectionReq)alterCollectionReq);
            }
            try {
                if (((Boolean)config.get(MilvusSinkOptions.CREATE_INDEX)).booleanValue()) {
                    log.info("create index for collection: " + collectionName);
                    DescribeCollectionResp describeCollectionResp = this.milvusClient.describeCollection((DescribeCollectionReq)((DescribeCollectionReq.DescribeCollectionReqBuilder)DescribeCollectionReq.builder().collectionName(collectionName)).build());
                    ArrayList<IndexParam> indexParams = new ArrayList<IndexParam>();
                    for (String fieldName : describeCollectionResp.getVectorFieldNames()) {
                        Object indexParam = ((IndexParam.IndexParamBuilder)((IndexParam.IndexParamBuilder)IndexParam.builder().fieldName(fieldName)).metricType(IndexParam.MetricType.COSINE)).build();
                        indexParams.add((IndexParam)indexParam);
                    }
                    Object createIndexReq = ((CreateIndexReq.CreateIndexReqBuilder)((CreateIndexReq.CreateIndexReqBuilder)CreateIndexReq.builder().collectionName(collectionName)).indexParams(indexParams)).build();
                    this.milvusClient.createIndex((CreateIndexReq)createIndexReq);
                }
            }
            catch (Exception e) {
                log.warn("create index failed, maybe index already exists");
            }
            if (((Boolean)config.get(MilvusSinkOptions.LOAD_COLLECTION)).booleanValue() && !this.milvusClient.getLoadState((GetLoadStateReq)((GetLoadStateReq.GetLoadStateReqBuilder)GetLoadStateReq.builder().collectionName(collectionName)).build()).booleanValue()) {
                log.info("load collection: " + collectionName);
                this.milvusClient.loadCollection((LoadCollectionReq)((LoadCollectionReq.LoadCollectionReqBuilder)LoadCollectionReq.builder().collectionName(collectionName)).build());
            }
            log.info("init Milvus client success");
        }
        catch (Exception e) {
            log.error("init Milvus client failed", (Throwable)e);
            throw new MilvusConnectorException((SeaTunnelErrorCode)MilvusConnectionErrorCode.INIT_CLIENT_ERROR, e);
        }
    }

    private Boolean getAutoId(PrimaryKey primaryKey, Boolean enableAutoId) {
        if (null != primaryKey && null != primaryKey.getEnableAutoId()) {
            return primaryKey.getEnableAutoId();
        }
        return enableAutoId;
    }

    public void addToBatch(SeaTunnelRow element) {
        Boolean hasPartition;
        String partitionName;
        if (element.getOptions().containsKey(CommonOptions.PARTITION.getName()) && !this.milvusDataCache.containsKey(partitionName = element.getOptions().get(CommonOptions.PARTITION.getName()).toString()) && !(hasPartition = this.milvusClient.hasPartition((HasPartitionReq)((HasPartitionReq.HasPartitionReqBuilder)((HasPartitionReq.HasPartitionReqBuilder)HasPartitionReq.builder().collectionName(this.collectionName)).partitionName(partitionName)).build())).booleanValue()) {
            log.info("create partition: " + partitionName);
            Object createPartitionReq = ((CreatePartitionReq.CreatePartitionReqBuilder)((CreatePartitionReq.CreatePartitionReqBuilder)CreatePartitionReq.builder().collectionName(this.collectionName)).partitionName(partitionName)).build();
            this.milvusClient.createPartition((CreatePartitionReq)createPartitionReq);
            log.info("create partition success");
        }
        JsonObject data = this.milvusSinkConverter.buildMilvusData(this.catalogTable, this.config, this.jsonFieldNames, this.dynamicFieldName, element);
        String partitionName2 = element.getOptions().getOrDefault(CommonOptions.PARTITION.getName(), "_default").toString();
        this.milvusDataCache.computeIfAbsent(partitionName2, k -> new ArrayList());
        this.milvusDataCache.get(partitionName2).add(data);
        this.writeCache.incrementAndGet();
    }

    public boolean needFlush() {
        return this.writeCache.get() >= (long)this.batchSize;
    }

    public void flush() throws Exception {
        log.info("Starting to put {} records to Milvus.", (Object)this.writeCache.get());
        if (this.milvusDataCache.isEmpty()) {
            return;
        }
        this.writeData2Collection();
        log.info("Successfully put {} records to Milvus. Total records written: {}", (Object)this.writeCache.get(), (Object)this.writeCount.get());
        this.milvusDataCache = new HashMap<String, List<JsonObject>>();
        this.writeCache.set(0L);
    }

    public void close() throws Exception {
        String collectionName = this.catalogTable.getTablePath().getTableName();
        HashMap<String, String> properties = new HashMap<String, String>();
        properties.put("collection.insertRate.max.mb", "-1");
        properties.put("collection.upsertRate.max.mb", "-1");
        Object alterCollectionReq = ((AlterCollectionReq.AlterCollectionReqBuilder)((AlterCollectionReq.AlterCollectionReqBuilder)AlterCollectionReq.builder().collectionName(collectionName)).properties(properties)).build();
        this.milvusClient.alterCollection((AlterCollectionReq)alterCollectionReq);
        this.milvusClient.close(10L);
    }

    private void writeData2Collection() throws Exception {
        try {
            for (String partitionName : this.milvusDataCache.keySet()) {
                List<JsonObject> data = this.milvusDataCache.get(partitionName);
                if (Objects.equals(partitionName, "_default") || this.hasPartitionKey.booleanValue()) {
                    partitionName = null;
                }
                if (this.enableUpsert.booleanValue() && !this.autoId.booleanValue()) {
                    this.upsertWrite(partitionName, data);
                    continue;
                }
                this.insertWrite(partitionName, data);
            }
        }
        catch (Exception e) {
            log.error("write data to Milvus failed", (Throwable)e);
            log.error("error data: " + this.milvusDataCache);
            throw new MilvusConnectorException(MilvusConnectionErrorCode.WRITE_DATA_FAIL);
        }
        this.writeCount.addAndGet(this.writeCache.get());
    }

    private void upsertWrite(String partitionName, List<JsonObject> data) throws InterruptedException {
        Object upsertReq = ((UpsertReq.UpsertReqBuilder)((UpsertReq.UpsertReqBuilder)UpsertReq.builder().collectionName(this.collectionName)).data(data)).build();
        if (StringUtils.isNotEmpty(partitionName)) {
            ((UpsertReq)upsertReq).setPartitionName(partitionName);
        }
        try {
            this.milvusClient.upsert((UpsertReq)upsertReq);
        }
        catch (Exception e) {
            if (e.getMessage().contains("rate limit exceeded") || e.getMessage().contains("received message larger than max")) {
                if (data.size() > 10) {
                    log.warn("upsert data failed, retry in smaller chunks: {} ", (Object)(data.size() / 2));
                    this.batchSize /= 2;
                    log.info("sleep 1 minute to avoid rate limit");
                    Thread.sleep(60000L);
                    log.info("sleep 1 minute success");
                    List<JsonObject> firstHalf = data.subList(0, data.size() / 2);
                    List<JsonObject> secondHalf = data.subList(data.size() / 2, data.size());
                    this.upsertWrite(partitionName, firstHalf);
                    this.upsertWrite(partitionName, secondHalf);
                }
                throw new MilvusConnectorException(MilvusConnectionErrorCode.WRITE_DATA_FAIL, "upsert data failed, size down to 10, break", e);
            }
            throw new MilvusConnectorException(MilvusConnectionErrorCode.WRITE_DATA_FAIL, "upsert data failed with unknown exception", e);
        }
        log.info("upsert data success");
    }

    private void insertWrite(String partitionName, List<JsonObject> data) {
        Object insertReq = ((InsertReq.InsertReqBuilder)((InsertReq.InsertReqBuilder)InsertReq.builder().collectionName(this.collectionName)).data(data)).build();
        if (StringUtils.isNotEmpty(partitionName)) {
            ((InsertReq)insertReq).setPartitionName(partitionName);
        }
        try {
            this.milvusClient.insert((InsertReq)insertReq);
        }
        catch (Exception e) {
            if (e.getMessage().contains("rate limit exceeded") || e.getMessage().contains("received message larger than max")) {
                if (data.size() > 10) {
                    log.warn("insert data failed, retry in smaller chunks: {} ", (Object)(data.size() / 2));
                    List<JsonObject> firstHalf = data.subList(0, data.size() / 2);
                    List<JsonObject> secondHalf = data.subList(data.size() / 2, data.size());
                    this.batchSize /= 2;
                    this.insertWrite(partitionName, firstHalf);
                    this.insertWrite(partitionName, secondHalf);
                }
                throw new MilvusConnectorException(MilvusConnectionErrorCode.WRITE_DATA_FAIL, "insert data failed", e);
            }
            throw new MilvusConnectorException(MilvusConnectionErrorCode.WRITE_DATA_FAIL, "insert data failed with unknown exception", e);
        }
    }
}

