/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.azure.openai;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.ChatChoice;
import com.azure.ai.openai.models.ChatCompletions;
import com.azure.ai.openai.models.ChatCompletionsFunctionToolCall;
import com.azure.ai.openai.models.ChatCompletionsFunctionToolDefinition;
import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.azure.ai.openai.models.ChatCompletionsResponseFormat;
import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat;
import com.azure.ai.openai.models.ChatCompletionsToolCall;
import com.azure.ai.openai.models.ChatCompletionsToolDefinition;
import com.azure.ai.openai.models.ChatRequestAssistantMessage;
import com.azure.ai.openai.models.ChatRequestMessage;
import com.azure.ai.openai.models.ChatRequestSystemMessage;
import com.azure.ai.openai.models.ChatRequestToolMessage;
import com.azure.ai.openai.models.ChatRequestUserMessage;
import com.azure.ai.openai.models.ChatResponseMessage;
import com.azure.ai.openai.models.CompletionsFinishReason;
import com.azure.ai.openai.models.FunctionCall;
import com.azure.ai.openai.models.FunctionDefinition;
import com.azure.core.util.BinaryData;
import com.azure.core.util.IterableStream;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat;
import org.springframework.ai.azure.openai.MergeUtils;
import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptions;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

public class AzureOpenAiChatModel
extends AbstractFunctionCallSupport<ChatRequestMessage, ChatCompletionsOptions, ChatCompletions>
implements ChatModel,
StreamingChatModel {
    private static final String DEFAULT_DEPLOYMENT_NAME = "gpt-35-turbo";
    private static final Float DEFAULT_TEMPERATURE = Float.valueOf(0.7f);
    private final Logger logger = LoggerFactory.getLogger(((Object)((Object)this)).getClass());
    private AzureOpenAiChatOptions defaultOptions;
    private final OpenAIClient openAIClient;

    public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient) {
        this(microsoftOpenAiClient, AzureOpenAiChatOptions.builder().withDeploymentName(DEFAULT_DEPLOYMENT_NAME).withTemperature(DEFAULT_TEMPERATURE).build());
    }

    public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options) {
        this(microsoftOpenAiClient, options, null);
    }

    public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options, FunctionCallbackContext functionCallbackContext) {
        super(functionCallbackContext);
        Assert.notNull((Object)microsoftOpenAiClient, (String)"com.azure.ai.openai.OpenAIClient must not be null");
        Assert.notNull((Object)options, (String)"AzureOpenAiChatOptions must not be null");
        this.openAIClient = microsoftOpenAiClient;
        this.defaultOptions = options;
    }

    @Deprecated(forRemoval=true, since="0.8.0")
    public AzureOpenAiChatModel withDefaultOptions(AzureOpenAiChatOptions defaultOptions) {
        Assert.notNull((Object)defaultOptions, (String)"DefaultOptions must not be null");
        this.defaultOptions = defaultOptions;
        return this;
    }

    public AzureOpenAiChatOptions getDefaultOptions() {
        return AzureOpenAiChatOptions.fromOptions(this.defaultOptions);
    }

    public ChatResponse call(Prompt prompt) {
        ChatCompletionsOptions options = this.toAzureChatCompletionsOptions(prompt);
        options.setStream(Boolean.valueOf(false));
        this.logger.trace("Azure ChatCompletionsOptions: {}", (Object)options);
        ChatCompletions chatCompletions = (ChatCompletions)this.callWithFunctionSupport(options);
        this.logger.trace("Azure ChatCompletions: {}", (Object)chatCompletions);
        List<Generation> generations = chatCompletions.getChoices().stream().map(choice -> new Generation(choice.getMessage().getContent()).withGenerationMetadata(this.generateChoiceMetadata((ChatChoice)choice))).toList();
        PromptMetadata promptFilterMetadata = this.generatePromptMetadata(chatCompletions);
        return new ChatResponse(generations, (ChatResponseMetadata)AzureOpenAiChatResponseMetadata.from(chatCompletions, promptFilterMetadata));
    }

    public Flux<ChatResponse> stream(Prompt prompt) {
        ChatCompletionsOptions options = this.toAzureChatCompletionsOptions(prompt);
        options.setStream(Boolean.valueOf(true));
        IterableStream chatCompletionsStream = this.openAIClient.getChatCompletionsStream(options.getModel(), options);
        Flux chatCompletionsFlux = Flux.fromIterable((Iterable)chatCompletionsStream);
        AtomicBoolean isFunctionCall = new AtomicBoolean(false);
        Flux accessibleChatCompletionsFlux = chatCompletionsFlux.skip(1L).map(chatCompletions -> {
            List toolCalls = ((ChatChoice)chatCompletions.getChoices().get(0)).getDelta().getToolCalls();
            isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty());
            return chatCompletions;
        }).windowUntil(chatCompletions -> {
            if (isFunctionCall.get() && ((ChatChoice)chatCompletions.getChoices().get(0)).getFinishReason() == CompletionsFinishReason.TOOL_CALLS) {
                isFunctionCall.set(false);
                return true;
            }
            return !isFunctionCall.get();
        }).concatMapIterable(window -> {
            Mono reduce = window.reduce((Object)MergeUtils.emptyChatCompletions(), MergeUtils::mergeChatCompletions);
            return List.of(reduce);
        }).flatMap(mono -> mono);
        return accessibleChatCompletionsFlux.switchMap(accessibleChatCompletions -> this.handleFunctionCallOrReturnStream(options, Flux.just((Object)accessibleChatCompletions))).flatMapIterable(ChatCompletions::getChoices).map(choice -> {
            String content = Optional.ofNullable(choice.getMessage()).orElse(choice.getDelta()).getContent();
            Generation generation = new Generation(content).withGenerationMetadata(this.generateChoiceMetadata((ChatChoice)choice));
            return new ChatResponse(List.of(generation));
        });
    }

    ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
        HashSet<String> functionsForThisRequest = new HashSet<String>();
        List<ChatRequestMessage> azureMessages = prompt.getInstructions().stream().map(this::fromSpringAiMessage).toList();
        ChatCompletionsOptions options = new ChatCompletionsOptions(azureMessages);
        if (this.defaultOptions != null) {
            options = this.merge(options, this.defaultOptions);
            Set defaultEnabledFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions, false);
            functionsForThisRequest.addAll(defaultEnabledFunctions);
        }
        if (prompt.getOptions() != null) {
            ModelOptions modelOptions = prompt.getOptions();
            if (modelOptions instanceof ChatOptions) {
                ChatOptions runtimeOptions = (ChatOptions)modelOptions;
                AzureOpenAiChatOptions updatedRuntimeOptions = (AzureOpenAiChatOptions)ModelOptionsUtils.copyToTarget((Object)runtimeOptions, ChatOptions.class, AzureOpenAiChatOptions.class);
                options = this.merge(updatedRuntimeOptions, options);
                Set promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions, true);
                functionsForThisRequest.addAll(promptEnabledFunctions);
            } else {
                throw new IllegalArgumentException("Prompt options are not of type ChatCompletionsOptions:" + prompt.getOptions().getClass().getSimpleName());
            }
        }
        if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
            List<ChatCompletionsFunctionToolDefinition> tools = this.getFunctionTools(functionsForThisRequest);
            List<ChatCompletionsToolDefinition> tools2 = tools.stream().map(t -> t).toList();
            options.setTools(tools2);
        }
        return options;
    }

    private List<ChatCompletionsFunctionToolDefinition> getFunctionTools(Set<String> functionNames) {
        return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
            FunctionDefinition functionDefinition = new FunctionDefinition(functionCallback.getName());
            functionDefinition.setDescription(functionCallback.getDescription());
            BinaryData parameters = BinaryData.fromObject((Object)ModelOptionsUtils.jsonToMap((String)functionCallback.getInputTypeSchema()));
            functionDefinition.setParameters(parameters);
            return new ChatCompletionsFunctionToolDefinition(functionDefinition);
        }).toList();
    }

    private ChatRequestMessage fromSpringAiMessage(Message message) {
        switch (message.getMessageType()) {
            case USER: {
                return new ChatRequestUserMessage(message.getContent());
            }
            case SYSTEM: {
                return new ChatRequestSystemMessage(message.getContent());
            }
            case ASSISTANT: {
                return new ChatRequestAssistantMessage(message.getContent());
            }
        }
        throw new IllegalArgumentException("Unknown message type " + String.valueOf(message.getMessageType()));
    }

    private ChatGenerationMetadata generateChoiceMetadata(ChatChoice choice) {
        return ChatGenerationMetadata.from((String)String.valueOf(choice.getFinishReason()), (Object)choice.getContentFilterResults());
    }

    private PromptMetadata generatePromptMetadata(ChatCompletions chatCompletions) {
        List promptFilterResults = this.nullSafeList(chatCompletions.getPromptFilterResults());
        return PromptMetadata.of(promptFilterResults.stream().map(promptFilterResult -> PromptMetadata.PromptFilterMetadata.from((int)promptFilterResult.getPromptIndex(), (Object)promptFilterResult.getContentFilterResults())).toList());
    }

    private <T> List<T> nullSafeList(List<T> list) {
        return list != null ? list : Collections.emptyList();
    }

    private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions, AzureOpenAiChatOptions toSpringAiOptions) {
        if (toSpringAiOptions == null) {
            return fromAzureOptions;
        }
        ChatCompletionsOptions mergedAzureOptions = new ChatCompletionsOptions(fromAzureOptions.getMessages());
        mergedAzureOptions.setStream(fromAzureOptions.isStream());
        mergedAzureOptions.setMaxTokens(fromAzureOptions.getMaxTokens() != null ? fromAzureOptions.getMaxTokens() : toSpringAiOptions.getMaxTokens());
        mergedAzureOptions.setLogitBias(fromAzureOptions.getLogitBias() != null ? fromAzureOptions.getLogitBias() : toSpringAiOptions.getLogitBias());
        mergedAzureOptions.setStop(fromAzureOptions.getStop() != null ? fromAzureOptions.getStop() : toSpringAiOptions.getStop());
        mergedAzureOptions.setTemperature(fromAzureOptions.getTemperature());
        if (mergedAzureOptions.getTemperature() == null && toSpringAiOptions.getTemperature() != null) {
            mergedAzureOptions.setTemperature(Double.valueOf(toSpringAiOptions.getTemperature().doubleValue()));
        }
        mergedAzureOptions.setTopP(fromAzureOptions.getTopP());
        if (mergedAzureOptions.getTopP() == null && toSpringAiOptions.getTopP() != null) {
            mergedAzureOptions.setTopP(Double.valueOf(toSpringAiOptions.getTopP().doubleValue()));
        }
        mergedAzureOptions.setFrequencyPenalty(fromAzureOptions.getFrequencyPenalty());
        if (mergedAzureOptions.getFrequencyPenalty() == null && toSpringAiOptions.getFrequencyPenalty() != null) {
            mergedAzureOptions.setFrequencyPenalty(Double.valueOf(toSpringAiOptions.getFrequencyPenalty()));
        }
        mergedAzureOptions.setPresencePenalty(fromAzureOptions.getPresencePenalty());
        if (mergedAzureOptions.getPresencePenalty() == null && toSpringAiOptions.getPresencePenalty() != null) {
            mergedAzureOptions.setPresencePenalty(Double.valueOf(toSpringAiOptions.getPresencePenalty()));
        }
        mergedAzureOptions.setResponseFormat(fromAzureOptions.getResponseFormat());
        if (mergedAzureOptions.getResponseFormat() == null && toSpringAiOptions.getResponseFormat() != null) {
            mergedAzureOptions.setResponseFormat(this.toAzureResponseFormat(toSpringAiOptions.getResponseFormat()));
        }
        mergedAzureOptions.setN(fromAzureOptions.getN() != null ? fromAzureOptions.getN() : toSpringAiOptions.getN());
        mergedAzureOptions.setUser(fromAzureOptions.getUser() != null ? fromAzureOptions.getUser() : toSpringAiOptions.getUser());
        mergedAzureOptions.setModel(fromAzureOptions.getModel() != null ? fromAzureOptions.getModel() : toSpringAiOptions.getDeploymentName());
        return mergedAzureOptions;
    }

    private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions, ChatCompletionsOptions toAzureOptions) {
        if (fromSpringAiOptions == null) {
            return toAzureOptions;
        }
        ChatCompletionsOptions mergedAzureOptions = this.copy(toAzureOptions);
        if (fromSpringAiOptions.getMaxTokens() != null) {
            mergedAzureOptions.setMaxTokens(fromSpringAiOptions.getMaxTokens());
        }
        if (fromSpringAiOptions.getLogitBias() != null) {
            mergedAzureOptions.setLogitBias(fromSpringAiOptions.getLogitBias());
        }
        if (fromSpringAiOptions.getStop() != null) {
            mergedAzureOptions.setStop(fromSpringAiOptions.getStop());
        }
        if (fromSpringAiOptions.getTemperature() != null) {
            mergedAzureOptions.setTemperature(Double.valueOf(fromSpringAiOptions.getTemperature().doubleValue()));
        }
        if (fromSpringAiOptions.getTopP() != null) {
            mergedAzureOptions.setTopP(Double.valueOf(fromSpringAiOptions.getTopP().doubleValue()));
        }
        if (fromSpringAiOptions.getFrequencyPenalty() != null) {
            mergedAzureOptions.setFrequencyPenalty(Double.valueOf(fromSpringAiOptions.getFrequencyPenalty()));
        }
        if (fromSpringAiOptions.getPresencePenalty() != null) {
            mergedAzureOptions.setPresencePenalty(Double.valueOf(fromSpringAiOptions.getPresencePenalty()));
        }
        if (fromSpringAiOptions.getN() != null) {
            mergedAzureOptions.setN(fromSpringAiOptions.getN());
        }
        if (fromSpringAiOptions.getUser() != null) {
            mergedAzureOptions.setUser(fromSpringAiOptions.getUser());
        }
        if (fromSpringAiOptions.getDeploymentName() != null) {
            mergedAzureOptions.setModel(fromSpringAiOptions.getDeploymentName());
        }
        if (fromSpringAiOptions.getResponseFormat() != null) {
            mergedAzureOptions.setResponseFormat(this.toAzureResponseFormat(fromSpringAiOptions.getResponseFormat()));
        }
        return mergedAzureOptions;
    }

    private ChatCompletionsOptions merge(ChatCompletionsOptions fromOptions, ChatCompletionsOptions toOptions) {
        if (fromOptions == null) {
            return toOptions;
        }
        ChatCompletionsOptions mergedOptions = this.copy(toOptions);
        if (fromOptions.getMaxTokens() != null) {
            mergedOptions.setMaxTokens(fromOptions.getMaxTokens());
        }
        if (fromOptions.getLogitBias() != null) {
            mergedOptions.setLogitBias(fromOptions.getLogitBias());
        }
        if (fromOptions.getStop() != null) {
            mergedOptions.setStop(fromOptions.getStop());
        }
        if (fromOptions.getTemperature() != null) {
            mergedOptions.setTemperature(fromOptions.getTemperature());
        }
        if (fromOptions.getTopP() != null) {
            mergedOptions.setTopP(fromOptions.getTopP());
        }
        if (fromOptions.getFrequencyPenalty() != null) {
            mergedOptions.setFrequencyPenalty(fromOptions.getFrequencyPenalty());
        }
        if (fromOptions.getPresencePenalty() != null) {
            mergedOptions.setPresencePenalty(fromOptions.getPresencePenalty());
        }
        if (fromOptions.getN() != null) {
            mergedOptions.setN(fromOptions.getN());
        }
        if (fromOptions.getUser() != null) {
            mergedOptions.setUser(fromOptions.getUser());
        }
        if (fromOptions.getModel() != null) {
            mergedOptions.setModel(fromOptions.getModel());
        }
        if (fromOptions.getResponseFormat() != null) {
            mergedOptions.setResponseFormat(fromOptions.getResponseFormat());
        }
        return mergedOptions;
    }

    private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {
        ChatCompletionsOptions copyOptions = new ChatCompletionsOptions(fromOptions.getMessages());
        copyOptions.setStream(fromOptions.isStream());
        if (fromOptions.getMaxTokens() != null) {
            copyOptions.setMaxTokens(fromOptions.getMaxTokens());
        }
        if (fromOptions.getLogitBias() != null) {
            copyOptions.setLogitBias(fromOptions.getLogitBias());
        }
        if (fromOptions.getStop() != null) {
            copyOptions.setStop(fromOptions.getStop());
        }
        if (fromOptions.getTemperature() != null) {
            copyOptions.setTemperature(fromOptions.getTemperature());
        }
        if (fromOptions.getTopP() != null) {
            copyOptions.setTopP(fromOptions.getTopP());
        }
        if (fromOptions.getFrequencyPenalty() != null) {
            copyOptions.setFrequencyPenalty(fromOptions.getFrequencyPenalty());
        }
        if (fromOptions.getPresencePenalty() != null) {
            copyOptions.setPresencePenalty(fromOptions.getPresencePenalty());
        }
        if (fromOptions.getN() != null) {
            copyOptions.setN(fromOptions.getN());
        }
        if (fromOptions.getUser() != null) {
            copyOptions.setUser(fromOptions.getUser());
        }
        if (fromOptions.getModel() != null) {
            copyOptions.setModel(fromOptions.getModel());
        }
        if (fromOptions.getResponseFormat() != null) {
            copyOptions.setResponseFormat(fromOptions.getResponseFormat());
        }
        return copyOptions;
    }

    protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOptions previousRequest, ChatRequestMessage responseMessage, List<ChatRequestMessage> conversationHistory) {
        for (ChatCompletionsToolCall toolCall : ((ChatRequestAssistantMessage)responseMessage).getToolCalls()) {
            String functionName = ((ChatCompletionsFunctionToolCall)toolCall).getFunction().getName();
            String functionArguments = ((ChatCompletionsFunctionToolCall)toolCall).getFunction().getArguments();
            if (!this.functionCallbackRegister.containsKey(functionName)) {
                throw new IllegalStateException("No function callback found for function name: " + functionName);
            }
            String functionResponse = ((FunctionCallback)this.functionCallbackRegister.get(functionName)).call(functionArguments);
            conversationHistory.add((ChatRequestMessage)new ChatRequestToolMessage(functionResponse, toolCall.getId()));
        }
        ChatCompletionsOptions newRequest = new ChatCompletionsOptions(conversationHistory);
        newRequest = this.merge(previousRequest, newRequest);
        return newRequest;
    }

    protected List<ChatRequestMessage> doGetUserMessages(ChatCompletionsOptions request) {
        return request.getMessages();
    }

    protected ChatRequestMessage doGetToolResponseMessage(ChatCompletions response) {
        ChatChoice accessibleChatChoice = (ChatChoice)response.getChoices().get(0);
        ChatResponseMessage responseMessage = Optional.ofNullable(accessibleChatChoice.getMessage()).orElse(accessibleChatChoice.getDelta());
        ChatRequestAssistantMessage assistantMessage = new ChatRequestAssistantMessage("");
        List toolCalls = responseMessage.getToolCalls();
        assistantMessage.setToolCalls(toolCalls.stream().map(tc -> {
            ChatCompletionsFunctionToolCall tc1 = (ChatCompletionsFunctionToolCall)tc;
            ChatCompletionsFunctionToolCall toDowncast = new ChatCompletionsFunctionToolCall(tc.getId(), new FunctionCall(tc1.getFunction().getName(), tc1.getFunction().getArguments()));
            return toDowncast;
        }).toList());
        return assistantMessage;
    }

    protected ChatCompletions doChatCompletion(ChatCompletionsOptions request) {
        return this.openAIClient.getChatCompletions(request.getModel(), request);
    }

    protected Flux<ChatCompletions> doChatCompletionStream(ChatCompletionsOptions request) {
        return Flux.fromIterable((Iterable)this.openAIClient.getChatCompletionsStream(request.getModel(), request));
    }

    protected boolean isToolFunctionCall(ChatCompletions chatCompletions) {
        if (chatCompletions == null || CollectionUtils.isEmpty((Collection)chatCompletions.getChoices())) {
            return false;
        }
        ChatChoice choice = (ChatChoice)chatCompletions.getChoices().get(0);
        if (choice == null || choice.getFinishReason() == null) {
            return false;
        }
        return choice.getFinishReason() == CompletionsFinishReason.TOOL_CALLS;
    }

    private ChatCompletionsResponseFormat toAzureResponseFormat(AzureOpenAiResponseFormat responseFormat) {
        if (responseFormat == AzureOpenAiResponseFormat.JSON) {
            return new ChatCompletionsJsonResponseFormat();
        }
        return new ChatCompletionsTextResponseFormat();
    }
}

