/*
 * Decompiled with CFR 0.152.
 */
package com.jxdinfo.hussar.vector.milvus.service.impl;

import com.google.gson.Gson;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonPrimitive;
import com.jxdinfo.hussar.core.exception.HussarException;
import com.jxdinfo.hussar.vector.milvus.annotation.MilvusBM25;
import com.jxdinfo.hussar.vector.milvus.annotation.MilvusCollection;
import com.jxdinfo.hussar.vector.milvus.annotation.MilvusField;
import com.jxdinfo.hussar.vector.milvus.annotation.MilvusIndex;
import com.jxdinfo.hussar.vector.milvus.dto.QueryCollectionData;
import com.jxdinfo.hussar.vector.milvus.service.IMilvusAnnotationService;
import com.jxdinfo.hussar.vector.milvus.util.MilvusUtil;
import com.jxdinfo.hussar.vector.milvus.wrapper.MilvusHybridSearchWrapper;
import com.jxdinfo.hussar.vector.milvus.wrapper.MilvusQueryWrapper;
import com.jxdinfo.hussar.vector.milvus.wrapper.MilvusSearchWrapper;
import io.milvus.common.clientenum.FunctionType;
import io.milvus.v2.client.MilvusClientV2;
import io.milvus.v2.common.DataType;
import io.milvus.v2.common.IndexParam;
import io.milvus.v2.service.collection.request.AddFieldReq;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import io.milvus.v2.service.collection.request.DropCollectionReq;
import io.milvus.v2.service.collection.request.GetCollectionStatsReq;
import io.milvus.v2.service.collection.response.GetCollectionStatsResp;
import io.milvus.v2.service.vector.request.InsertReq;
import io.milvus.v2.service.vector.request.QueryReq;
import io.milvus.v2.service.vector.request.UpsertReq;
import io.milvus.v2.service.vector.response.DeleteResp;
import io.milvus.v2.service.vector.response.InsertResp;
import io.milvus.v2.service.vector.response.QueryResp;
import io.milvus.v2.service.vector.response.SearchResp;
import io.milvus.v2.service.vector.response.UpsertResp;
import java.lang.reflect.Field;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

@Component
public class MilvusAnnotationServiceImpl
implements IMilvusAnnotationService {
    private final MilvusClientV2 milvusClientV2;
    private static final Logger logger = LoggerFactory.getLogger(MilvusAnnotationServiceImpl.class);
    private static final Map<Class<?>, Map<String, Field>> CLASS_FIELD_CACHE = new ConcurrentHashMap();

    public MilvusAnnotationServiceImpl(@Autowired(required=false) MilvusClientV2 milvusClientV2) {
        this.milvusClientV2 = milvusClientV2;
    }

    @Override
    public <T> List<Map<String, Object>> hybridSearchMap(MilvusHybridSearchWrapper<T> wrapper) {
        SearchResp searchResp = this.milvusClientV2.hybridSearch(wrapper.buildHybridSearchReq());
        if (searchResp.getSearchResults().isEmpty()) {
            return Collections.emptyList();
        }
        ArrayList<Map<String, Object>> dataList = new ArrayList<Map<String, Object>>();
        List searchResultList = (List)searchResp.getSearchResults().get(0);
        for (SearchResp.SearchResult sr : searchResultList) {
            dataList.add(sr.getEntity());
        }
        return dataList;
    }

    @Override
    public <T> List<Map<String, Object>> queryMapList(MilvusQueryWrapper<T> wrapper) {
        QueryResp queryResp = this.milvusClientV2.query(wrapper.buildQueryReq());
        if (queryResp.getQueryResults().isEmpty()) {
            return Collections.emptyList();
        }
        ArrayList<Map<String, Object>> dataList = new ArrayList<Map<String, Object>>();
        List queryResultList = queryResp.getQueryResults();
        for (QueryResp.QueryResult qr : queryResultList) {
            dataList.add(qr.getEntity());
        }
        return dataList;
    }

    @Override
    public <W, R> List<R> queryList(MilvusQueryWrapper<W> wrapper, Class<R> clazz) {
        QueryResp queryResp = this.milvusClientV2.query(wrapper.buildQueryReq());
        if (queryResp.getQueryResults().isEmpty()) {
            return Collections.emptyList();
        }
        List queryResultList = queryResp.getQueryResults();
        ArrayList<Map<String, Object>> entityList = new ArrayList<Map<String, Object>>();
        for (QueryResp.QueryResult qr : queryResultList) {
            entityList.add(qr.getEntity());
        }
        return this.convertDataListToObjectList(entityList, clazz);
    }

    @Override
    public <T> long upsertBatch(List<T> entityList, String id) {
        if (entityList == null || entityList.isEmpty()) {
            return 0L;
        }
        Class<?> clazz = entityList.get(0).getClass();
        String collectionName = MilvusUtil.getCollectionName(clazz, id);
        ArrayList<JsonObject> gsonDataList = new ArrayList<JsonObject>();
        for (T entity : entityList) {
            JsonObject gsonData = this.convertEntityToGsonData(entity);
            gsonDataList.add(gsonData);
        }
        return this.upsertToVectorDB(gsonDataList, collectionName);
    }

    @Override
    public <T> long upsert(T entity, String id) {
        if (entity == null) {
            return 0L;
        }
        Class<?> clazz = entity.getClass();
        String collectionName = MilvusUtil.getCollectionName(clazz, id);
        ArrayList<JsonObject> gsonDataList = new ArrayList<JsonObject>();
        gsonDataList.add(this.convertEntityToGsonData(entity));
        return this.upsertToVectorDB(gsonDataList, collectionName);
    }

    private long upsertToVectorDB(List<JsonObject> data, String collectionName) {
        UpsertResp upsertResp = this.milvusClientV2.upsert(UpsertReq.builder().collectionName(collectionName).data(data).build());
        logger.info("\u66f4\u65b0\u6210\u529f\uff0c\u66f4\u65b0\u8bb0\u5f55\u6570: {}", (Object)upsertResp.getUpsertCnt());
        return upsertResp.getUpsertCnt();
    }

    @Override
    public <T> long insertBatch(List<T> entityList, String id) {
        if (entityList == null || entityList.isEmpty()) {
            return 0L;
        }
        Class<?> clazz = entityList.get(0).getClass();
        String collectionName = MilvusUtil.getCollectionName(clazz, id);
        ArrayList<JsonObject> gsonDataList = new ArrayList<JsonObject>();
        for (T entity : entityList) {
            JsonObject gsonData = this.convertEntityToGsonData(entity);
            gsonDataList.add(gsonData);
        }
        return this.insertToVectorDB(collectionName, gsonDataList);
    }

    @Override
    public <T> long insert(T entity, String id) {
        if (entity == null) {
            return 0L;
        }
        Class<?> clazz = entity.getClass();
        String collectionName = MilvusUtil.getCollectionName(clazz, id);
        JsonObject gsonData = this.convertEntityToGsonData(entity);
        return this.insertToVectorDB(collectionName, Collections.singletonList(gsonData));
    }

    private long insertToVectorDB(String collectionName, List<JsonObject> data) {
        InsertReq insertReq = InsertReq.builder().collectionName(collectionName).data(data).build();
        InsertResp insertResp = this.milvusClientV2.insert(insertReq);
        logger.info("\u63d2\u5165\u6210\u529f\uff0c\u63d2\u5165\u8bb0\u5f55\u6570: {}", (Object)insertResp.getInsertCnt());
        return insertResp.getInsertCnt();
    }

    private <T> JsonObject convertEntityToGsonData(T entity) {
        JsonObject data = new JsonObject();
        Class<?> clazz = entity.getClass();
        Map<String, Field> fieldInfoCache = this.getFieldInfoCache(clazz);
        for (Map.Entry<String, Field> entry : fieldInfoCache.entrySet()) {
            String milvusFieldName = entry.getKey();
            Field field = entry.getValue();
            try {
                MilvusField milvusField;
                Object fieldValue = field.get(entity);
                if (fieldValue == null || (milvusField = field.getAnnotation(MilvusField.class)) != null && milvusField.isPrimaryKey() && milvusField.autoID()) continue;
                data.add(milvusFieldName, this.convertValueToJsonElement(fieldValue));
            }
            catch (IllegalAccessException e) {
                throw new RuntimeException("\u65e0\u6cd5\u8bbf\u95ee\u5b57\u6bb5: " + field.getName(), e);
            }
        }
        return data;
    }

    private JsonElement convertValueToJsonElement(Object value) {
        if (value == null) {
            return null;
        }
        if (value instanceof String) {
            return new JsonPrimitive((String)value);
        }
        if (value instanceof Number) {
            return new JsonPrimitive((Number)value);
        }
        if (value instanceof Boolean) {
            return new JsonPrimitive((Boolean)value);
        }
        if (value instanceof Character) {
            return new JsonPrimitive((Character)value);
        }
        Gson gson = new Gson();
        return gson.toJsonTree(value);
    }

    @Override
    public <T> long delete(MilvusQueryWrapper<T> wrapper) {
        DeleteResp deleteResp = this.milvusClientV2.delete(wrapper.buildDeleteReq());
        logger.info("\u5220\u9664\u6210\u529f\uff0c\u5220\u9664\u8bb0\u5f55\u6570: {}", (Object)deleteResp.getDeleteCnt());
        return deleteResp.getDeleteCnt();
    }

    @Override
    public <T> List<Map<String, Object>> searchMap(MilvusSearchWrapper<T> wrapper) {
        SearchResp searchResp = this.milvusClientV2.search(wrapper.buildSearchReq());
        if (searchResp.getSearchResults().isEmpty()) {
            return Collections.emptyList();
        }
        ArrayList<Map<String, Object>> dataList = new ArrayList<Map<String, Object>>();
        List searchResultList = (List)searchResp.getSearchResults().get(0);
        for (SearchResp.SearchResult sr : searchResultList) {
            dataList.add(sr.getEntity());
        }
        return dataList;
    }

    @Override
    public <T> List<T> search(MilvusSearchWrapper<T> wrapper) {
        SearchResp searchResp = this.milvusClientV2.search(wrapper.buildSearchReq());
        if (searchResp.getSearchResults().isEmpty()) {
            return Collections.emptyList();
        }
        ArrayList<Map<String, Object>> dataList = new ArrayList<Map<String, Object>>();
        List searchResultList = (List)searchResp.getSearchResults().get(0);
        for (SearchResp.SearchResult sr : searchResultList) {
            dataList.add(sr.getEntity());
        }
        return this.convertDataListToObjectList(dataList, wrapper.getClazz());
    }

    @Override
    public <T> List<T> queryList(MilvusQueryWrapper<T> wrapper) {
        QueryResp queryResp = this.milvusClientV2.query(wrapper.buildQueryReq());
        if (queryResp.getQueryResults().isEmpty()) {
            return Collections.emptyList();
        }
        ArrayList<Map<String, Object>> dataList = new ArrayList<Map<String, Object>>();
        List queryResultList = queryResp.getQueryResults();
        for (QueryResp.QueryResult qr : queryResultList) {
            dataList.add(qr.getEntity());
        }
        return this.convertDataListToObjectList(dataList, wrapper.getClazz());
    }

    private <T> List<T> convertDataListToObjectList(List<Map<String, Object>> dataList, Class<T> clazz) {
        ArrayList<T> resultList = new ArrayList<T>(dataList.size());
        Map<String, Field> fieldInfoCache = this.getFieldInfoCache(clazz);
        for (Map<String, Object> data : dataList) {
            try {
                T instance = clazz.getDeclaredConstructor(new Class[0]).newInstance(new Object[0]);
                for (Map.Entry<String, Object> entry : data.entrySet()) {
                    Field f = fieldInfoCache.get(entry.getKey());
                    if (f == null) continue;
                    f.set(instance, entry.getValue());
                }
                resultList.add(instance);
            }
            catch (Exception e) {
                throw new RuntimeException("\u8f6c\u6362\u5931\u8d25", e);
            }
        }
        return resultList;
    }

    private <T> Map<String, Field> getFieldInfoCache(Class<T> clazz) {
        return CLASS_FIELD_CACHE.computeIfAbsent(clazz, c -> {
            Field[] fields;
            HashMap<String, Field> fieldMap = new HashMap<String, Field>();
            for (Field field : fields = c.getDeclaredFields()) {
                MilvusField milvusField = field.getAnnotation(MilvusField.class);
                field.setAccessible(true);
                if (milvusField != null) {
                    String milvusFieldName = milvusField.name().isEmpty() ? field.getName() : milvusField.name();
                    fieldMap.put(milvusFieldName, field);
                    continue;
                }
                fieldMap.put(field.getName(), field);
            }
            return fieldMap;
        });
    }

    private Object convertValue(Object value, Class<?> targetType) {
        if (value == null) {
            if (targetType.isPrimitive()) {
                return this.getPrimitiveDefaultValue(targetType);
            }
            return null;
        }
        if (targetType.isInstance(value)) {
            return value;
        }
        if (targetType == String.class) {
            return value.toString();
        }
        if (targetType == Integer.class || targetType == Integer.TYPE) {
            if (value instanceof Number) {
                return ((Number)value).intValue();
            }
            return Integer.valueOf(value.toString());
        }
        if (targetType == Long.class || targetType == Long.TYPE) {
            if (value instanceof Number) {
                return ((Number)value).longValue();
            }
            return Long.valueOf(value.toString());
        }
        if (targetType == Double.class || targetType == Double.TYPE) {
            if (value instanceof Number) {
                return ((Number)value).doubleValue();
            }
            return Double.valueOf(value.toString());
        }
        if (targetType == Float.class || targetType == Float.TYPE) {
            if (value instanceof Number) {
                return Float.valueOf(((Number)value).floatValue());
            }
            return Float.valueOf(value.toString());
        }
        if (targetType == Boolean.class || targetType == Boolean.TYPE) {
            if (value instanceof Boolean) {
                return value;
            }
            return Boolean.valueOf(value.toString());
        }
        if (targetType == BigDecimal.class) {
            if (value instanceof Number) {
                return BigDecimal.valueOf(((Number)value).doubleValue());
            }
            return new BigDecimal(value.toString());
        }
        if (targetType == List.class && value instanceof Object[]) {
            return Arrays.asList((Object[])value);
        }
        return value;
    }

    private Object getPrimitiveDefaultValue(Class<?> primitiveType) {
        if (primitiveType == Integer.TYPE) {
            return 0;
        }
        if (primitiveType == Long.TYPE) {
            return 0L;
        }
        if (primitiveType == Double.TYPE) {
            return 0.0;
        }
        if (primitiveType == Float.TYPE) {
            return Float.valueOf(0.0f);
        }
        if (primitiveType == Boolean.TYPE) {
            return false;
        }
        if (primitiveType == Byte.TYPE) {
            return (byte)0;
        }
        if (primitiveType == Short.TYPE) {
            return (short)0;
        }
        if (primitiveType == Character.TYPE) {
            return Character.valueOf('\u0000');
        }
        return null;
    }

    @Override
    public void dropCollection(Class<?> clazz, String id) {
        this.dropCollection(MilvusUtil.getCollectionName(clazz, id));
    }

    private void dropCollection(String collectionName) {
        this.milvusClientV2.dropCollection(DropCollectionReq.builder().collectionName(collectionName).build());
    }

    @Override
    public void createCollection(Class<?> clazz, String id) {
        CreateCollectionReq.CollectionSchema schema = this.milvusClientV2.createSchema();
        ArrayList<IndexParam> indexParams = new ArrayList<IndexParam>();
        MilvusCollection mc = clazz.getAnnotation(MilvusCollection.class);
        if (mc == null) {
            throw new RuntimeException("\u7f3a\u5c11'@MilvusCollection'\u6ce8\u89e3\uff0c\u65e0\u6cd5\u521b\u5efa\u96c6\u5408");
        }
        this.processFields(clazz, schema, indexParams);
        this.createCollection(MilvusUtil.getCollectionName(clazz, id), mc.description(), schema, indexParams);
    }

    @Override
    public long count(Class<?> clazz, String id) {
        GetCollectionStatsReq getCollectionStatsReq = GetCollectionStatsReq.builder().collectionName(MilvusUtil.getCollectionName(clazz, id)).build();
        GetCollectionStatsResp r = this.milvusClientV2.getCollectionStats(getCollectionStatsReq);
        return r.getNumOfEntities();
    }

    @Override
    public List<Map<String, Object>> queryCollectionData(QueryCollectionData queryCollectionData) {
        QueryResp queryDataResp;
        if (queryCollectionData == null) {
            throw new HussarException("\u53c2\u6570\u5f02\u5e38");
        }
        String collectionName = queryCollectionData.getCollectionName();
        if (collectionName == null || collectionName.isEmpty()) {
            throw new HussarException("\u96c6\u5408\u540d\u79f0\u4e0d\u80fd\u4e3a\u7a7a");
        }
        String filter = queryCollectionData.getFilter();
        QueryReq queryReq = QueryReq.builder().collectionName(collectionName).filter(filter).build();
        if (queryCollectionData.getFields() != null && !queryCollectionData.getFields().isEmpty()) {
            queryReq.setOutputFields(queryCollectionData.getFields());
        }
        Long current = queryCollectionData.getCurrent();
        Long size = queryCollectionData.getSize();
        if (current != null && size != null && current > 0L && size > 0L) {
            queryReq.setOffset((current - 1L) * size);
            queryReq.setLimit(size.longValue());
        }
        if ((queryDataResp = this.milvusClientV2.query(queryReq)).getQueryResults().isEmpty()) {
            return Collections.emptyList();
        }
        ArrayList<Map<String, Object>> resultList = new ArrayList<Map<String, Object>>();
        for (QueryResp.QueryResult result : queryDataResp.getQueryResults()) {
            resultList.add(result.getEntity());
        }
        return resultList;
    }

    private void processFields(Class<?> clazz, CreateCollectionReq.CollectionSchema schema, List<IndexParam> indexParams) {
        for (Field field : clazz.getDeclaredFields()) {
            MilvusBM25 mb25;
            MilvusField mf = field.getAnnotation(MilvusField.class);
            if (mf == null) continue;
            String fieldName = mf.name();
            if (fieldName.isEmpty()) {
                fieldName = field.getName();
            }
            AddFieldReq addFieldReq = AddFieldReq.builder().fieldName(fieldName).description(mf.description()).dataType(mf.dataType()).enableAnalyzer(Boolean.valueOf(mf.enableAnalyzer())).isPrimaryKey(Boolean.valueOf(mf.isPrimaryKey())).build();
            if (mf.isPrimaryKey()) {
                addFieldReq.setAutoID(Boolean.valueOf(mf.autoID()));
            } else {
                addFieldReq.setIsNullable(Boolean.valueOf(mf.isNullable()));
            }
            if (mf.dataType() == DataType.VarChar) {
                if (mf.maxLength() <= 0) {
                    throw new RuntimeException("VarChar\u7c7b\u578b\u7684\u5b57\u6bb5\u5fc5\u987b\u8bbe\u7f6emaxLength");
                }
                addFieldReq.setMaxLength(Integer.valueOf(mf.maxLength()));
                if (mf.enableAnalyzer()) {
                    addFieldReq.setEnableAnalyzer(Boolean.valueOf(true));
                    addFieldReq.setAnalyzerParams(mf.analyzerType().getAnalyzerParams());
                }
            } else if (mf.dataType() == DataType.FloatVector) {
                if (mf.dimension() <= 0) {
                    throw new RuntimeException("FloatVector\u7c7b\u578b\u7684\u5b57\u6bb5\u5fc5\u987b\u8bbe\u7f6edimension");
                }
                addFieldReq.setDimension(Integer.valueOf(mf.dimension()));
            }
            schema.addField(addFieldReq);
            MilvusIndex mi = field.getAnnotation(MilvusIndex.class);
            if (mi != null) {
                IndexParam indexParam = IndexParam.builder().fieldName(fieldName).indexName(MilvusUtil.generateIndexName(fieldName)).indexType(mi.indexType()).build();
                if (MilvusUtil.isVectorType(mf.dataType())) {
                    indexParam.setMetricType(mi.metricType());
                }
                indexParams.add(indexParam);
            }
            if ((mb25 = field.getAnnotation(MilvusBM25.class)) == null) continue;
            String bm25FieldName = MilvusUtil.generateBM25FieldName(fieldName);
            schema.addField(AddFieldReq.builder().fieldName(bm25FieldName).dataType(DataType.SparseFloatVector).build());
            schema.addFunction(CreateCollectionReq.Function.builder().functionType(FunctionType.BM25).name(MilvusUtil.generateBM25FunctionName(fieldName)).inputFieldNames(Collections.singletonList(fieldName)).outputFieldNames(Collections.singletonList(bm25FieldName)).build());
            indexParams.add(IndexParam.builder().fieldName(bm25FieldName).indexName(MilvusUtil.generateBM25IndexName(fieldName)).indexType(IndexParam.IndexType.AUTOINDEX).metricType(IndexParam.MetricType.BM25).build());
        }
    }

    private void createCollection(String collectionName, String description, CreateCollectionReq.CollectionSchema schema, List<IndexParam> indexParams) {
        this.milvusClientV2.createCollection(CreateCollectionReq.builder().collectionName(collectionName).description(description).collectionSchema(schema).indexParams(indexParams).build());
    }
}

