/*
 * Decompiled with CFR 0.152.
 */
package com.jxdinfo.hussar.ai.trainingset.service.impl;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.conditions.Wrapper;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.jxdinfo.hussar.ai.common.QAConstant;
import com.jxdinfo.hussar.ai.config.PromptMap;
import com.jxdinfo.hussar.ai.config.PromptMapProperties;
import com.jxdinfo.hussar.ai.qaGroup.config.ModelConfig;
import com.jxdinfo.hussar.ai.qaGroup.model.QAGroupBase;
import com.jxdinfo.hussar.ai.qaGroup.model.QAGroupPair;
import com.jxdinfo.hussar.ai.qaGroup.service.IQAGroupBaseService;
import com.jxdinfo.hussar.ai.qaGroup.service.IQAGroupPairService;
import com.jxdinfo.hussar.ai.ragmanager.model.QARagModel;
import com.jxdinfo.hussar.ai.ragmanager.service.QARagModelService;
import com.jxdinfo.hussar.ai.trainingset.dao.QATrainingsetMapper;
import com.jxdinfo.hussar.ai.trainingset.dto.GenerateTrainSetDTO;
import com.jxdinfo.hussar.ai.trainingset.dto.ModelDetail;
import com.jxdinfo.hussar.ai.trainingset.model.QAPromptTemplate;
import com.jxdinfo.hussar.ai.trainingset.model.QATrainingset;
import com.jxdinfo.hussar.ai.trainingset.service.QAPromptTemplateService;
import com.jxdinfo.hussar.ai.trainingset.service.QATrainingsetService;
import com.jxdinfo.hussar.ai.trainingset.vo.QATrainingsetVO;
import com.jxdinfo.hussar.ai.util.YamlUtil;
import com.jxdinfo.hussar.common.base.HussarBaseEntity;
import com.jxdinfo.hussar.common.util.IqaHttpClientUtil;
import com.jxdinfo.hussar.platform.core.utils.BeanUtil;
import com.jxdinfo.hussar.platform.core.utils.HussarUtils;
import com.jxdinfo.hussar.support.exception.HussarException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import javax.annotation.Resource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;

@Service
public class QATrainingsetServiceImpl
extends ServiceImpl<QATrainingsetMapper, QATrainingset>
implements QATrainingsetService {
    private static final Logger logger = LoggerFactory.getLogger(QATrainingsetServiceImpl.class);
    @Resource
    private IQAGroupBaseService iqaGroupBaseService;
    @Resource
    private IQAGroupPairService iqaGroupPairService;
    @Resource
    private QAPromptTemplateService qaPromptTemplateService;
    @Resource
    private QARagModelService qaRagModelService;
    @Resource
    private ModelConfig modelConfig;
    @Resource
    private PromptMapProperties promptMapProperties;

    @Override
    public IPage<QATrainingsetVO> listQATrainingset(Page<QATrainingset> page, QATrainingset model) {
        if (HussarUtils.isNotEmpty(page) && HussarUtils.isNotEmpty((Object)((Object)model))) {
            LambdaQueryWrapper trainingsetLambdaQueryWrapper = new LambdaQueryWrapper();
            trainingsetLambdaQueryWrapper.like(HussarUtils.isNotEmpty((Object)model.getTrainingsetName()), QATrainingset::getTrainingsetName, (Object)model.getTrainingsetName());
            trainingsetLambdaQueryWrapper.eq(HussarUtils.isNotEmpty((Object)model.getModelType()), QATrainingset::getModelType, (Object)model.getModelType());
            trainingsetLambdaQueryWrapper.orderByDesc(HussarBaseEntity::getCreateTime);
            Page trainingsetPage = (Page)this.page((IPage)page, (Wrapper)trainingsetLambdaQueryWrapper);
            List qaTrainingsets = trainingsetPage.getRecords();
            Page qaTrainingsetVOPage = new Page();
            BeanUtil.copyProperties((Object)trainingsetPage, (Object)qaTrainingsetVOPage);
            if (HussarUtils.isNotEmpty((Object)qaTrainingsets)) {
                List ragStrategyIds = qaTrainingsets.stream().map(QATrainingset::getRagStrategyId).collect(Collectors.toList());
                List promptIds = qaTrainingsets.stream().map(QATrainingset::getPromptId).collect(Collectors.toList());
                List datasetIds = qaTrainingsets.stream().map(QATrainingset::getDatasetId).collect(Collectors.toList());
                LambdaQueryWrapper qaRagModelLambdaQueryWrapper = new LambdaQueryWrapper();
                qaRagModelLambdaQueryWrapper.in(QARagModel::getModelId, ragStrategyIds);
                List ragModels = this.qaRagModelService.list((Wrapper)qaRagModelLambdaQueryWrapper);
                Map<Long, String> ragModelMap = ragModels.stream().collect(Collectors.toMap(QARagModel::getModelId, QARagModel::getModelName));
                LambdaQueryWrapper groupBaseLambdaQueryWrapper = new LambdaQueryWrapper();
                groupBaseLambdaQueryWrapper.in(QAGroupBase::getQaGroupId, datasetIds);
                List groupBaseList = this.iqaGroupBaseService.list((Wrapper)groupBaseLambdaQueryWrapper);
                Map<Long, String> groupBaseMap = groupBaseList.stream().collect(Collectors.toMap(QAGroupBase::getQaGroupId, QAGroupBase::getQaGroupName));
                LambdaQueryWrapper templateLambdaQueryWrapper = new LambdaQueryWrapper();
                templateLambdaQueryWrapper.in(QAPromptTemplate::getPromptId, promptIds);
                List promptTemplates = this.qaPromptTemplateService.list((Wrapper)templateLambdaQueryWrapper);
                Map<Long, String> promptTempleteMap = promptTemplates.stream().collect(Collectors.toMap(QAPromptTemplate::getPromptId, QAPromptTemplate::getPromptName));
                List trainingsetVOS = BeanUtil.copyToList((Collection)qaTrainingsets, QATrainingsetVO.class);
                for (QATrainingsetVO trainingVO : trainingsetVOS) {
                    trainingVO.setRagStrategyName(ragModelMap.get(trainingVO.getRagStrategyId()));
                    trainingVO.setDatasetName(groupBaseMap.get(trainingVO.getDatasetId()));
                    trainingVO.setPromptName(promptTempleteMap.get(trainingVO.getPromptId()));
                }
                qaTrainingsetVOPage.setRecords(trainingsetVOS);
            }
            return qaTrainingsetVOPage;
        }
        throw new HussarException("\u67e5\u8be2\u6761\u4ef6\u4e0d\u80fd\u4e3a\u7a7a\uff01");
    }

    @Override
    @Async(value="ragTaskExecutor")
    public void generateTrainSet(Long trainingId) {
        QATrainingset trainingset = (QATrainingset)((Object)this.getById(trainingId));
        GenerateTrainSetDTO generateTrainSetDTO = new GenerateTrainSetDTO();
        BeanUtil.copyProperties((Object)((Object)trainingset), (Object)generateTrainSetDTO);
        generateTrainSetDTO.setTrainingId(String.valueOf(trainingset.getTrainingsetId()));
        generateTrainSetDTO.setTrainingName(trainingset.getTrainingsetName());
        LambdaQueryWrapper qaGroupPairLambdaQueryWrapper = new LambdaQueryWrapper();
        qaGroupPairLambdaQueryWrapper.eq(HussarUtils.isNotEmpty((Object)trainingset.getDatasetId()), QAGroupPair::getQaGroupId, (Object)trainingset.getDatasetId());
        List qaGroupPairs = this.iqaGroupPairService.list((Wrapper)qaGroupPairLambdaQueryWrapper);
        List modelDetails = BeanUtil.copyToList((Collection)qaGroupPairs, ModelDetail.class);
        generateTrainSetDTO.setModelDetails(modelDetails);
        QARagModel qaRagModel = (QARagModel)((Object)this.qaRagModelService.getById(trainingset.getRagStrategyId()));
        String modelArg = qaRagModel.getModelArg();
        generateTrainSetDTO.setModelArg(JSON.parseObject((String)JSON.toJSONString(YamlUtil.transferToMap(modelArg))));
        generateTrainSetDTO.setNowledgeBaseId(String.valueOf(qaRagModel.getModelId()));
        String replace = ((QAPromptTemplate)((Object)this.qaPromptTemplateService.getById(trainingset.getPromptId()))).getTemplateContent().replace("/*#META*/", "{").replace("/*META#*/", "}");
        List<PromptMap> promptMapList = this.promptMapProperties.getPromptMapList();
        for (PromptMap promptMap : promptMapList) {
            replace = replace.replace(promptMap.getFrontValue(), promptMap.getBigModelValue());
        }
        generateTrainSetDTO.setPromptTemplate(replace);
        logger.info("generateTrainSet============\u5f00\u59cb===========>" + JSON.toJSONString((Object)generateTrainSetDTO));
        String result = IqaHttpClientUtil.httpPost((String)this.modelConfig.getGenerateTrainSetUrl(), (String)JSONObject.toJSONString((Object)generateTrainSetDTO), new HashMap());
        logger.info("generateTrainSet============\u7ed3\u679c\u8fd4\u56de===========>" + result);
        JSONObject jsonObject = JSONObject.parseObject((String)result);
        Boolean success = jsonObject.getBoolean("success");
        String msg = jsonObject.getString("msg");
        if (!success.booleanValue() || HussarUtils.isEmpty((Object)success)) {
            this.updateTrainStatus(trainingId, QAConstant.TRAINING_TASK_DO_ERROR.getStatus(), jsonObject.toJSONString());
            throw new HussarException(msg);
        }
    }

    @Override
    public void updateTrainStatus(Long trainingId, String status, String reason) {
        HussarException.throwByNull((Object)trainingId, (String)"\u8bad\u7ec3\u96c6id\u4e3a\u7a7a");
        HussarException.throwByNull((Object)status, (String)"\u8bad\u7ec3\u96c6\u72b6\u6001\u4e3a\u7a7a");
        LambdaUpdateWrapper trainingsetLambdaUpdateWrapper = new LambdaUpdateWrapper();
        trainingsetLambdaUpdateWrapper.eq(QATrainingset::getTrainingsetId, (Object)trainingId);
        trainingsetLambdaUpdateWrapper.set(QATrainingset::getTrainingsetStatus, (Object)status);
        trainingsetLambdaUpdateWrapper.set(HussarUtils.isNotEmpty((Object)reason), QATrainingset::getErrorReason, (Object)reason);
        this.update((Wrapper)trainingsetLambdaUpdateWrapper);
    }

    @Override
    public List<QATrainingset> findAllByBasicModelType(String basicModelType) {
        LambdaQueryWrapper templateLambdaQueryWrapper = new LambdaQueryWrapper();
        templateLambdaQueryWrapper.eq(QAPromptTemplate::getModelType, (Object)basicModelType);
        List promptTemplates = this.qaPromptTemplateService.list((Wrapper)templateLambdaQueryWrapper);
        if (HussarUtils.isEmpty((Object)promptTemplates)) {
            return new ArrayList<QATrainingset>();
        }
        List promptIds = promptTemplates.stream().map(QAPromptTemplate::getPromptId).collect(Collectors.toList());
        LambdaQueryWrapper trainingsetLambdaQueryWrapper = new LambdaQueryWrapper();
        trainingsetLambdaQueryWrapper.in(QATrainingset::getPromptId, promptIds);
        return this.list((Wrapper)trainingsetLambdaQueryWrapper);
    }

    @Override
    public QATrainingsetVO detail(Long trainingId) {
        QATrainingset trainingset = (QATrainingset)((Object)this.getById(trainingId));
        QATrainingsetVO trainingsetVO = new QATrainingsetVO();
        BeanUtil.copyProperties((Object)((Object)trainingset), (Object)((Object)trainingsetVO));
        QARagModel qaRagModel = (QARagModel)((Object)this.qaRagModelService.getById(trainingsetVO.getRagStrategyId()));
        trainingsetVO.setRagStrategyName(HussarUtils.isNotEmpty((Object)((Object)qaRagModel)) ? qaRagModel.getModelName() : null);
        return trainingsetVO;
    }
}

