package org.springframework.ai.oci.cohere;

import com.oracle.bmc.generativeaiinference.GenerativeAiInference;
import com.oracle.bmc.generativeaiinference.model.ChatDetails;
import com.oracle.bmc.generativeaiinference.model.CohereChatBotMessage;
import com.oracle.bmc.generativeaiinference.model.CohereChatRequest;
import com.oracle.bmc.generativeaiinference.model.CohereChatResponse;
import com.oracle.bmc.generativeaiinference.model.CohereMessage;
import com.oracle.bmc.generativeaiinference.model.CohereSystemMessage;
import com.oracle.bmc.generativeaiinference.model.CohereToolCall;
import com.oracle.bmc.generativeaiinference.model.CohereToolMessage;
import com.oracle.bmc.generativeaiinference.model.CohereToolResult;
import com.oracle.bmc.generativeaiinference.model.CohereUserMessage;
import com.oracle.bmc.generativeaiinference.requests.ChatRequest;
import io.micrometer.observation.ObservationRegistry;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.oci.ServingModeHelper;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/* loaded from: input_file:org/springframework/ai/oci/cohere/OCICohereChatModel.class */
public class OCICohereChatModel implements ChatModel {
    private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
    private static final Double DEFAULT_TEMPERATURE = Double.valueOf(0.7d);
    private final GenerativeAiInference genAi;
    private final OCICohereChatOptions defaultOptions;
    private final ObservationRegistry observationRegistry;
    private ChatModelObservationConvention observationConvention;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.springframework.ai.oci.cohere.OCICohereChatModel$1, reason: invalid class name */
    /* loaded from: input_file:org/springframework/ai/oci/cohere/OCICohereChatModel$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$springframework$ai$chat$messages$MessageType = new int[MessageType.values().length];

        static {
            try {
                $SwitchMap$org$springframework$ai$chat$messages$MessageType[MessageType.USER.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$springframework$ai$chat$messages$MessageType[MessageType.ASSISTANT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$springframework$ai$chat$messages$MessageType[MessageType.SYSTEM.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$springframework$ai$chat$messages$MessageType[MessageType.TOOL.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    public OCICohereChatModel(GenerativeAiInference generativeAiInference, OCICohereChatOptions oCICohereChatOptions) {
        this(generativeAiInference, oCICohereChatOptions, null);
    }

    public OCICohereChatModel(GenerativeAiInference generativeAiInference, OCICohereChatOptions oCICohereChatOptions, ObservationRegistry observationRegistry) {
        this.observationConvention = DEFAULT_OBSERVATION_CONVENTION;
        Assert.notNull(generativeAiInference, "com.oracle.bmc.generativeaiinference.GenerativeAiInference must not be null");
        Assert.notNull(oCICohereChatOptions, "OCIChatOptions must not be null");
        this.genAi = generativeAiInference;
        this.defaultOptions = oCICohereChatOptions;
        this.observationRegistry = observationRegistry;
    }

    public ChatResponse call(Prompt prompt) {
        ChatModelObservationContext build = ChatModelObservationContext.builder().prompt(prompt).provider(AiProvider.OCI_GENAI.value()).requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions).build();
        return (ChatResponse) ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> {
            return build;
        }, this.observationRegistry).observe(() -> {
            ChatResponse doChatRequest = doChatRequest(prompt);
            build.setResponse(doChatRequest);
            return doChatRequest;
        });
    }

    public ChatOptions getDefaultOptions() {
        return OCICohereChatOptions.fromOptions(this.defaultOptions);
    }

    public void setObservationConvention(ChatModelObservationConvention chatModelObservationConvention) {
        Assert.notNull(chatModelObservationConvention, "observationConvention cannot be null");
        this.observationConvention = chatModelObservationConvention;
    }

    private ChatResponse doChatRequest(Prompt prompt) {
        OCICohereChatOptions mergeOptions = mergeOptions(prompt.getOptions(), this.defaultOptions);
        validateChatOptions(mergeOptions);
        return new ChatResponse(getGenerations(prompt, mergeOptions), ChatResponseMetadata.builder().model(mergeOptions.getModel()).keyValue("compartment", mergeOptions.getCompartment()).build());
    }

    private OCICohereChatOptions mergeOptions(ChatOptions chatOptions, OCICohereChatOptions oCICohereChatOptions) {
        OCICohereChatOptions oCICohereChatOptions2;
        return (!(chatOptions instanceof OCICohereChatOptions) || (oCICohereChatOptions2 = (OCICohereChatOptions) ModelOptionsUtils.merge((OCICohereChatOptions) chatOptions, oCICohereChatOptions, OCICohereChatOptions.class)) == null) ? oCICohereChatOptions : oCICohereChatOptions2;
    }

    private void validateChatOptions(OCICohereChatOptions oCICohereChatOptions) {
        if (!StringUtils.hasText(oCICohereChatOptions.getModel())) {
            throw new IllegalArgumentException("Model is not set!");
        }
        if (!StringUtils.hasText(oCICohereChatOptions.getCompartment())) {
            throw new IllegalArgumentException("Compartment is not set!");
        }
        if (!StringUtils.hasText(oCICohereChatOptions.getServingMode())) {
            throw new IllegalArgumentException("ServingMode is not set!");
        }
    }

    private List<Generation> getGenerations(Prompt prompt, OCICohereChatOptions oCICohereChatOptions) {
        return toGenerations(this.genAi.chat(toCohereChatRequest(prompt, oCICohereChatOptions)), oCICohereChatOptions);
    }

    private List<Generation> toGenerations(com.oracle.bmc.generativeaiinference.responses.ChatResponse chatResponse, OCICohereChatOptions oCICohereChatOptions) {
        CohereChatResponse chatResponse2 = chatResponse.getChatResult().getChatResponse();
        if (!(chatResponse2 instanceof CohereChatResponse)) {
            throw new IllegalStateException(String.format("Unexpected chat response type: %s", chatResponse2.getClass().getName()));
        }
        CohereChatResponse cohereChatResponse = chatResponse2;
        ArrayList arrayList = new ArrayList();
        arrayList.add(new Generation(new AssistantMessage(cohereChatResponse.getText(), Map.of()), ChatGenerationMetadata.builder().finishReason(cohereChatResponse.getFinishReason().getValue()).build()));
        return arrayList;
    }

    private ChatRequest toCohereChatRequest(Prompt prompt, OCICohereChatOptions oCICohereChatOptions) {
        List<Message> instructions = prompt.getInstructions();
        return newChatRequest(oCICohereChatOptions, instructions.get(0), getCohereMessages(instructions));
    }

    private List<CohereMessage> getCohereMessages(List<Message> list) {
        ArrayList arrayList = new ArrayList();
        for (int i = 1; i < list.size(); i++) {
            Message message = list.get(i);
            switch (AnonymousClass1.$SwitchMap$org$springframework$ai$chat$messages$MessageType[message.getMessageType().ordinal()]) {
                case 1:
                    arrayList.add(CohereUserMessage.builder().message(message.getText()).build());
                    break;
                case 2:
                    arrayList.add(CohereChatBotMessage.builder().message(message.getText()).build());
                    break;
                case 3:
                    arrayList.add(CohereSystemMessage.builder().message(message.getText()).build());
                    break;
                case 4:
                    if (message instanceof ToolResponseMessage) {
                        arrayList.add(toToolMessage((ToolResponseMessage) message));
                        break;
                    } else {
                        break;
                    }
            }
        }
        return arrayList;
    }

    private CohereToolMessage toToolMessage(ToolResponseMessage toolResponseMessage) {
        return CohereToolMessage.builder().toolResults(toolResponseMessage.getResponses().stream().map(toolResponse -> {
            return CohereToolResult.builder().call(CohereToolCall.builder().name(toolResponse.name()).build()).outputs(List.of(toolResponse.responseData())).build();
        }).toList()).build();
    }

    private ChatRequest newChatRequest(OCICohereChatOptions oCICohereChatOptions, Message message, List<CohereMessage> list) {
        return ChatRequest.builder().body$(ChatDetails.builder().compartmentId(oCICohereChatOptions.getCompartment()).servingMode(ServingModeHelper.get(oCICohereChatOptions.getServingMode(), oCICohereChatOptions.getModel())).chatRequest(CohereChatRequest.builder().frequencyPenalty(oCICohereChatOptions.getFrequencyPenalty()).presencePenalty(oCICohereChatOptions.getPresencePenalty()).maxTokens(oCICohereChatOptions.getMaxTokens()).topK(oCICohereChatOptions.getTopK()).topP(oCICohereChatOptions.getTopP()).temperature((Double) Objects.requireNonNullElse(oCICohereChatOptions.getTemperature(), DEFAULT_TEMPERATURE)).preambleOverride(oCICohereChatOptions.getPreambleOverride()).stopSequences(oCICohereChatOptions.getStopSequences()).documents(oCICohereChatOptions.getDocuments()).tools(oCICohereChatOptions.getTools()).chatHistory(list).message(message.getText()).build()).build()).build();
    }
}
