/*
 * Decompiled with CFR 0.152.
 */
package org.noear.solon.ai.chat;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import org.noear.snack4.ONode;
import org.noear.solon.Utils;
import org.noear.solon.ai.chat.ChatChoice;
import org.noear.solon.ai.chat.ChatConfig;
import org.noear.solon.ai.chat.ChatException;
import org.noear.solon.ai.chat.ChatOptions;
import org.noear.solon.ai.chat.ChatRequest;
import org.noear.solon.ai.chat.ChatRequestDesc;
import org.noear.solon.ai.chat.ChatResponse;
import org.noear.solon.ai.chat.ChatResponseDefault;
import org.noear.solon.ai.chat.ChatSession;
import org.noear.solon.ai.chat.ChatSubscriberProxy;
import org.noear.solon.ai.chat.dialect.ChatDialect;
import org.noear.solon.ai.chat.interceptor.CallChain;
import org.noear.solon.ai.chat.interceptor.ChatInterceptor;
import org.noear.solon.ai.chat.interceptor.StreamChain;
import org.noear.solon.ai.chat.interceptor.ToolChain;
import org.noear.solon.ai.chat.interceptor.ToolRequest;
import org.noear.solon.ai.chat.message.AssistantMessage;
import org.noear.solon.ai.chat.message.ChatMessage;
import org.noear.solon.ai.chat.message.ToolMessage;
import org.noear.solon.ai.chat.tool.FunctionTool;
import org.noear.solon.ai.chat.tool.ToolCall;
import org.noear.solon.ai.chat.tool.ToolCallBuilder;
import org.noear.solon.ai.chat.tool.ToolCallException;
import org.noear.solon.core.util.RankEntity;
import org.noear.solon.net.http.HttpException;
import org.noear.solon.net.http.HttpResponse;
import org.noear.solon.net.http.HttpUtils;
import org.noear.solon.net.http.textstream.ServerSentEvent;
import org.noear.solon.net.http.textstream.TextStreamUtil;
import org.noear.solon.rx.SimpleSubscriber;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ChatRequestDescDefault
implements ChatRequestDesc {
    private static final Logger log = LoggerFactory.getLogger(ChatRequestDescDefault.class);
    private final ChatConfig config;
    private final ChatDialect dialect;
    private final ChatSession session;
    private ChatOptions options;

    public ChatRequestDescDefault(ChatConfig config, ChatDialect dialect, ChatSession session) {
        this.config = config;
        this.dialect = dialect;
        this.session = session;
        this.options = new ChatOptions();
        if (Utils.isNotEmpty(config.getDefaultToolsContext())) {
            this.options.toolsContext().putAll(config.getDefaultToolsContext());
        }
        if (Utils.isNotEmpty(config.getDefaultOptions())) {
            this.options.options().putAll(config.getDefaultOptions());
        }
    }

    @Override
    public ChatRequestDesc options(ChatOptions options) {
        if (options != null) {
            this.options = options;
        }
        return this;
    }

    @Override
    public ChatRequestDesc options(Consumer<ChatOptions> optionsBuilder) {
        optionsBuilder.accept(this.options);
        return this;
    }

    @Override
    public ChatResponse call() throws IOException {
        ArrayList<RankEntity<ChatInterceptor>> interceptorList = new ArrayList<RankEntity<ChatInterceptor>>();
        interceptorList.addAll(this.config.getDefaultInterceptors());
        interceptorList.addAll(this.options.interceptors());
        if (interceptorList.size() > 1) {
            Collections.sort(interceptorList);
        }
        ChatRequest req = new ChatRequest(this.config, this.dialect, this.options, false, this.session.getMessages());
        CallChain chain = new CallChain(interceptorList, this::doCall);
        return chain.doIntercept(req);
    }

    private ChatResponse doCall(ChatRequest req) throws IOException {
        HttpUtils httpUtils = this.config.createHttpUtils();
        String reqJson = req.toRequestData();
        if (log.isDebugEnabled()) {
            log.debug("ai-request: {}", (Object)reqJson);
        }
        String respJson = httpUtils.bodyOfJson(reqJson).post();
        if (log.isDebugEnabled()) {
            log.debug("ai-response: {}", (Object)respJson);
        }
        ChatResponseDefault resp = new ChatResponseDefault(false);
        resp.setResponseData(respJson);
        this.dialect.parseResponseJson(this.config, resp, respJson);
        if (resp.getError() != null) {
            throw resp.getError();
        }
        if (resp.hasChoices()) {
            AssistantMessage choiceMessage = resp.getMessage();
            this.session.addMessage(choiceMessage);
            if (Utils.isNotEmpty(choiceMessage.getToolCalls())) {
                List<ToolMessage> returnDirectMessages = this.buildToolMessage(resp, choiceMessage);
                if (Utils.isEmpty(returnDirectMessages)) {
                    return this.call();
                }
                choiceMessage = this.dialect.buildAssistantMessageByToolMessages(returnDirectMessages);
                resp.reset();
                resp.addChoice(new ChatChoice(0, new Date(), "tool", choiceMessage));
                this.session.addMessage(choiceMessage);
            }
        }
        return resp;
    }

    @Override
    public Publisher<ChatResponse> stream() {
        ArrayList<RankEntity<ChatInterceptor>> interceptorList = new ArrayList<RankEntity<ChatInterceptor>>();
        interceptorList.addAll(this.config.getDefaultInterceptors());
        interceptorList.addAll(this.options.interceptors());
        if (interceptorList.size() > 1) {
            Collections.sort(interceptorList);
        }
        ChatRequest req = new ChatRequest(this.config, this.dialect, this.options, true, this.session.getMessages());
        StreamChain chain = new StreamChain(interceptorList, this::doStream);
        return chain.doIntercept(req);
    }

    private Publisher<ChatResponse> doStream(ChatRequest req) {
        HttpUtils httpUtils = this.config.createHttpUtils();
        String reqJson = req.toRequestData();
        if (log.isDebugEnabled()) {
            log.debug("ai-request: {}", (Object)reqJson);
        }
        return subscriber -> httpUtils.bodyOfJson(reqJson).execAsync("POST").whenComplete((resp, err) -> {
            block5: {
                Subscriber subscriberProxy = ChatSubscriberProxy.of(subscriber);
                if (err == null) {
                    try {
                        if (resp.code() < 400) {
                            this.parseResp((HttpResponse)resp, (Subscriber<? super ChatResponse>)subscriberProxy);
                            break block5;
                        }
                        String message = resp.bodyAsString();
                        String description = Utils.isEmpty((String)message) ? "Error code:" + resp.code() : "Error code:" + resp.code() + ", message:" + message;
                        subscriberProxy.onError((Throwable)new HttpException(description));
                    }
                    catch (IOException e) {
                        subscriberProxy.onError((Throwable)e);
                    }
                } else {
                    subscriberProxy.onError(err);
                }
            }
        });
    }

    private void parseResp(HttpResponse httpResp, Subscriber<? super ChatResponse> subscriber) throws IOException {
        ChatResponseDefault resp = new ChatResponseDefault(true);
        String contentType = httpResp.header("Content-Type");
        try {
            if (contentType != null && contentType.startsWith("text/event-stream")) {
                TextStreamUtil.parseSseStream((HttpResponse)httpResp, (Subscriber)new SimpleSubscriber().doOnSubscribe(arg_0 -> subscriber.onSubscribe(arg_0)).doOnNext(event -> this.onEventStream(resp, (ServerSentEvent)event, subscriber)).doOnComplete(() -> this.onEventEnd(resp, subscriber)).doOnError(arg_0 -> subscriber.onError(arg_0)));
            } else {
                TextStreamUtil.parseLineStream((HttpResponse)httpResp, (Subscriber)new SimpleSubscriber().doOnSubscribe(arg_0 -> subscriber.onSubscribe(arg_0)).doOnNext(data -> this.onEventStream(resp, new ServerSentEvent(null, data), subscriber)).doOnComplete(() -> this.onEventEnd(resp, subscriber)).doOnError(arg_0 -> subscriber.onError(arg_0)));
            }
        }
        catch (Throwable ex) {
            subscriber.onError(ex);
        }
    }

    private void onEventEnd(ChatResponseDefault resp, Subscriber<? super ChatResponse> subscriber) {
        if (resp.toolCallBuilders.size() > 0 && !this.buildStreamToolMessage(resp, subscriber)) {
            return;
        }
        AssistantMessage aggregationMessage = resp.getAggregationMessage();
        if (aggregationMessage != null) {
            this.session.addMessage(aggregationMessage);
        }
        subscriber.onComplete();
    }

    private boolean onEventStream(ChatResponseDefault resp, ServerSentEvent event, Subscriber<? super ChatResponse> subscriber) {
        if (log.isDebugEnabled()) {
            log.debug("ai-response: {}", (Object)event.data());
        }
        resp.setResponseData(event.data());
        if (Utils.isEmpty((String)event.data())) {
            return true;
        }
        resp.reset();
        if (this.dialect.parseResponseJson(this.config, resp, event.data())) {
            if (resp.getError() != null) {
                subscriber.onError((Throwable)((Object)resp.getError()));
                return false;
            }
            if (resp.hasChoices()) {
                AssistantMessage choiceMessage = resp.getMessage();
                if (Utils.isNotEmpty(choiceMessage.getToolCalls())) {
                    this.buildToolCallBuilder(resp, choiceMessage);
                }
                if (choiceMessage != null) {
                    if (resp.getChoices().size() > 1) {
                        ArrayList<ChatChoice> choices = new ArrayList<ChatChoice>(resp.getChoices());
                        for (ChatChoice choice : choices) {
                            resp.reset();
                            resp.addChoice(choice);
                            this.publishResponse(subscriber, resp, choice);
                        }
                    } else {
                        this.publishResponse(subscriber, resp, resp.getChoices().get(0));
                    }
                }
            }
        }
        return true;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private boolean buildStreamToolMessage(ChatResponseDefault resp, Subscriber<? super ChatResponse> subscriber) {
        try {
            ONode oNode = this.dialect.buildAssistantMessageNode(resp.toolCallBuilders);
            List<AssistantMessage> assistantMessages = this.dialect.parseAssistantMessage(resp, oNode);
            this.session.addMessage(assistantMessages);
            List<ToolMessage> returnDirectMessages = this.buildToolMessage(resp, assistantMessages.get(0));
            if (Utils.isEmpty(returnDirectMessages)) {
                this.stream().subscribe(subscriber);
                boolean bl = false;
                return bl;
            }
            AssistantMessage message = this.dialect.buildAssistantMessageByToolMessages(returnDirectMessages);
            resp.reset();
            resp.addChoice(new ChatChoice(0, new Date(), "tool", message));
            resp.aggregationMessageContent.setLength(0);
            this.publishResponse(subscriber, resp, resp.lastChoice());
            boolean bl = true;
            return bl;
        }
        finally {
            resp.toolCallBuilders.clear();
        }
    }

    private void publishResponse(Subscriber<? super ChatResponse> subscriber, ChatResponseDefault resp, ChatChoice choice) {
        if (choice.getMessage().getContent() != null) {
            resp.aggregationMessageContent.append(choice.getMessage().getContent());
        }
        subscriber.onNext((Object)resp);
    }

    private void buildToolCallBuilder(ChatResponseDefault resp, AssistantMessage acm) {
        if (Utils.isEmpty(acm.getToolCalls())) {
            return;
        }
        for (ToolCall call : acm.getToolCalls()) {
            ToolCallBuilder callBuilder = resp.toolCallBuilders.computeIfAbsent(call.index(), k -> new ToolCallBuilder());
            if (call.id() != null) {
                callBuilder.idBuilder.append(call.id());
            }
            if (call.name() != null) {
                callBuilder.nameBuilder.append(call.name());
            }
            if (call.argumentsStr() == null) continue;
            callBuilder.argumentsBuilder.append(call.argumentsStr());
        }
    }

    private List<ToolMessage> buildToolMessage(ChatResponseDefault resp, AssistantMessage acm) throws ChatException {
        if (Utils.isEmpty(acm.getToolCalls())) {
            return null;
        }
        ArrayList<ToolMessage> toolMessages = new ArrayList<ToolMessage>();
        for (ToolCall call : acm.getToolCalls()) {
            FunctionTool func = this.config.getDefaultTool(call.name());
            if (func == null) {
                func = this.options.tool(call.name());
            }
            if (func != null) {
                try {
                    String content = this.doToolCall(func, call.arguments());
                    ToolMessage toolMessage = ChatMessage.ofTool(content, call.name(), call.id(), func.returnDirect());
                    this.session.addMessage(toolMessage);
                    toolMessages.add(toolMessage);
                    continue;
                }
                catch (Throwable ex) {
                    throw new ToolCallException("The tool call failed, name: '" + func + "'", ex);
                }
            }
            log.warn("Tool call not found: {}", (Object)call.name());
        }
        if (toolMessages.size() > 0 && toolMessages.stream().filter(m -> !m.isReturnDirect()).count() == 0L) {
            return toolMessages;
        }
        return null;
    }

    private String doToolCall(FunctionTool func, Map<String, Object> args) throws Throwable {
        ArrayList<RankEntity<ChatInterceptor>> interceptorList = new ArrayList<RankEntity<ChatInterceptor>>();
        interceptorList.addAll(this.config.getDefaultInterceptors());
        interceptorList.addAll(this.options.interceptors());
        if (interceptorList.size() > 1) {
            Collections.sort(interceptorList);
        }
        ToolRequest req = new ToolRequest(this.config, this.options, args);
        ToolChain chain = new ToolChain(interceptorList, func);
        return chain.doIntercept(req);
    }
}

