/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.graph.agent.hook.summarization;

import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.agent.hook.HookPosition;
import com.alibaba.cloud.ai.graph.agent.hook.HookPositions;
import com.alibaba.cloud.ai.graph.agent.hook.JumpTo;
import com.alibaba.cloud.ai.graph.agent.hook.TokenCounter;
import com.alibaba.cloud.ai.graph.agent.hook.messages.AgentCommand;
import com.alibaba.cloud.ai.graph.agent.hook.messages.MessagesModelHook;
import com.alibaba.cloud.ai.graph.agent.hook.messages.UpdatePolicy;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;

@HookPositions(value={HookPosition.BEFORE_MODEL})
public class SummarizationHook
extends MessagesModelHook {
    private static final Logger log = LoggerFactory.getLogger(SummarizationHook.class);
    private static final String DEFAULT_SUMMARY_PROMPT = "<role>\nContext Extraction Assistant\n</role>\n\n<primary_objective>\nYour sole objective in this task is to extract the highest quality/most relevant context from the conversation history below.\n</primary_objective>\n\n<instructions>\nThe conversation history below will be replaced with the context you extract in this step. Extract and record all of the most important context from the conversation history.\nRespond ONLY with the extracted context. Do not include any additional information.\n</instructions>\n\n<messages>\nMessages to summarize:\n%s\n</messages>";
    private static final String SUMMARY_PREFIX = "## Previous conversation summary:";
    private static final int DEFAULT_MESSAGES_TO_KEEP = 20;
    private static final int SEARCH_RANGE_FOR_TOOL_PAIRS = 5;
    private final ChatModel model;
    private final Integer maxTokensBeforeSummary;
    private final int messagesToKeep;
    private final TokenCounter tokenCounter;
    private final String summaryPrompt;
    private final String summaryPrefix;

    private SummarizationHook(Builder builder) {
        this.model = builder.model;
        this.maxTokensBeforeSummary = builder.maxTokensBeforeSummary;
        this.messagesToKeep = builder.messagesToKeep;
        this.tokenCounter = builder.tokenCounter;
        this.summaryPrompt = builder.summaryPrompt;
        this.summaryPrefix = builder.summaryPrefix;
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override
    public AgentCommand beforeModel(List<Message> previousMessages, RunnableConfig config) {
        if (this.maxTokensBeforeSummary == null) {
            return new AgentCommand(previousMessages);
        }
        int totalTokens = this.tokenCounter.countTokens(previousMessages);
        if (totalTokens < this.maxTokensBeforeSummary) {
            return new AgentCommand(previousMessages);
        }
        log.info("Token count {} exceeds threshold {}, triggering summarization", (Object)totalTokens, (Object)this.maxTokensBeforeSummary);
        int cutoffIndex = this.findSafeCutoff(previousMessages);
        if (cutoffIndex <= 0) {
            log.warn("Cannot find safe cutoff point for summarization");
            return new AgentCommand(previousMessages);
        }
        List<Message> toSummarize = previousMessages.subList(0, cutoffIndex);
        List<Message> toPreserve = previousMessages.subList(cutoffIndex, previousMessages.size());
        String summary = this.createSummary(toSummarize);
        ArrayList<Message> newMessages = new ArrayList<Message>();
        newMessages.add((Message)new UserMessage("Here is a summary of the conversation to date:\n\n" + summary));
        newMessages.addAll(toPreserve);
        log.info("Summarized {} messages, keeping {} recent messages", (Object)toSummarize.size(), (Object)toPreserve.size());
        return new AgentCommand(newMessages, UpdatePolicy.REPLACE);
    }

    private int findSafeCutoff(List<Message> messages) {
        int targetCutoff;
        if (messages.size() <= this.messagesToKeep) {
            return 0;
        }
        for (int i = targetCutoff = messages.size() - this.messagesToKeep; i >= 0; --i) {
            if (!this.isSafeCutoffPoint(messages, i)) continue;
            return i;
        }
        return 0;
    }

    private boolean isSafeCutoffPoint(List<Message> messages, int cutoffIndex) {
        if (cutoffIndex >= messages.size()) {
            return true;
        }
        int searchStart = Math.max(0, cutoffIndex - 5);
        int searchEnd = Math.min(messages.size(), cutoffIndex + 5);
        for (int i = searchStart; i < searchEnd; ++i) {
            AssistantMessage aiMessage;
            Set<String> toolCallIds;
            if (!this.hasToolCalls(messages.get(i)) || !this.cutoffSeparatesToolPair(messages, i, cutoffIndex, toolCallIds = this.extractToolCallIds(aiMessage = (AssistantMessage)messages.get(i)))) continue;
            return false;
        }
        return true;
    }

    private boolean hasToolCalls(Message message) {
        AssistantMessage assistantMessage;
        return message instanceof AssistantMessage && !(assistantMessage = (AssistantMessage)message).getToolCalls().isEmpty();
    }

    private Set<String> extractToolCallIds(AssistantMessage aiMessage) {
        HashSet<String> toolCallIds = new HashSet<String>();
        for (AssistantMessage.ToolCall toolCall : aiMessage.getToolCalls()) {
            String callId = toolCall.id();
            toolCallIds.add(callId);
        }
        return toolCallIds;
    }

    private boolean cutoffSeparatesToolPair(List<Message> messages, int aiMessageIndex, int cutoffIndex, Set<String> toolCallIds) {
        for (int j = aiMessageIndex + 1; j < messages.size(); ++j) {
            Message message = messages.get(j);
            if (!(message instanceof ToolResponseMessage)) continue;
            ToolResponseMessage toolResponseMessage = (ToolResponseMessage)message;
            for (ToolResponseMessage.ToolResponse response : toolResponseMessage.getResponses()) {
                boolean toolBeforeCutoff;
                boolean aiBeforeCutoff;
                if (!toolCallIds.contains(response.id()) || (aiBeforeCutoff = aiMessageIndex < cutoffIndex) == (toolBeforeCutoff = j < cutoffIndex)) continue;
                return true;
            }
        }
        return false;
    }

    private String createSummary(List<Message> messages) {
        if (messages.isEmpty()) {
            return "No previous conversation.";
        }
        StringBuilder messageText = new StringBuilder();
        for (Message msg : messages) {
            String role = this.getRoleName(msg);
            messageText.append(role).append(": ").append(msg.getText()).append("\n");
        }
        String prompt = String.format(this.summaryPrompt, messageText.toString());
        try {
            Prompt summaryPromptObj = new Prompt(List.of(new UserMessage(prompt)));
            ChatResponse response = this.model.call(summaryPromptObj);
            return response.getResult().getOutput().getText();
        }
        catch (Exception e) {
            log.error("Failed to create summary: {}", (Object)e.getMessage());
            return "Summary generation failed: " + e.getMessage();
        }
    }

    private String getRoleName(Message message) {
        if (message instanceof UserMessage) {
            return "Human";
        }
        if (message instanceof AssistantMessage) {
            return "Assistant";
        }
        if (message instanceof SystemMessage) {
            return "System";
        }
        if (message instanceof ToolResponseMessage) {
            return "Tool";
        }
        return "Unknown";
    }

    @Override
    public String getName() {
        return "Summarization";
    }

    @Override
    public List<JumpTo> canJumpTo() {
        return List.of();
    }

    public static class Builder {
        private ChatModel model;
        private Integer maxTokensBeforeSummary;
        private int messagesToKeep = 20;
        private TokenCounter tokenCounter = TokenCounter.approximateMsgCounter();
        private String summaryPrompt = "<role>\nContext Extraction Assistant\n</role>\n\n<primary_objective>\nYour sole objective in this task is to extract the highest quality/most relevant context from the conversation history below.\n</primary_objective>\n\n<instructions>\nThe conversation history below will be replaced with the context you extract in this step. Extract and record all of the most important context from the conversation history.\nRespond ONLY with the extracted context. Do not include any additional information.\n</instructions>\n\n<messages>\nMessages to summarize:\n%s\n</messages>";
        private String summaryPrefix = "## Previous conversation summary:";

        public Builder model(ChatModel model) {
            this.model = model;
            return this;
        }

        public Builder maxTokensBeforeSummary(Integer maxTokens) {
            this.maxTokensBeforeSummary = maxTokens;
            return this;
        }

        public Builder messagesToKeep(int count) {
            this.messagesToKeep = count;
            return this;
        }

        public Builder summaryPrompt(String prompt) {
            this.summaryPrompt = prompt;
            return this;
        }

        public Builder summaryPrefix(String prefix) {
            this.summaryPrefix = prefix;
            return this;
        }

        public Builder tokenCounter(TokenCounter counter) {
            this.tokenCounter = counter;
            return this;
        }

        public SummarizationHook build() {
            if (this.model == null) {
                throw new IllegalArgumentException("model must be specified");
            }
            return new SummarizationHook(this);
        }
    }
}

