package org.springframework.ai.vertexai.embedding.text;

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictRequest;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.protobuf.Value;
import io.micrometer.observation.ObservationRegistry;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/* loaded from: input_file:org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.class */
public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel {
    private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
    private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = (Map) Stream.of((Object[]) VertexAiTextEmbeddingModelName.values()).collect(Collectors.toMap((v0) -> {
        return v0.getName();
    }, (v0) -> {
        return v0.getDimensions();
    }));
    public final VertexAiTextEmbeddingOptions defaultOptions;
    private final VertexAiEmbeddingConnectionDetails connectionDetails;
    private final RetryTemplate retryTemplate;
    private final ObservationRegistry observationRegistry;
    private EmbeddingModelObservationConvention observationConvention;

    public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails vertexAiEmbeddingConnectionDetails, VertexAiTextEmbeddingOptions vertexAiTextEmbeddingOptions) {
        this(vertexAiEmbeddingConnectionDetails, vertexAiTextEmbeddingOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
    }

    public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails vertexAiEmbeddingConnectionDetails, VertexAiTextEmbeddingOptions vertexAiTextEmbeddingOptions, RetryTemplate retryTemplate) {
        this(vertexAiEmbeddingConnectionDetails, vertexAiTextEmbeddingOptions, retryTemplate, ObservationRegistry.NOOP);
    }

    public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails vertexAiEmbeddingConnectionDetails, VertexAiTextEmbeddingOptions vertexAiTextEmbeddingOptions, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
        this.observationConvention = DEFAULT_OBSERVATION_CONVENTION;
        Assert.notNull(vertexAiTextEmbeddingOptions, "VertexAiTextEmbeddingOptions must not be null");
        Assert.notNull(retryTemplate, "retryTemplate must not be null");
        Assert.notNull(observationRegistry, "observationRegistry must not be null");
        this.defaultOptions = vertexAiTextEmbeddingOptions.initializeDefaults();
        this.connectionDetails = vertexAiEmbeddingConnectionDetails;
        this.retryTemplate = retryTemplate;
        this.observationRegistry = observationRegistry;
    }

    public float[] embed(Document document) {
        Assert.notNull(document, "Document must not be null");
        return embed(document.getFormattedContent());
    }

    public EmbeddingResponse call(EmbeddingRequest embeddingRequest) {
        VertexAiTextEmbeddingOptions mergedOptions = mergedOptions(embeddingRequest);
        EmbeddingModelObservationContext build = EmbeddingModelObservationContext.builder().embeddingRequest(embeddingRequest).provider(AiProvider.VERTEX_AI.value()).requestOptions(mergedOptions).build();
        return (EmbeddingResponse) EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> {
            return build;
        }, this.observationRegistry).observe(() -> {
            PredictionServiceClient createPredictionServiceClient = createPredictionServiceClient();
            PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(embeddingRequest, this.connectionDetails.getEndpointName(mergedOptions.getModel()), mergedOptions);
            PredictResponse predictResponse = (PredictResponse) this.retryTemplate.execute(retryContext -> {
                return getPredictResponse(createPredictionServiceClient, predictRequestBuilder);
            });
            int i = 0;
            int i2 = 0;
            ArrayList arrayList = new ArrayList();
            Iterator it = predictResponse.getPredictionsList().iterator();
            while (it.hasNext()) {
                Value fieldsOrThrow = ((Value) it.next()).getStructValue().getFieldsOrThrow("embeddings");
                i2 += (int) fieldsOrThrow.getStructValue().getFieldsOrThrow("statistics").getStructValue().getFieldsOrThrow("token_count").getNumberValue();
                int i3 = i;
                i++;
                arrayList.add(new Embedding(VertexAiEmbeddingUtils.toVector(fieldsOrThrow.getStructValue().getFieldsOrThrow("values")), Integer.valueOf(i3)));
            }
            EmbeddingResponse embeddingResponse = new EmbeddingResponse(arrayList, generateResponseMetadata(mergedOptions.getModel(), Integer.valueOf(i2)));
            build.setResponse(embeddingResponse);
            return embeddingResponse;
        });
    }

    private VertexAiTextEmbeddingOptions mergedOptions(EmbeddingRequest embeddingRequest) {
        VertexAiTextEmbeddingOptions vertexAiTextEmbeddingOptions = this.defaultOptions;
        if (embeddingRequest.getOptions() != null) {
            vertexAiTextEmbeddingOptions = (VertexAiTextEmbeddingOptions) ModelOptionsUtils.merge(embeddingRequest.getOptions(), VertexAiTextEmbeddingOptions.builder().from(this.defaultOptions).build(), VertexAiTextEmbeddingOptions.class);
        }
        return vertexAiTextEmbeddingOptions;
    }

    protected PredictRequest.Builder getPredictRequestBuilder(EmbeddingRequest embeddingRequest, EndpointName endpointName, VertexAiTextEmbeddingOptions vertexAiTextEmbeddingOptions) {
        PredictRequest.Builder endpoint = PredictRequest.newBuilder().setEndpoint(endpointName.toString());
        VertexAiEmbeddingUtils.TextParametersBuilder of = VertexAiEmbeddingUtils.TextParametersBuilder.of();
        if (vertexAiTextEmbeddingOptions.getAutoTruncate() != null) {
            of.autoTruncate(vertexAiTextEmbeddingOptions.getAutoTruncate());
        }
        if (vertexAiTextEmbeddingOptions.getDimensions() != null) {
            of.outputDimensionality(vertexAiTextEmbeddingOptions.getDimensions());
        }
        endpoint.setParameters(VertexAiEmbeddingUtils.valueOf(of.build()));
        for (int i = 0; i < embeddingRequest.getInstructions().size(); i++) {
            VertexAiEmbeddingUtils.TextInstanceBuilder taskType = VertexAiEmbeddingUtils.TextInstanceBuilder.of((String) embeddingRequest.getInstructions().get(i)).taskType(vertexAiTextEmbeddingOptions.getTaskType().name());
            if (StringUtils.hasText(vertexAiTextEmbeddingOptions.getTitle())) {
                taskType.title(vertexAiTextEmbeddingOptions.getTitle());
            }
            endpoint.addInstances(VertexAiEmbeddingUtils.valueOf(taskType.build()));
        }
        return endpoint;
    }

    PredictionServiceClient createPredictionServiceClient() {
        try {
            return PredictionServiceClient.create(this.connectionDetails.getPredictionServiceSettings());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    PredictResponse getPredictResponse(PredictionServiceClient predictionServiceClient, PredictRequest.Builder builder) {
        return predictionServiceClient.predict(builder.build());
    }

    private EmbeddingResponseMetadata generateResponseMetadata(String str, Integer num) {
        EmbeddingResponseMetadata embeddingResponseMetadata = new EmbeddingResponseMetadata();
        embeddingResponseMetadata.setModel(str);
        embeddingResponseMetadata.setUsage(getDefaultUsage(num));
        return embeddingResponseMetadata;
    }

    private DefaultUsage getDefaultUsage(Integer num) {
        return new DefaultUsage(0, 0, num);
    }

    public int dimensions() {
        return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), Integer.valueOf(super.dimensions())).intValue();
    }

    public void setObservationConvention(EmbeddingModelObservationConvention embeddingModelObservationConvention) {
        Assert.notNull(embeddingModelObservationConvention, "observationConvention cannot be null");
        this.observationConvention = embeddingModelObservationConvention;
    }
}
