/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.advisor;

import com.alibaba.cloud.ai.advisor.CompositeDocumentRetriever;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.util.Assert;

public class DocumentRetrievalAdvisor
implements BaseAdvisor {
    private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("{query}\n\nContext information is below, surrounded by ---------------------\n---------------------\n{question_answer_context}\n---------------------\nGiven the context and provided history information and not prior knowledge,\nreply to the user comment. If the answer is not in the context, inform\nthe user that you can't answer the question.\n");
    private static final int DEFAULT_ORDER = 0;
    public static String RETRIEVED_DOCUMENTS = "question_answer_context";
    private final DocumentRetriever retriever;
    private final PromptTemplate promptTemplate;
    private final int order;

    public DocumentRetrievalAdvisor(DocumentRetriever retriever) {
        this(retriever, DEFAULT_PROMPT_TEMPLATE);
    }

    public DocumentRetrievalAdvisor(DocumentRetriever retriever, PromptTemplate promptTemplate) {
        this(retriever, promptTemplate, 0);
    }

    public DocumentRetrievalAdvisor(DocumentRetriever retriever, PromptTemplate promptTemplate, int order) {
        Assert.notNull((Object)retriever, (String)"The retriever must not be null!");
        Assert.notNull((Object)promptTemplate, (String)"The promptTemplate must not be null!");
        this.retriever = retriever;
        this.promptTemplate = promptTemplate;
        this.order = order;
    }

    public DocumentRetrievalAdvisor(List<DocumentRetriever> retrievers) {
        this(retrievers, DEFAULT_PROMPT_TEMPLATE, 0);
    }

    public DocumentRetrievalAdvisor(List<DocumentRetriever> retrievers, PromptTemplate promptTemplate) {
        this(retrievers, promptTemplate, 0);
    }

    public DocumentRetrievalAdvisor(List<DocumentRetriever> retrievers, PromptTemplate promptTemplate, int order) {
        Assert.notEmpty(retrievers, (String)"The retrievers list must not be null or empty!");
        Assert.notNull((Object)promptTemplate, (String)"The promptTemplate must not be null!");
        this.retriever = new CompositeDocumentRetriever(retrievers);
        this.promptTemplate = promptTemplate;
        this.order = order;
    }

    public DocumentRetrievalAdvisor(List<DocumentRetriever> retrievers, CompositeDocumentRetriever.ResultMergeStrategy mergeStrategy, int maxResultsPerRetriever) {
        this(retrievers, mergeStrategy, maxResultsPerRetriever, DEFAULT_PROMPT_TEMPLATE, 0);
    }

    public DocumentRetrievalAdvisor(List<DocumentRetriever> retrievers, CompositeDocumentRetriever.ResultMergeStrategy mergeStrategy, int maxResultsPerRetriever, PromptTemplate promptTemplate, int order) {
        Assert.notEmpty(retrievers, (String)"The retrievers list must not be null or empty!");
        Assert.notNull((Object)promptTemplate, (String)"The promptTemplate must not be null!");
        this.retriever = new CompositeDocumentRetriever(retrievers, maxResultsPerRetriever, mergeStrategy);
        this.promptTemplate = promptTemplate;
        this.order = order;
    }

    public int getOrder() {
        return this.order;
    }

    public ChatClientRequest before(ChatClientRequest request, AdvisorChain advisorChain) {
        Map context = request.context();
        UserMessage userMessage = request.prompt().getUserMessage();
        Query query = new Query(userMessage.getText(), request.prompt().getInstructions(), context);
        List documents = this.retriever.retrieve(query);
        context.put(RETRIEVED_DOCUMENTS, documents);
        String documentContext = documents.stream().map(Document::getText).collect(Collectors.joining(System.lineSeparator()));
        String augmentedUserText = this.promptTemplate.render(Map.of("query", userMessage.getText(), "question_answer_context", documentContext));
        return request.mutate().prompt(request.prompt().augmentUserMessage(augmentedUserText)).context(context).build();
    }

    public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {
        ChatResponse.Builder chatResponseBuilder = chatClientResponse.chatResponse() == null ? ChatResponse.builder() : ChatResponse.builder().from(chatClientResponse.chatResponse());
        chatResponseBuilder.metadata(RETRIEVED_DOCUMENTS, chatClientResponse.context().get(RETRIEVED_DOCUMENTS));
        return ChatClientResponse.builder().chatResponse(chatResponseBuilder.build()).context(chatClientResponse.context()).build();
    }
}

