package org.springframework.ai.vectorstore.milvus;

import com.google.gson.Gson;
import com.google.gson.JsonObject;
import com.google.gson.reflect.TypeToken;
import io.milvus.client.MilvusServiceClient;
import io.milvus.common.clientenum.ConsistencyLevelEnum;
import io.milvus.exception.ParamException;
import io.milvus.grpc.DataType;
import io.milvus.grpc.MutationResult;
import io.milvus.grpc.SearchResults;
import io.milvus.param.IndexType;
import io.milvus.param.MetricType;
import io.milvus.param.R;
import io.milvus.param.collection.CreateCollectionParam;
import io.milvus.param.collection.DropCollectionParam;
import io.milvus.param.collection.FieldType;
import io.milvus.param.collection.HasCollectionParam;
import io.milvus.param.collection.LoadCollectionParam;
import io.milvus.param.collection.ReleaseCollectionParam;
import io.milvus.param.dml.DeleteParam;
import io.milvus.param.dml.InsertParam;
import io.milvus.param.dml.SearchParam;
import io.milvus.param.index.CreateIndexParam;
import io.milvus.param.index.DescribeIndexParam;
import io.milvus.param.index.DropIndexParam;
import io.milvus.response.QueryResultsWrapper;
import io.milvus.response.SearchResultsWrapper;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.model.EmbeddingUtils;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/* loaded from: input_file:org/springframework/ai/vectorstore/milvus/MilvusVectorStore.class */
public class MilvusVectorStore extends AbstractObservationVectorStore implements InitializingBean {
    public static final int OPENAI_EMBEDDING_DIMENSION_SIZE = 1536;
    public static final int INVALID_EMBEDDING_DIMENSION = -1;
    public static final String DEFAULT_DATABASE_NAME = "default";
    public static final String DEFAULT_COLLECTION_NAME = "vector_store";
    public static final String DOC_ID_FIELD_NAME = "doc_id";
    public static final String CONTENT_FIELD_NAME = "content";
    public static final String METADATA_FIELD_NAME = "metadata";
    public static final String EMBEDDING_FIELD_NAME = "embedding";
    public static final String SIMILARITY_FIELD_NAME = "score";
    private static final Logger logger = LoggerFactory.getLogger(MilvusVectorStore.class);
    private static final Map<MetricType, VectorStoreSimilarityMetric> SIMILARITY_TYPE_MAPPING = Map.of(MetricType.COSINE, VectorStoreSimilarityMetric.COSINE, MetricType.L2, VectorStoreSimilarityMetric.EUCLIDEAN, MetricType.IP, VectorStoreSimilarityMetric.DOT);
    public final FilterExpressionConverter filterExpressionConverter;
    private final MilvusServiceClient milvusClient;
    private final boolean initializeSchema;
    private final String databaseName;
    private final String collectionName;
    private final int embeddingDimension;
    private final IndexType indexType;
    private final MetricType metricType;
    private final String indexParameters;
    private final String idFieldName;
    private final boolean isAutoId;
    private final String contentFieldName;
    private final String metadataFieldName;
    private final String embeddingFieldName;

    /* loaded from: input_file:org/springframework/ai/vectorstore/milvus/MilvusVectorStore$Builder.class */
    public static class Builder extends AbstractVectorStoreBuilder<Builder> {
        private final MilvusServiceClient milvusClient;
        private String databaseName;
        private String collectionName;
        private int embeddingDimension;
        private IndexType indexType;
        private MetricType metricType;
        private String indexParameters;
        private String idFieldName;
        private boolean isAutoId;
        private String contentFieldName;
        private String metadataFieldName;
        private String embeddingFieldName;
        private boolean initializeSchema;

        private Builder(MilvusServiceClient milvusServiceClient, EmbeddingModel embeddingModel) {
            super(embeddingModel);
            this.databaseName = MilvusVectorStore.DEFAULT_DATABASE_NAME;
            this.collectionName = MilvusVectorStore.DEFAULT_COLLECTION_NAME;
            this.embeddingDimension = -1;
            this.indexType = IndexType.IVF_FLAT;
            this.metricType = MetricType.COSINE;
            this.indexParameters = "{\"nlist\":1024}";
            this.idFieldName = MilvusVectorStore.DOC_ID_FIELD_NAME;
            this.isAutoId = false;
            this.contentFieldName = MilvusVectorStore.CONTENT_FIELD_NAME;
            this.metadataFieldName = MilvusVectorStore.METADATA_FIELD_NAME;
            this.embeddingFieldName = MilvusVectorStore.EMBEDDING_FIELD_NAME;
            this.initializeSchema = false;
            Assert.notNull(milvusServiceClient, "milvusClient must not be null");
            this.milvusClient = milvusServiceClient;
        }

        public Builder metricType(MetricType metricType) {
            Assert.notNull(metricType, "Collection Name must not be empty");
            Assert.isTrue(metricType == MetricType.IP || metricType == MetricType.L2 || metricType == MetricType.COSINE, "Only the text metric types IP and L2 are supported");
            this.metricType = metricType;
            return this;
        }

        public Builder indexType(IndexType indexType) {
            this.indexType = indexType;
            return this;
        }

        public Builder indexParameters(String str) {
            this.indexParameters = str;
            return this;
        }

        public Builder databaseName(String str) {
            this.databaseName = str;
            return this;
        }

        public Builder collectionName(String str) {
            this.collectionName = str;
            return this;
        }

        public Builder embeddingDimension(int i) {
            Assert.isTrue(i >= 1 && i <= 32768, "Dimension has to be withing the boundaries 1 and 32768 (inclusively)");
            this.embeddingDimension = i;
            return this;
        }

        public Builder iDFieldName(String str) {
            this.idFieldName = str;
            return this;
        }

        public Builder autoId(boolean z) {
            this.isAutoId = z;
            return this;
        }

        public Builder contentFieldName(String str) {
            this.contentFieldName = str;
            return this;
        }

        public Builder metadataFieldName(String str) {
            this.metadataFieldName = str;
            return this;
        }

        public Builder embeddingFieldName(String str) {
            this.embeddingFieldName = str;
            return this;
        }

        public Builder initializeSchema(boolean z) {
            this.initializeSchema = z;
            return this;
        }

        /* renamed from: build, reason: merged with bridge method [inline-methods] */
        public MilvusVectorStore m2build() {
            return new MilvusVectorStore(this);
        }
    }

    protected MilvusVectorStore(Builder builder) {
        super(builder);
        this.filterExpressionConverter = new MilvusFilterExpressionConverter();
        Assert.notNull(builder.milvusClient, "milvusClient must not be null");
        this.milvusClient = builder.milvusClient;
        this.initializeSchema = builder.initializeSchema;
        this.databaseName = builder.databaseName;
        this.collectionName = builder.collectionName;
        this.embeddingDimension = builder.embeddingDimension;
        this.indexType = builder.indexType;
        this.metricType = builder.metricType;
        this.indexParameters = builder.indexParameters;
        this.idFieldName = builder.idFieldName;
        this.isAutoId = builder.isAutoId;
        this.contentFieldName = builder.contentFieldName;
        this.metadataFieldName = builder.metadataFieldName;
        this.embeddingFieldName = builder.embeddingFieldName;
    }

    public static Builder builder(MilvusServiceClient milvusServiceClient, EmbeddingModel embeddingModel) {
        return new Builder(milvusServiceClient, embeddingModel);
    }

    public void doAdd(List<Document> list) {
        Assert.notNull(list, "Documents must not be null");
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        List embed = this.embeddingModel.embed(list, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
        for (Document document : list) {
            arrayList.add(document.getId());
            arrayList2.add(document.getText());
            Gson gson = new Gson();
            arrayList3.add((JsonObject) gson.fromJson(gson.toJson(document.getMetadata()), JsonObject.class));
            arrayList4.add(EmbeddingUtils.toList((float[]) embed.get(list.indexOf(document))));
        }
        ArrayList arrayList5 = new ArrayList();
        if (!this.isAutoId) {
            arrayList5.add(new InsertParam.Field(this.idFieldName, arrayList));
        }
        arrayList5.add(new InsertParam.Field(this.contentFieldName, arrayList2));
        arrayList5.add(new InsertParam.Field(this.metadataFieldName, arrayList3));
        arrayList5.add(new InsertParam.Field(this.embeddingFieldName, arrayList4));
        R insert = this.milvusClient.insert(InsertParam.newBuilder().withDatabaseName(this.databaseName).withCollectionName(this.collectionName).withFields(arrayList5).build());
        if (insert.getException() != null) {
            throw new RuntimeException("Failed to insert:", insert.getException());
        }
    }

    public void doDelete(List<String> list) {
        Assert.notNull(list, "Document id list must not be null");
        long deleteCnt = ((MutationResult) this.milvusClient.delete(DeleteParam.newBuilder().withDatabaseName(this.databaseName).withCollectionName(this.collectionName).withExpr(String.format("%s in [%s]", this.idFieldName, list.stream().map(str -> {
            return "'" + str + "'";
        }).collect(Collectors.joining(",")))).build()).getData()).getDeleteCnt();
        if (deleteCnt != list.size()) {
            logger.warn(String.format("Deleted only %s entries from requested %s ", Long.valueOf(deleteCnt), Integer.valueOf(list.size())));
        }
    }

    protected void doDelete(Filter.Expression expression) {
        Assert.notNull(expression, "Filter expression must not be null");
        try {
            R delete = this.milvusClient.delete(DeleteParam.newBuilder().withDatabaseName(this.databaseName).withCollectionName(this.collectionName).withExpr(this.filterExpressionConverter.convertExpression(expression)).build());
            if (delete.getStatus().intValue() != R.Status.Success.getCode()) {
                throw new IllegalStateException("Failed to delete documents by filter: " + delete.getMessage());
            }
            logger.debug("Deleted {} documents matching filter expression", Long.valueOf(((MutationResult) delete.getData()).getDeleteCnt()));
        } catch (Exception e) {
            logger.error("Failed to delete documents by filter: {}", e.getMessage(), e);
            throw new IllegalStateException("Failed to delete documents by filter", e);
        }
    }

    public List<Document> doSimilaritySearch(SearchRequest searchRequest) {
        String convertedFilterExpression;
        String str = null;
        if (searchRequest instanceof MilvusSearchRequest) {
            MilvusSearchRequest milvusSearchRequest = (MilvusSearchRequest) searchRequest;
            convertedFilterExpression = StringUtils.hasText(milvusSearchRequest.getNativeExpression()) ? milvusSearchRequest.getNativeExpression() : getConvertedFilterExpression(searchRequest);
            str = StringUtils.hasText(milvusSearchRequest.getSearchParamsJson()) ? milvusSearchRequest.getSearchParamsJson() : null;
        } else {
            convertedFilterExpression = getConvertedFilterExpression(searchRequest);
        }
        Assert.notNull(searchRequest.getQuery(), "Query string must not be null");
        ArrayList arrayList = new ArrayList();
        arrayList.add(this.idFieldName);
        arrayList.add(this.contentFieldName);
        arrayList.add(this.metadataFieldName);
        SearchParam.Builder withVectorFieldName = SearchParam.newBuilder().withDatabaseName(this.databaseName).withCollectionName(this.collectionName).withConsistencyLevel(ConsistencyLevelEnum.STRONG).withMetricType(this.metricType).withOutFields(arrayList).withTopK(Integer.valueOf(searchRequest.getTopK())).withVectors(List.of(EmbeddingUtils.toList(this.embeddingModel.embed(searchRequest.getQuery())))).withVectorFieldName(this.embeddingFieldName);
        if (StringUtils.hasText(convertedFilterExpression)) {
            withVectorFieldName.withExpr(convertedFilterExpression);
        }
        if (StringUtils.hasText(str)) {
            withVectorFieldName.withParams(str);
        }
        R search = this.milvusClient.search(withVectorFieldName.build());
        if (search.getException() != null) {
            throw new RuntimeException("Search failed!", search.getException());
        }
        return new SearchResultsWrapper(((SearchResults) search.getData()).getResults()).getRowRecords(0).stream().filter(rowRecord -> {
            return ((double) getResultSimilarity(rowRecord)) >= searchRequest.getSimilarityThreshold();
        }).map(rowRecord2 -> {
            String valueOf = String.valueOf(rowRecord2.get(this.idFieldName));
            String str2 = (String) rowRecord2.get(this.contentFieldName);
            JsonObject jsonObject = new JsonObject();
            try {
                jsonObject = (JsonObject) rowRecord2.get(this.metadataFieldName);
                jsonObject.addProperty(DocumentMetadata.DISTANCE.value(), Float.valueOf(1.0f - getResultSimilarity(rowRecord2)));
            } catch (ParamException e) {
            }
            return Document.builder().id(valueOf).text(str2).metadata(jsonObject != null ? (Map) new Gson().fromJson(jsonObject, new TypeToken<Map<String, Object>>() { // from class: org.springframework.ai.vectorstore.milvus.MilvusVectorStore.1
            }.getType()) : Map.of()).score(Double.valueOf(getResultSimilarity(rowRecord2))).build();
        }).toList();
    }

    private String getConvertedFilterExpression(SearchRequest searchRequest) {
        return searchRequest.getFilterExpression() != null ? this.filterExpressionConverter.convertExpression(searchRequest.getFilterExpression()) : "";
    }

    private float getResultSimilarity(QueryResultsWrapper.RowRecord rowRecord) {
        Float f = (Float) rowRecord.get(SIMILARITY_FIELD_NAME);
        return (this.metricType == MetricType.IP || this.metricType == MetricType.COSINE) ? f.floatValue() : 1.0f - f.floatValue();
    }

    public void afterPropertiesSet() throws Exception {
        if (this.initializeSchema) {
            createCollection();
        }
    }

    void releaseCollection() {
        if (isDatabaseCollectionExists()) {
            this.milvusClient.releaseCollection(ReleaseCollectionParam.newBuilder().withCollectionName(this.collectionName).build());
        }
    }

    private boolean isDatabaseCollectionExists() {
        return ((Boolean) this.milvusClient.hasCollection(HasCollectionParam.newBuilder().withDatabaseName(this.databaseName).withCollectionName(this.collectionName).build()).getData()).booleanValue();
    }

    void createCollection() {
        if (!isDatabaseCollectionExists()) {
            createCollection(this.databaseName, this.collectionName, this.idFieldName, this.isAutoId, this.contentFieldName, this.metadataFieldName, this.embeddingFieldName);
        }
        if (this.milvusClient.describeIndex(DescribeIndexParam.newBuilder().withDatabaseName(this.databaseName).withCollectionName(this.collectionName).build()).getData() == null) {
            R createIndex = this.milvusClient.createIndex(CreateIndexParam.newBuilder().withDatabaseName(this.databaseName).withCollectionName(this.collectionName).withFieldName(this.embeddingFieldName).withIndexType(this.indexType).withMetricType(this.metricType).withExtraParam(this.indexParameters).withSyncMode(Boolean.FALSE).build());
            if (createIndex.getException() != null) {
                throw new RuntimeException("Failed to create Index", createIndex.getException());
            }
        }
        R loadCollection = this.milvusClient.loadCollection(LoadCollectionParam.newBuilder().withDatabaseName(this.databaseName).withCollectionName(this.collectionName).build());
        if (loadCollection.getException() != null) {
            throw new RuntimeException("Collection loading failed!", loadCollection.getException());
        }
    }

    void createCollection(String str, String str2, String str3, boolean z, String str4, String str5, String str6) {
        FieldType build = FieldType.newBuilder().withName(str3).withDataType(DataType.VarChar).withMaxLength(36).withPrimaryKey(true).withAutoID(z).build();
        FieldType build2 = FieldType.newBuilder().withName(str4).withDataType(DataType.VarChar).withMaxLength(65535).build();
        FieldType build3 = FieldType.newBuilder().withName(str5).withDataType(DataType.JSON).build();
        R createCollection = this.milvusClient.createCollection(CreateCollectionParam.newBuilder().withDatabaseName(str).withCollectionName(str2).withDescription("Spring AI Vector Store").withConsistencyLevel(ConsistencyLevelEnum.STRONG).withShardsNum(2).addFieldType(build).addFieldType(build2).addFieldType(build3).addFieldType(FieldType.newBuilder().withName(str6).withDataType(DataType.FloatVector).withDimension(Integer.valueOf(embeddingDimensions())).build()).build());
        if (createCollection.getException() != null) {
            throw new RuntimeException("Failed to create collection", createCollection.getException());
        }
    }

    int embeddingDimensions() {
        if (this.embeddingDimension != -1) {
            return this.embeddingDimension;
        }
        try {
            int dimensions = this.embeddingModel.dimensions();
            return dimensions > 0 ? dimensions : OPENAI_EMBEDDING_DIMENSION_SIZE;
        } catch (Exception e) {
            logger.warn("Failed to obtain the embedding dimensions from the embedding model and fall backs to default:" + this.embeddingDimension, e);
            return OPENAI_EMBEDDING_DIMENSION_SIZE;
        }
    }

    void dropCollection() {
        R releaseCollection = this.milvusClient.releaseCollection(ReleaseCollectionParam.newBuilder().withCollectionName(this.collectionName).build());
        if (releaseCollection.getException() != null) {
            throw new RuntimeException("Release collection failed!", releaseCollection.getException());
        }
        R dropIndex = this.milvusClient.dropIndex(DropIndexParam.newBuilder().withCollectionName(this.collectionName).build());
        if (dropIndex.getException() != null) {
            throw new RuntimeException("Drop Index failed!", dropIndex.getException());
        }
        R dropCollection = this.milvusClient.dropCollection(DropCollectionParam.newBuilder().withDatabaseName(this.databaseName).withCollectionName(this.collectionName).build());
        if (dropCollection.getException() != null) {
            throw new RuntimeException("Drop Collection failed!", dropCollection.getException());
        }
    }

    public VectorStoreObservationContext.Builder createObservationContextBuilder(String str) {
        return VectorStoreObservationContext.builder(VectorStoreProvider.MILVUS.value(), str).collectionName(this.collectionName).dimensions(Integer.valueOf(this.embeddingModel.dimensions())).similarityMetric(getSimilarityMetric()).namespace(this.databaseName);
    }

    private String getSimilarityMetric() {
        return !SIMILARITY_TYPE_MAPPING.containsKey(this.metricType) ? this.metricType.name() : SIMILARITY_TYPE_MAPPING.get(this.metricType).value();
    }

    public <T> Optional<T> getNativeClient() {
        return Optional.of(this.milvusClient);
    }
}
