/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.graph.agent.flow.node;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.AsyncEdgeAction;
import com.alibaba.cloud.ai.graph.agent.Agent;
import com.alibaba.cloud.ai.graph.agent.flow.agent.LlmRoutingAgent;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.converter.BeanOutputConverter;
import org.springframework.util.StringUtils;

public class RoutingEdgeAction
implements AsyncEdgeAction {
    private static final Logger logger = LoggerFactory.getLogger(RoutingEdgeAction.class);
    private static final int DEFAULT_MAX_RETRIES = 2;
    private final ChatClient chatClient;
    private final BeanOutputConverter<RoutingDecision> outputConverter;
    private final Agent rootAgent;
    private final List<Agent> subAgents;

    public RoutingEdgeAction(ChatModel chatModel, Agent current, List<Agent> subAgents) {
        LlmRoutingAgent llmRoutingAgent;
        this.rootAgent = current;
        this.subAgents = subAgents;
        StringBuilder sb = new StringBuilder();
        Agent agent = this.rootAgent;
        if (agent instanceof LlmRoutingAgent && StringUtils.hasLength((String)(llmRoutingAgent = (LlmRoutingAgent)agent).getSystemPrompt())) {
            sb.append("You are responsible for task routing in a graph-based AI system.\n");
            sb.append("The instruction that you should follow to finish this task is:\n\n ");
            sb.append(llmRoutingAgent.getSystemPrompt());
        } else {
            sb.append("You are responsible for task routing in a graph-based AI system.\n");
            sb.append("\n\n");
            sb.append("You have access to some specialized agents that can handle this task. You must delegate the task to ONE of the following agents.\n");
            sb.append("The available agents and their capabilities are listed below:\n");
            for (Agent agent2 : subAgents) {
                sb.append("- ").append(agent2.name()).append(": ").append(agent2.description()).append("\n");
            }
            sb.append("\n");
            sb.append("Return ONLY the exact agent name from the list above, without any explanation or additional text.\n");
            sb.append("Available names: ");
            sb.append(String.join((CharSequence)", ", subAgents.stream().map(Agent::name).toList()));
            sb.append("\n\n");
            sb.append("Example: prose_writer_agent");
        }
        this.outputConverter = new BeanOutputConverter(RoutingDecision.class);
        sb.append("\n\n");
        sb.append(this.outputConverter.getFormat());
        this.chatClient = ChatClient.builder((ChatModel)chatModel).defaultSystem(sb.toString()).build();
    }

    public CompletableFuture<String> apply(OverAllState state) {
        CompletableFuture<String> result = new CompletableFuture<String>();
        try {
            List messages = (List)state.value("messages").orElseThrow();
            List<Message> messagesWithInstruction = this.prepareMessagesWithInstruction(messages);
            String decisionValue = this.getDecisionWithRetry(messagesWithInstruction, 2);
            boolean isValidAgent = this.subAgents.stream().anyMatch(agent -> agent.name().equals(decisionValue));
            if (isValidAgent) {
                logger.info("RoutingAgent {} routed to sub-agent {}.", (Object)this.rootAgent.name(), (Object)decisionValue);
                result.complete(decisionValue);
            } else {
                logger.error("RoutingAgent {} failed to get valid decision after {} retries. Last invalid decision: {}.", new Object[]{this.rootAgent.name(), 2, decisionValue});
                result.completeExceptionally(new IllegalStateException("RoutingAgent " + this.rootAgent.name() + " failed to get valid decision after retries. Last invalid decision: " + decisionValue + "."));
            }
        }
        catch (Exception e) {
            logger.error("Error during routing decision: ", (Throwable)e);
            result.completeExceptionally(e);
        }
        return result;
    }

    private List<Message> prepareMessagesWithInstruction(List<Message> messages) {
        ArrayList<Message> messagesWithInstruction = new ArrayList<Message>(messages);
        Agent agent = this.rootAgent;
        if (agent instanceof LlmRoutingAgent) {
            LlmRoutingAgent llmRoutingAgent = (LlmRoutingAgent)agent;
            String instruction = llmRoutingAgent.getInstruction();
            if (StringUtils.hasLength((String)instruction)) {
                messagesWithInstruction.add((Message)new UserMessage(instruction));
            } else {
                messagesWithInstruction.add((Message)new UserMessage("Based on the chat history and current task progress, please decide the next agent to delegate the task to."));
            }
        } else {
            messagesWithInstruction.add((Message)new UserMessage("Based on the chat history and current task progress, please decide the next agent to delegate the task to."));
        }
        return messagesWithInstruction;
    }

    private String getDecisionWithRetry(List<Message> messages, int maxRetries) throws Exception {
        String lastInvalidDecision = null;
        for (int attempt = 0; attempt <= maxRetries; ++attempt) {
            try {
                RoutingDecision decision;
                if (attempt == 0) {
                    decision = (RoutingDecision)this.chatClient.prompt().messages(messages).call().entity(this.outputConverter);
                } else {
                    String errorFeedback = String.format("Previous attempt returned an invalid agent name '%s'. Please choose from the available agents: %s.", lastInvalidDecision, String.join((CharSequence)", ", this.subAgents.stream().map(Agent::name).toList()));
                    logger.warn("RoutingAgent {} retry attempt {}/{}. Previous invalid decision: {}", new Object[]{this.rootAgent.name(), attempt, maxRetries, lastInvalidDecision});
                    ArrayList<Object> messagesWithFeedback = new ArrayList<Object>();
                    boolean systemMessageFound = false;
                    for (Message msg : messages) {
                        if (msg instanceof SystemMessage && !systemMessageFound) {
                            String enhancedContent = msg.getText() + "\n\n" + errorFeedback;
                            messagesWithFeedback.add(new SystemMessage(enhancedContent));
                            systemMessageFound = true;
                            continue;
                        }
                        messagesWithFeedback.add(msg);
                    }
                    if (!systemMessageFound) {
                        messagesWithFeedback.add(new UserMessage(errorFeedback));
                    }
                    decision = (RoutingDecision)this.chatClient.prompt().messages(messagesWithFeedback).call().entity(this.outputConverter);
                }
                String decisionValue = decision.agent();
                boolean isValidAgent = this.subAgents.stream().anyMatch(agent -> agent.name().equals(decisionValue));
                if (isValidAgent) {
                    if (attempt > 0) {
                        logger.info("RoutingAgent {} succeeded on retry attempt {}. Routed to sub-agent: {}", new Object[]{this.rootAgent.name(), attempt, decisionValue});
                    }
                    return decisionValue;
                }
                lastInvalidDecision = decisionValue;
                logger.warn("RoutingAgent {} attempt {}/{} returned invalid agent name: {}", new Object[]{this.rootAgent.name(), attempt, maxRetries, decisionValue});
                continue;
            }
            catch (Exception e) {
                if (attempt == maxRetries) {
                    logger.error("RoutingAgent {} failed on final attempt {}/{}", new Object[]{this.rootAgent.name(), attempt, maxRetries, e});
                    throw e;
                }
                logger.warn("RoutingAgent {} attempt {}/{} encountered an error, will retry", new Object[]{this.rootAgent.name(), attempt, maxRetries, e});
            }
        }
        throw new IllegalStateException(String.format("Failed to get valid decision after %d retries. Last invalid decision: %s", maxRetries, lastInvalidDecision));
    }

    public record RoutingDecision(String agent) {
    }
}

