package org.springframework.ai.ollama;

import java.time.Duration;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/* loaded from: input_file:org/springframework/ai/ollama/OllamaEmbeddingModel.class */
public class OllamaEmbeddingModel extends AbstractEmbeddingModel {
    private final Logger logger;
    private final OllamaApi ollamaApi;
    private OllamaOptions defaultOptions;

    /* loaded from: input_file:org/springframework/ai/ollama/OllamaEmbeddingModel$DurationParser.class */
    public static class DurationParser {
        private static Pattern PATTERN = Pattern.compile("(\\d+)(ms|s|m|h)");

        public static Duration parse(String str) {
            if (!StringUtils.hasText(str)) {
                return null;
            }
            Matcher matcher = PATTERN.matcher(str);
            if (!matcher.matches()) {
                throw new IllegalArgumentException("Invalid duration format: " + str);
            }
            long parseLong = Long.parseLong(matcher.group(1));
            String group = matcher.group(2);
            boolean z = -1;
            switch (group.hashCode()) {
                case 104:
                    if (group.equals("h")) {
                        z = 3;
                        break;
                    }
                    break;
                case 109:
                    if (group.equals("m")) {
                        z = 2;
                        break;
                    }
                    break;
                case 115:
                    if (group.equals("s")) {
                        z = true;
                        break;
                    }
                    break;
                case 3494:
                    if (group.equals("ms")) {
                        z = false;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    return Duration.ofMillis(parseLong);
                case true:
                    return Duration.ofSeconds(parseLong);
                case true:
                    return Duration.ofMinutes(parseLong);
                case true:
                    return Duration.ofHours(parseLong);
                default:
                    throw new IllegalArgumentException("Unsupported time unit: " + group);
            }
        }
    }

    public OllamaEmbeddingModel(OllamaApi ollamaApi) {
        this.logger = LoggerFactory.getLogger(getClass());
        this.defaultOptions = OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL);
        this.ollamaApi = ollamaApi;
    }

    public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions ollamaOptions) {
        this.logger = LoggerFactory.getLogger(getClass());
        this.defaultOptions = OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL);
        this.ollamaApi = ollamaApi;
        this.defaultOptions = ollamaOptions;
    }

    public float[] embed(Document document) {
        return embed(document.getContent());
    }

    public EmbeddingResponse call(EmbeddingRequest embeddingRequest) {
        Assert.notEmpty(embeddingRequest.getInstructions(), "At least one text is required!");
        OllamaApi.EmbeddingsResponse embed = this.ollamaApi.embed(ollamaEmbeddingRequest(embeddingRequest.getInstructions(), embeddingRequest.getOptions()));
        AtomicInteger atomicInteger = new AtomicInteger(0);
        return new EmbeddingResponse(embed.embeddings().stream().map(fArr -> {
            return new Embedding(fArr, Integer.valueOf(atomicInteger.getAndIncrement()));
        }).toList(), new EmbeddingResponseMetadata(embed.model(), new EmptyUsage()));
    }

    OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(List<String> list, EmbeddingOptions embeddingOptions) {
        OllamaOptions ollamaOptions = null;
        if (embeddingOptions != null && (embeddingOptions instanceof OllamaOptions)) {
            ollamaOptions = (OllamaOptions) embeddingOptions;
        }
        OllamaOptions ollamaOptions2 = (OllamaOptions) ModelOptionsUtils.merge(ollamaOptions, this.defaultOptions, OllamaOptions.class);
        if (StringUtils.hasText(ollamaOptions2.getModel())) {
            return new OllamaApi.EmbeddingsRequest(ollamaOptions2.getModel(), list, DurationParser.parse(ollamaOptions2.getKeepAlive()), OllamaOptions.filterNonSupportedFields(ollamaOptions2.toMap()), ollamaOptions2.getTruncate());
        }
        throw new IllegalArgumentException("Model is not set!");
    }
}
