/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.community.store.embedding.alloydb;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.pgvector.PGvector;
import dev.langchain4j.community.store.embedding.alloydb.AlloyDBEngine;
import dev.langchain4j.community.store.embedding.alloydb.filter.AlloyDBFilterMapper;
import dev.langchain4j.community.store.embedding.alloydb.index.BaseIndex;
import dev.langchain4j.community.store.embedding.alloydb.index.DistanceStrategy;
import dev.langchain4j.community.store.embedding.alloydb.index.ScaNNIndex;
import dev.langchain4j.community.store.embedding.alloydb.index.query.QueryOptions;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import java.sql.Array;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;

public class AlloyDBEmbeddingStore
implements EmbeddingStore<TextSegment> {
    private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper().enable(SerializationFeature.INDENT_OUTPUT);
    private final AlloyDBFilterMapper FILTER_MAPPER = new AlloyDBFilterMapper();
    private final AlloyDBEngine engine;
    private final String tableName;
    private final String schemaName;
    private final String contentColumn;
    private final String embeddingColumn;
    private final String idColumn;
    private final List<String> metadataColumns;
    private final DistanceStrategy distanceStrategy;
    private final QueryOptions queryOptions;
    private String metadataJsonColumn;
    private final String insertQuery;
    private final String deleteQuery;

    public AlloyDBEmbeddingStore(Builder builder) {
        this.engine = builder.engine;
        this.tableName = builder.tableName;
        this.schemaName = builder.schemaName;
        this.contentColumn = builder.contentColumn;
        this.embeddingColumn = builder.embeddingColumn;
        this.idColumn = builder.idColumn;
        this.metadataJsonColumn = builder.metadataJsonColumn;
        this.metadataColumns = builder.metadataColumns;
        this.distanceStrategy = builder.distanceStrategy;
        this.queryOptions = builder.queryOptions;
        this.verifyEmbeddingStoreColumns(builder.ignoreMetadataColumnNames);
        this.insertQuery = this.generateInsertQuery();
        this.deleteQuery = String.format("DELETE FROM \"%s\".\"%s\" WHERE %s = ANY(?)", this.schemaName, this.tableName, this.idColumn);
    }

    private void verifyEmbeddingStoreColumns(List<String> ignoredColumns) {
        if (!this.metadataColumns.isEmpty() && !ignoredColumns.isEmpty()) {
            throw new IllegalArgumentException("Cannot use both metadataColumns and ignoreMetadataColumns at the same time.");
        }
        String query = String.format("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '%s' AND table_schema = '%s'", this.tableName, this.schemaName);
        HashMap<String, String> allColumns = new HashMap<String, String>();
        try (Connection conn = this.engine.getConnection();){
            ResultSet resultSet = conn.createStatement().executeQuery(query);
            while (resultSet.next()) {
                allColumns.put(resultSet.getString("column_name"), resultSet.getString("data_type"));
            }
            if (!allColumns.containsKey(this.idColumn)) {
                throw new IllegalStateException("Id column, " + this.idColumn + ", does not exist.");
            }
            if (!allColumns.containsKey(this.contentColumn)) {
                throw new IllegalStateException("Content column, " + this.contentColumn + ", does not exist.");
            }
            if (!((String)allColumns.get(this.contentColumn)).equalsIgnoreCase("text") && !((String)allColumns.get(this.contentColumn)).contains("char")) {
                throw new IllegalStateException("Content column, is type " + (String)allColumns.get(this.contentColumn) + ". It must be a type of character string.");
            }
            if (!allColumns.containsKey(this.embeddingColumn)) {
                throw new IllegalStateException("Embedding column, " + this.embeddingColumn + ", does not exist.");
            }
            if (!((String)allColumns.get(this.embeddingColumn)).equalsIgnoreCase("USER-DEFINED")) {
                throw new IllegalStateException("Embedding column, " + this.embeddingColumn + ", is not type Vector.");
            }
            if (!allColumns.containsKey(this.metadataJsonColumn)) {
                this.metadataJsonColumn = null;
            }
            for (String metadataColumn : this.metadataColumns) {
                if (allColumns.containsKey(metadataColumn)) continue;
                throw new IllegalStateException("Metadata column, " + metadataColumn + ", does not exist.");
            }
            if (ignoredColumns != null && !ignoredColumns.isEmpty()) {
                Map<String, String> allColumnsCopy = allColumns.entrySet().stream().collect(Collectors.toMap(e -> (String)e.getKey(), e -> (String)e.getValue()));
                ignoredColumns.add(this.idColumn);
                ignoredColumns.add(this.contentColumn);
                ignoredColumns.add(this.embeddingColumn);
                for (String ignore : ignoredColumns) {
                    allColumnsCopy.remove(ignore);
                }
                this.metadataColumns.addAll(allColumnsCopy.keySet());
            }
        }
        catch (SQLException ex) {
            throw new RuntimeException("Exception caught when verifying vector store table: \"" + this.schemaName + "\".\"" + this.tableName + "\"", ex);
        }
    }

    private String generateInsertQuery() {
        Object metadataColumnNames = this.metadataColumns.stream().map(column -> "\"" + column + "\"").collect(Collectors.joining(", "));
        int totalColumns = 3;
        if (Utils.isNotNullOrEmpty((String)metadataColumnNames)) {
            totalColumns += ((String)metadataColumnNames).split(",").length;
            metadataColumnNames = ", " + (String)metadataColumnNames;
        }
        if (Utils.isNotNullOrEmpty((String)this.metadataJsonColumn)) {
            metadataColumnNames = (String)metadataColumnNames + ", \"" + this.metadataJsonColumn + "\"";
            ++totalColumns;
        }
        Object placeholders = "?";
        for (int p = 1; p < totalColumns; ++p) {
            placeholders = (String)placeholders + ", ?";
        }
        return String.format("INSERT INTO \"%s\".\"%s\" (\"%s\", \"%s\", \"%s\"%s) VALUES (%s)", this.schemaName, this.tableName, this.idColumn, this.embeddingColumn, this.contentColumn, metadataColumnNames, placeholders);
    }

    public String add(Embedding embedding) {
        String id = Utils.randomUUID();
        this.addInternal(id, embedding, null);
        return id;
    }

    public void add(String id, Embedding embedding) {
        this.addInternal(id, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String id = Utils.randomUUID();
        this.addInternal(id, embedding, textSegment);
        return id;
    }

    public List<String> addAll(List<Embedding> embeddings) {
        List<String> ids = embeddings.stream().map(ignored -> Utils.randomUUID()).collect(Collectors.toList());
        List<Object> emptyTextSegments = Collections.nCopies(ids.size(), null);
        this.addAll(ids, embeddings, emptyTextSegments);
        return ids;
    }

    public List<String> addAll(List<Embedding> embeddings, List<TextSegment> textSegment) {
        List<String> ids = embeddings.stream().map(ignored -> Utils.randomUUID()).collect(Collectors.toList());
        this.addAll(ids, embeddings, textSegment);
        return ids;
    }

    public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request) {
        ArrayList<String> columns = new ArrayList<String>(this.metadataColumns);
        columns.add(this.idColumn);
        columns.add(this.contentColumn);
        columns.add(this.embeddingColumn);
        if (Utils.isNotNullOrBlank((String)this.metadataJsonColumn)) {
            columns.add(this.metadataJsonColumn);
        }
        String columnNames = columns.stream().map(c -> String.format("\"%s\"", c)).collect(Collectors.joining(", "));
        String filterString = this.FILTER_MAPPER.map(request.filter());
        String whereClause = Utils.isNotNullOrBlank((String)filterString) ? String.format("WHERE %s", filterString) : "";
        String selectQuery = String.format("SELECT %s, %s(%s, ?) as distance FROM \"%s\".\"%s\" %s ORDER BY %s %s ? LIMIT ?;", columnNames, this.distanceStrategy.getSearchFunction(), this.embeddingColumn, this.schemaName, this.tableName, whereClause, this.embeddingColumn, this.distanceStrategy.getOperator());
        ArrayList<EmbeddingMatch> embeddingMatches = new ArrayList<EmbeddingMatch>();
        try (Connection conn = this.engine.getConnection();){
            PGvector.registerTypes((Connection)conn);
            try (Statement statement = conn.createStatement();){
                if (this.queryOptions != null) {
                    for (String option : this.queryOptions.getParameterSettings()) {
                        statement.executeQuery(String.format("SET LOCAL %s;", option));
                    }
                }
            }
            try (PreparedStatement preparedStatement = conn.prepareStatement(selectQuery);){
                preparedStatement.setObject(1, new PGvector(request.queryEmbedding().vector()));
                preparedStatement.setObject(2, new PGvector(request.queryEmbedding().vector()));
                preparedStatement.setInt(3, request.maxResults());
                ResultSet resultSet = preparedStatement.executeQuery();
                while (resultSet.next()) {
                    double score = this.calculateRelevanceScore(resultSet.getDouble("distance"));
                    if (score < request.minScore()) continue;
                    String embeddingId = resultSet.getString(this.idColumn);
                    PGvector pgVector = (PGvector)resultSet.getObject(this.embeddingColumn);
                    Embedding embedding = Embedding.from((float[])pgVector.toArray());
                    String embeddedText = resultSet.getString(this.contentColumn);
                    HashMap<String, Object> metadataMap = new HashMap<String, Object>();
                    for (String metadataColumn : this.metadataColumns) {
                        if (resultSet.getObject(metadataColumn) == null) continue;
                        metadataMap.put(metadataColumn, resultSet.getObject(metadataColumn));
                    }
                    if (Utils.isNotNullOrBlank((String)this.metadataJsonColumn)) {
                        String metadataJsonString = (String)Utils.getOrDefault((Object)resultSet.getString(this.metadataJsonColumn), (Object)"{}");
                        Map metadataJsonMap = (Map)OBJECT_MAPPER.readValue(metadataJsonString, Map.class);
                        metadataMap.putAll(metadataJsonMap);
                    }
                    Metadata metadata = Metadata.from(metadataMap);
                    TextSegment embedded = embeddedText != null ? new TextSegment(embeddedText, metadata) : null;
                    embeddingMatches.add(new EmbeddingMatch(Double.valueOf(score), embeddingId, embedding, (Object)embedded));
                }
            }
            catch (JsonProcessingException ex) {
                throw new RuntimeException("Exception caught when processing JSON metadata", ex);
            }
        }
        catch (SQLException ex) {
            throw new RuntimeException("Exception caught when searching in store table: \"" + this.schemaName + "\".\"" + this.tableName + "\"", ex);
        }
        return new EmbeddingSearchResult(embeddingMatches);
    }

    public void removeAll(Collection<String> ids) {
        if (ids == null || ids.isEmpty()) {
            throw new IllegalArgumentException("ids cannot be null or empty");
        }
        try (Connection conn = this.engine.getConnection();
             PreparedStatement preparedStatement = conn.prepareStatement(this.deleteQuery);){
            Array array = conn.createArrayOf("uuid", ids.stream().map(UUID::fromString).toArray());
            preparedStatement.setArray(1, array);
            preparedStatement.executeUpdate();
        }
        catch (SQLException ex) {
            throw new RuntimeException(String.format("Exception caught when deleting from vector store table: \"%s\".\"%s\"", this.schemaName, this.tableName), ex);
        }
    }

    private void addInternal(String id, Embedding embedding, TextSegment textSegment) {
        this.addAll(Collections.singletonList(id), Collections.singletonList(embedding), Collections.singletonList(textSegment));
    }

    public void addAll(List<String> ids, List<Embedding> embeddings, List<TextSegment> textSegments) {
        if (ids.size() != embeddings.size() || embeddings.size() != textSegments.size()) {
            throw new IllegalArgumentException("List parameters ids and embeddings and textSegments shouldn't be different sizes!");
        }
        try (Connection connection = this.engine.getConnection();){
            try (PreparedStatement preparedStatement = connection.prepareStatement(this.insertQuery);){
                PGvector.registerTypes((Connection)connection);
                for (int i = 0; i < ids.size(); ++i) {
                    String id = ids.get(i);
                    Embedding embedding = embeddings.get(i);
                    TextSegment textSegment = textSegments.get(i);
                    String text = textSegment != null ? textSegment.text() : null;
                    Map<String, Object> embeddedMetadataCopy = textSegment != null ? textSegment.metadata().toMap().entrySet().stream().collect(Collectors.toMap(e -> (String)e.getKey(), e -> e.getValue())) : null;
                    preparedStatement.setObject(1, (Object)UUID.fromString(id), 1111);
                    preparedStatement.setObject(2, new PGvector(embedding.vector()));
                    preparedStatement.setString(3, text);
                    if (embeddedMetadataCopy != null && !embeddedMetadataCopy.isEmpty()) {
                        for (j = 0; j < this.metadataColumns.size(); ++j) {
                            if (embeddedMetadataCopy.containsKey(this.metadataColumns.get(j))) {
                                preparedStatement.setObject(j + 4, embeddedMetadataCopy.remove(this.metadataColumns.get(j)));
                                continue;
                            }
                            preparedStatement.setObject(j + 4, null);
                        }
                        if (Utils.isNotNullOrEmpty((String)this.metadataJsonColumn)) {
                            preparedStatement.setObject(j + 4, (Object)OBJECT_MAPPER.writeValueAsString(embeddedMetadataCopy), 1111);
                        }
                    } else {
                        while (j < this.metadataColumns.size()) {
                            preparedStatement.setObject(j + 4, null);
                            ++j;
                        }
                        if (Utils.isNotNullOrEmpty((String)this.metadataJsonColumn)) {
                            preparedStatement.setObject(j + 4, null);
                        }
                    }
                    preparedStatement.addBatch();
                }
                preparedStatement.executeBatch();
            }
            catch (JsonProcessingException ex) {
                throw new RuntimeException("Exception caught when processing JSON metadata", ex);
            }
        }
        catch (SQLException ex) {
            throw new RuntimeException("Exception caught when inserting into vector store table: \"" + this.schemaName + "\".\"" + this.tableName + "\"", ex);
        }
    }

    public void applyVectorIndex(BaseIndex index, String name, Boolean concurrently) {
        if (index == null) {
            this.dropVectorIndex(null);
            return;
        }
        if (Utils.isNullOrBlank((String)name)) {
            name = Utils.isNotNullOrBlank((String)index.getName()) ? index.getName() : this.tableName + "langchainvectorindex";
        }
        try (Connection conn = this.engine.getConnection();){
            String function;
            if (index instanceof ScaNNIndex) {
                ScaNNIndex scaNNIndex = (ScaNNIndex)index;
                conn.createStatement().executeQuery("CREATE EXTENSION IF NOT EXISTS alloydb_scann");
                function = scaNNIndex.getDistanceStrategy().getScannIndexFunction();
            } else {
                function = index.getDistanceStrategy().getIndexFunction();
            }
            String filter = index.getPartialIndexes() != null && index.getPartialIndexes().isEmpty() ? String.format("WHERE %s", String.join((CharSequence)", ", index.getPartialIndexes())) : "";
            String params = String.format("WITH %s", index.getIndexOptions());
            String concurrentlyString = concurrently != false ? "CONCURRENTLY" : "";
            String stmt = String.format("CREATE INDEX %s %s ON \"%s\".\"%s\" USING %s (%s %s) %s %s;", concurrentlyString, name, this.schemaName, this.tableName, index.getIndexType(), this.embeddingColumn, function, params, filter);
            conn.createStatement().executeQuery(stmt);
        }
        catch (SQLException ex) {
            throw new RuntimeException("Exception caught when creating " + (String)name + " index in vector store table: \"" + this.schemaName + "\".\"" + this.tableName + "\"", ex);
        }
    }

    public void dropVectorIndex(String name) {
        name = Utils.isNotNullOrBlank((String)name) ? name : this.tableName + "langchainvectorindex";
        String query = String.format("DROP INDEX IF EXISTS %s;", name);
        try (Connection conn = this.engine.getConnection();){
            conn.createStatement().executeQuery(query);
        }
        catch (SQLException ex) {
            throw new RuntimeException("Exception caught when removing " + name + " index in vector store table: \"" + this.schemaName + "\".\"" + this.tableName + "\"", ex);
        }
    }

    public void reindex(String name) {
        name = Utils.isNotNullOrBlank((String)name) ? name : this.tableName + "langchainvectorindex";
        String query = String.format("REINDEX INDEX %s;", name);
        try (Connection conn = this.engine.getConnection();){
            conn.createStatement().executeQuery(query);
        }
        catch (SQLException ex) {
            throw new RuntimeException("Exception caught when reindexing " + name + " index in vector store table: \"" + this.schemaName + "\".\"" + this.tableName + "\"", ex);
        }
    }

    private double calculateRelevanceScore(double distance) {
        switch (this.distanceStrategy.name()) {
            case "EUCLIDEAN": {
                return 1.0 - distance / Math.sqrt(2.0);
            }
            case "COSINE_DISTANCE": {
                return RelevanceScore.fromCosineSimilarity((double)(1.0 - distance));
            }
            case "INNER_PRODUCT": {
                if (distance > 0.0) {
                    return 1.0 - distance;
                }
                return -1.0 * distance;
            }
        }
        throw new UnsupportedOperationException(String.format("Unable to calculate relevance score for search function: %s ", this.distanceStrategy.getSearchFunction()));
    }

    public static Builder builder(AlloyDBEngine engine, String tableName) {
        return new Builder(engine, tableName);
    }

    public static class Builder {
        private final AlloyDBEngine engine;
        private final String tableName;
        private String schemaName = "public";
        private String contentColumn = "content";
        private String embeddingColumn = "embedding";
        private String idColumn = "langchain_id";
        private List<String> metadataColumns = new ArrayList<String>();
        private String metadataJsonColumn = "langchain_metadata";
        private List<String> ignoreMetadataColumnNames = new ArrayList<String>();
        private DistanceStrategy distanceStrategy = DistanceStrategy.COSINE_DISTANCE;
        private QueryOptions queryOptions;

        public Builder(AlloyDBEngine engine, String tableName) {
            this.engine = engine;
            this.tableName = tableName;
        }

        public Builder schemaName(String schemaName) {
            this.schemaName = schemaName;
            return this;
        }

        public Builder contentColumn(String contentColumn) {
            this.contentColumn = contentColumn;
            return this;
        }

        public Builder embeddingColumn(String embeddingColumn) {
            this.embeddingColumn = embeddingColumn;
            return this;
        }

        public Builder idColumn(String idColumn) {
            this.idColumn = idColumn;
            return this;
        }

        public Builder metadataColumns(List<String> metadataColumns) {
            this.metadataColumns = metadataColumns;
            return this;
        }

        public Builder metadataJsonColumn(String metadataJsonColumn) {
            this.metadataJsonColumn = metadataJsonColumn;
            return this;
        }

        public Builder ignoreMetadataColumnNames(List<String> ignoreMetadataColumnNames) {
            this.ignoreMetadataColumnNames = ignoreMetadataColumnNames;
            return this;
        }

        public Builder distanceStrategy(DistanceStrategy distanceStrategy) {
            this.distanceStrategy = distanceStrategy;
            return this;
        }

        public Builder queryOptions(QueryOptions queryOptions) {
            this.queryOptions = queryOptions;
            return this;
        }

        public AlloyDBEmbeddingStore build() {
            return new AlloyDBEmbeddingStore(this);
        }
    }
}

