/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.graph.agent;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.agent.MessageToolCallResultConverter;
import com.alibaba.cloud.ai.graph.agent.ReactAgent;
import com.alibaba.cloud.ai.graph.serializer.AgentInstructionMessage;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.execution.ToolCallResultConverter;
import org.springframework.ai.tool.function.FunctionToolCallback;
import org.springframework.ai.util.json.schema.JsonSchemaGenerator;
import org.springframework.util.StringUtils;

public class AgentTool
implements BiFunction<String, ToolContext, AssistantMessage> {
    private final ReactAgent agent;
    private static final ToolCallResultConverter CONVERTER = new MessageToolCallResultConverter();

    public AgentTool(ReactAgent agent) {
        this.agent = agent;
    }

    @Override
    public AssistantMessage apply(String input, ToolContext toolContext) {
        OverAllState state = (OverAllState)toolContext.getContext().get("_AGENT_STATE_");
        try {
            OverAllState newState = this.agent.getAndCompileGraph().cloneState(state.data());
            ArrayList<Object> messagesToAdd = new ArrayList<Object>();
            if (StringUtils.hasLength((String)this.agent.instruction())) {
                messagesToAdd.add(new AgentInstructionMessage(this.agent.instruction()));
            }
            messagesToAdd.add(new UserMessage(input));
            Map inputs = newState.updateState(Map.of("messages", messagesToAdd));
            Optional resultState = this.agent.getAndCompileGraph().invoke(inputs);
            Optional messages = resultState.flatMap(overAllState -> overAllState.value("messages", List.class));
            if (messages.isPresent()) {
                List messageList = (List)messages.get();
                AssistantMessage assistantMessage = (AssistantMessage)messageList.get(messageList.size() - 1);
                return assistantMessage;
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        throw new RuntimeException("Failed to execute agent tool or failed to get agent tool result");
    }

    public static AgentTool create(ReactAgent agent) {
        return new AgentTool(agent);
    }

    public static ToolCallback getFunctionToolCallback(ReactAgent agent) {
        String inputSchema = StringUtils.hasLength((String)agent.getInputSchema()) ? agent.getInputSchema() : (agent.getInputType() != null ? JsonSchemaGenerator.generateForType((Type)agent.getInputType(), (JsonSchemaGenerator.SchemaOption[])new JsonSchemaGenerator.SchemaOption[0]) : null);
        return FunctionToolCallback.builder((String)agent.name(), (BiFunction)AgentTool.create(agent)).description(agent.description()).inputType(String.class).inputSchema(inputSchema).toolCallResultConverter(CONVERTER).build();
    }
}

