/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.advisor;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.util.Assert;

public class CompositeDocumentRetriever
implements DocumentRetriever {
    private static final Logger logger = LoggerFactory.getLogger(CompositeDocumentRetriever.class);
    private final List<DocumentRetriever> retrievers;
    private final Integer maxResultsPerRetriever;
    private final ResultMergeStrategy mergeStrategy;

    public CompositeDocumentRetriever(List<DocumentRetriever> retrievers) {
        this(retrievers, 10, ResultMergeStrategy.SCORE_BASED);
    }

    public CompositeDocumentRetriever(List<DocumentRetriever> retrievers, Integer maxResultsPerRetriever) {
        this(retrievers, maxResultsPerRetriever, ResultMergeStrategy.SCORE_BASED);
    }

    public CompositeDocumentRetriever(List<DocumentRetriever> retrievers, Integer maxResultsPerRetriever, ResultMergeStrategy mergeStrategy) {
        Assert.notNull(retrievers, (String)"Retrievers list must not be null!");
        Assert.isTrue((!retrievers.isEmpty() ? 1 : 0) != 0, (String)"Retrievers list must not be empty!");
        Assert.isTrue((maxResultsPerRetriever > 0 ? 1 : 0) != 0, (String)"MaxResultsPerRetriever must be positive!");
        Assert.notNull((Object)((Object)mergeStrategy), (String)"MergeStrategy must not be null!");
        this.retrievers = new ArrayList<DocumentRetriever>(retrievers);
        this.maxResultsPerRetriever = maxResultsPerRetriever;
        this.mergeStrategy = mergeStrategy;
    }

    public List<Document> retrieve(Query query) {
        if (this.mergeStrategy == ResultMergeStrategy.ROUND_ROBIN) {
            return this.roundRobinRetrieve(query);
        }
        ArrayList<Document> allDocuments = new ArrayList<Document>();
        for (DocumentRetriever retriever : this.retrievers) {
            try {
                List documents = retriever.retrieve(query);
                if (documents == null || documents.isEmpty()) continue;
                List limitedDocuments = documents.stream().limit(this.maxResultsPerRetriever.intValue()).collect(Collectors.toList());
                allDocuments.addAll(limitedDocuments);
            }
            catch (Exception e) {
                logger.error("Error retrieving from one of the retrievers: {}", (Object)e.getMessage(), (Object)e);
            }
        }
        return this.mergeResults(allDocuments);
    }

    private List<Document> roundRobinRetrieve(Query query) {
        ArrayList<List<Object>> allResults = new ArrayList<List<Object>>();
        for (DocumentRetriever retriever : this.retrievers) {
            try {
                List documents = retriever.retrieve(query);
                if (documents != null && !documents.isEmpty()) {
                    List limitedDocuments = documents.stream().limit(this.maxResultsPerRetriever.intValue()).collect(Collectors.toList());
                    allResults.add(limitedDocuments);
                    continue;
                }
                allResults.add(new ArrayList());
            }
            catch (Exception e) {
                logger.error("Error retrieving from one of the retrievers: {}", (Object)e.getMessage(), (Object)e);
                allResults.add(new ArrayList());
            }
        }
        Integer maxSize = allResults.stream().mapToInt(List::size).max().orElse(0);
        return IntStream.range(0, maxSize).boxed().flatMap(i -> allResults.stream().filter(documents -> i < documents.size()).map(documents -> (Document)documents.get((int)i))).collect(Collectors.toList());
    }

    private List<Document> mergeResults(List<Document> documents) {
        if (documents.isEmpty()) {
            return documents;
        }
        return switch (this.mergeStrategy) {
            case ResultMergeStrategy.SIMPLE_MERGE -> documents;
            case ResultMergeStrategy.SCORE_BASED -> documents.stream().sorted((d1, d2) -> {
                Double score1 = d1.getScore();
                Double score2 = d2.getScore();
                if (score1 == null) {
                    score1 = 0.0;
                }
                if (score2 == null) {
                    score2 = 0.0;
                }
                return Double.compare(score2, score1);
            }).collect(Collectors.toList());
            case ResultMergeStrategy.ROUND_ROBIN -> documents;
            default -> documents;
        };
    }

    public static Builder builder() {
        return new Builder();
    }

    public static enum ResultMergeStrategy {
        SIMPLE_MERGE,
        SCORE_BASED,
        ROUND_ROBIN;

    }

    public static class Builder {
        private List<DocumentRetriever> retrievers = new ArrayList<DocumentRetriever>();
        private Integer maxResultsPerRetriever = 10;
        private ResultMergeStrategy mergeStrategy = ResultMergeStrategy.SCORE_BASED;

        private Builder() {
        }

        public Builder addRetriever(DocumentRetriever retriever) {
            if (retriever != null) {
                this.retrievers.add(retriever);
            }
            return this;
        }

        public Builder retrievers(List<DocumentRetriever> retrievers) {
            if (retrievers != null) {
                this.retrievers.addAll(retrievers);
            }
            return this;
        }

        public Builder maxResultsPerRetriever(Integer maxResultsPerRetriever) {
            this.maxResultsPerRetriever = maxResultsPerRetriever;
            return this;
        }

        public Builder mergeStrategy(ResultMergeStrategy mergeStrategy) {
            this.mergeStrategy = mergeStrategy;
            return this;
        }

        public CompositeDocumentRetriever build() {
            return new CompositeDocumentRetriever(this.retrievers, this.maxResultsPerRetriever, this.mergeStrategy);
        }
    }
}

