/*
 * Decompiled with CFR 0.152.
 */
package com.jxdinfo.hussar.ai.modelManagement.service.impl;

import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.conditions.Wrapper;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.jxdinfo.hussar.ai.datamanager.vo.IndicatorEffectChart;
import com.jxdinfo.hussar.ai.modelManagement.dao.QueueTrainingTaskMapper;
import com.jxdinfo.hussar.ai.modelManagement.dao.TrainingTaskMapper;
import com.jxdinfo.hussar.ai.modelManagement.dto.TrainingTaskDto;
import com.jxdinfo.hussar.ai.modelManagement.dto.TrainingTaskLogDto;
import com.jxdinfo.hussar.ai.modelManagement.model.MediumModel;
import com.jxdinfo.hussar.ai.modelManagement.model.QueueTrainingTask;
import com.jxdinfo.hussar.ai.modelManagement.model.TrainingTask;
import com.jxdinfo.hussar.ai.modelManagement.service.IMediumModelService;
import com.jxdinfo.hussar.ai.modelManagement.service.IQueueTrainingTaskService;
import com.jxdinfo.hussar.ai.modelManagement.service.ITrainingTaskService;
import com.jxdinfo.hussar.ai.modelManagement.vo.TrainingParamsVo;
import com.jxdinfo.hussar.ai.modelManagement.vo.TrainingTaskDetailVo;
import com.jxdinfo.hussar.ai.modelManagement.vo.TrainingTaskLog;
import com.jxdinfo.hussar.ai.qaGroup.config.ModelConfig;
import com.jxdinfo.hussar.ai.ragmanager.model.QARagModel;
import com.jxdinfo.hussar.ai.ragmanager.service.QARagModelService;
import com.jxdinfo.hussar.ai.trainingset.model.QATrainingsetDetail;
import com.jxdinfo.hussar.ai.trainingset.service.QATrainingsetDetailService;
import com.jxdinfo.hussar.ai.util.FileUtils;
import com.jxdinfo.hussar.ai.util.YamlUtil;
import com.jxdinfo.hussar.common.base.HussarBaseEntity;
import com.jxdinfo.hussar.platform.core.base.apiresult.ApiResponse;
import com.jxdinfo.hussar.platform.core.utils.FileUtil;
import com.jxdinfo.hussar.platform.core.utils.HussarUtils;
import com.jxdinfo.hussar.support.exception.HussarException;
import com.jxdinfo.hussar.support.transaction.core.annotation.HussarTransactional;
import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Resource;
import org.springframework.stereotype.Service;

@Service
public class TrainingTaskServiceImpl
extends ServiceImpl<TrainingTaskMapper, TrainingTask>
implements ITrainingTaskService {
    @Resource
    private TrainingTaskMapper trainingTaskMapper;
    @Resource
    private IMediumModelService mediumModelService;
    @Resource
    private QATrainingsetDetailService qaTrainingsetDetailService;
    @Resource
    private QARagModelService qaRagModelService;
    @Resource
    private ModelConfig modelConfig;
    @Resource
    private IQueueTrainingTaskService queueTrainingTaskService;
    @Resource
    private QueueTrainingTaskMapper queueTrainingTaskMapper;

    @Override
    public Map<String, Object> add(TrainingTaskDto trainingTaskDto) {
        HashMap<String, Object> map = new HashMap<String, Object>();
        HussarException.throwByNull((Object)trainingTaskDto, (String)"\u8bad\u7ec3\u4efb\u52a1\u4fe1\u606f\u4e3a\u7a7a");
        TrainingTask trainingTask = new TrainingTask();
        if (HussarUtils.isNotEmpty((Object)trainingTaskDto.getTaskId())) {
            trainingTask.setTaskId(trainingTaskDto.getTaskId());
        }
        trainingTask.setTaskName(trainingTaskDto.getTaskName());
        trainingTask.setBasicModelType(trainingTaskDto.getBasicModelType());
        trainingTask.setFineTuningType(trainingTaskDto.getFineTuningType());
        trainingTask.setTrainingSetId(trainingTaskDto.getTrainingSetId());
        trainingTask.setRagModelId(trainingTaskDto.getRagModelId());
        trainingTask.setTrainingParams(JSONObject.toJSONString((Object)trainingTaskDto.getTrainingParams()));
        this.saveOrUpdate((Object)trainingTask);
        map.put("trainingTaskId", trainingTask.getTaskId());
        return map;
    }

    @Override
    public TrainingTaskDetailVo qryTrainingTaskDetail(Long taskId) {
        HussarException.throwByNull((Object)taskId, (String)"\u8bad\u7ec3\u4efb\u52a1id\u4e3a\u7a7a");
        TrainingTask trainingTask = (TrainingTask)((Object)this.trainingTaskMapper.selectById(taskId));
        if (HussarUtils.isNotEmpty((Object)((Object)trainingTask))) {
            TrainingTaskDetailVo trainingTaskDetailVo = new TrainingTaskDetailVo();
            String trainingParams = trainingTask.getTrainingParams();
            TrainingParamsVo trainingParamsVo = (TrainingParamsVo)JSONObject.parseObject((String)trainingParams, TrainingParamsVo.class);
            trainingTaskDetailVo.setTaskId(trainingTask.getTaskId());
            trainingTaskDetailVo.setTaskName(trainingTask.getTaskName());
            trainingTaskDetailVo.setTrainingSetId(trainingTask.getTrainingSetId());
            trainingTaskDetailVo.setRagModelId(trainingTask.getRagModelId());
            trainingTaskDetailVo.setTrainingStatus(trainingTask.getTrainingStatus());
            trainingTaskDetailVo.setPublishStatus(trainingTask.getPublishStatus());
            trainingTaskDetailVo.setBasicModelType(trainingTask.getBasicModelType());
            trainingTaskDetailVo.setFineTuningType(trainingTask.getFineTuningType());
            trainingTaskDetailVo.setTrainingParamsVo(trainingParamsVo);
            return trainingTaskDetailVo;
        }
        return null;
    }

    @Override
    @HussarTransactional
    public Boolean deleteByIds(List<Long> taskIds) {
        LambdaQueryWrapper qw = new LambdaQueryWrapper();
        ((LambdaQueryWrapper)qw.in(TrainingTask::getTaskId, taskIds)).eq(TrainingTask::getTrainingStatus, (Object)"1");
        List trainingTasks = this.trainingTaskMapper.selectList((Wrapper)qw);
        if (HussarUtils.isNotEmpty((Object)trainingTasks)) {
            HussarException.throwBy((boolean)true, (String)"\u6240\u9009\u8bad\u7ec3\u4efb\u52a1\u4e2d\u5305\u542b\u8bad\u7ec3\u4e2d\u7684\u4efb\u52a1\uff0c\u8bf7\u91cd\u65b0\u9009\u62e9");
        }
        Boolean flag = this.removeByIds(taskIds);
        LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper();
        queryWrapper.in(QueueTrainingTask::getTaskId, taskIds);
        this.queueTrainingTaskService.remove((Wrapper)queryWrapper);
        return flag;
    }

    @Override
    public Page<TrainingTask> list(Page<TrainingTask> page, String taskName, String basicModelType) {
        List records;
        LambdaQueryWrapper qw = new LambdaQueryWrapper();
        qw.eq(TrainingTask::getDelFlag, (Object)"0");
        if (HussarUtils.isNotEmpty((Object)taskName)) {
            qw.like(TrainingTask::getTaskName, (Object)taskName);
        }
        if (HussarUtils.isNotEmpty((Object)basicModelType)) {
            qw.eq(TrainingTask::getBasicModelType, (Object)basicModelType);
        }
        qw.orderByDesc(HussarBaseEntity::getCreateTime);
        Page trainingTaskPage = (Page)this.trainingTaskMapper.selectPage((IPage)page, (Wrapper)qw);
        if (HussarUtils.isNotEmpty((Object)trainingTaskPage) && HussarUtils.isNotEmpty((Object)(records = trainingTaskPage.getRecords()))) {
            LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper();
            for (TrainingTask trainingTask : records) {
                queryWrapper.eq(MediumModel::getTaskId, (Object)trainingTask.getTaskId());
                Long count = this.mediumModelService.getBaseMapper().selectCount((Wrapper)queryWrapper);
                trainingTask.setMediumModelCount(count);
                queryWrapper.clear();
            }
        }
        return trainingTaskPage;
    }

    @Override
    @HussarTransactional
    public Map<String, String> startTraining(Long taskId) {
        HashMap<String, String> returnMap = new HashMap<String, String>();
        TrainingTask trainingTask = (TrainingTask)((Object)this.trainingTaskMapper.selectById(taskId));
        if (!"0".equals(trainingTask.getTrainingStatus())) {
            HussarException.throwBy((boolean)true, (String)"\u8be5\u8bad\u7ec3\u4efb\u52a1\u6682\u65f6\u4e0d\u80fd\u8bad\u7ec3");
        }
        LambdaQueryWrapper qw = new LambdaQueryWrapper();
        qw.eq(TrainingTask::getTrainingStatus, (Object)"1");
        Long count = this.trainingTaskMapper.selectCount((Wrapper)qw);
        Long trainingTaskThreshold = this.modelConfig.getTrainingTaskThreshold();
        if (count >= trainingTaskThreshold) {
            QueueTrainingTask queueTrainingTask = new QueueTrainingTask();
            queueTrainingTask.setTaskId(taskId);
            this.queueTrainingTaskService.save((Object)queueTrainingTask);
            TrainingTask task = new TrainingTask();
            task.setTaskId(taskId);
            task.setTrainingStatus("4");
            this.trainingTaskMapper.updateById((Object)task);
            returnMap.put("trainingStatus", "4");
        } else {
            ApiResponse apiResponse = ApiResponse.success();
            if (10000L == (long)apiResponse.getCode()) {
                TrainingTask task = new TrainingTask();
                task.setTaskId(taskId);
                task.setTrainingStatus("1");
                this.trainingTaskMapper.updateById((Object)task);
                LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper();
                queryWrapper.eq(QueueTrainingTask::getTaskId, (Object)taskId);
                this.queueTrainingTaskService.remove((Wrapper)queryWrapper);
                returnMap.put("trainingStatus", "1");
            }
        }
        return returnMap;
    }

    @Override
    public Boolean stopTraining(Long taskId) {
        TrainingTask trainingTask = (TrainingTask)((Object)this.trainingTaskMapper.selectById(taskId));
        HussarException.throwByNull((Object)((Object)trainingTask), (String)"\u672a\u67e5\u8be2\u5230\u8bad\u7ec3\u4efb\u52a1\u4fe1\u606f");
        if (!"1".equals(trainingTask.getTrainingStatus())) {
            HussarException.throwBy((boolean)true, (String)"\u8be5\u4efb\u52a1\u672a\u5728\u8bad\u7ec3\u4e2d\uff0c\u4e0d\u53ef\u4ee5\u7ec8\u6b62\uff01");
        }
        HashMap<String, Object> params = new HashMap<String, Object>();
        params.put("taskId", taskId);
        params.put("taskName", trainingTask.getTaskName());
        params.put("basicModelType", trainingTask.getBasicModelType());
        String stopTrainingUrl = this.modelConfig.getStopTrainingUrl();
        ApiResponse apiResponse = ApiResponse.success();
        if (10000L == (long)apiResponse.getCode()) {
            TrainingTask task = new TrainingTask();
            task.setTaskId(taskId);
            task.setTrainingStatus("3");
            this.trainingTaskMapper.updateById((Object)task);
        }
        return true;
    }

    @Override
    public Boolean recoverTraining(Long taskId, Long mediunModelId) {
        TrainingTask trainingTask = (TrainingTask)((Object)this.trainingTaskMapper.selectById(taskId));
        HussarException.throwByNull((Object)((Object)trainingTask), (String)"\u672a\u67e5\u8be2\u5230\u8bad\u7ec3\u4efb\u52a1\u4fe1\u606f");
        if (!"3".equals(trainingTask.getTrainingStatus())) {
            HussarException.throwBy((boolean)true, (String)"\u8be5\u4efb\u52a1\u672a\u7ec8\u6b62\uff0c\u4e0d\u53ef\u4ee5\u6062\u590d\u8bad\u7ec3\uff01");
        }
        MediumModel mediumModel = (MediumModel)((Object)this.mediumModelService.getById(mediunModelId));
        String recoverTrainingUrl = this.modelConfig.getRecoverTrainingUrl();
        ApiResponse apiResponse = ApiResponse.success();
        if (10000L == (long)apiResponse.getCode()) {
            TrainingTask task = new TrainingTask();
            task.setTaskId(taskId);
            task.setTrainingStatus("1");
            this.trainingTaskMapper.updateById((Object)task);
        }
        return true;
    }

    @Override
    public Boolean finishTraining(TrainingTask trainingTask) {
        HussarException.throwByNull((Object)((Object)trainingTask), (String)"\u8bad\u7ec3\u4efb\u52a1\u4fe1\u606f\u4e3a\u7a7a");
        HussarException.throwByNull((Object)trainingTask.getTaskId(), (String)"\u8bad\u7ec3\u4efb\u52a1id\u4e3a\u7a7a");
        Long taskId = trainingTask.getTaskId();
        TrainingTask task = (TrainingTask)((Object)this.trainingTaskMapper.selectById(taskId));
        HussarException.throwByNull((Object)((Object)task), (String)"\u672a\u67e5\u8be2\u5230\u8be5\u8bad\u7ec3\u4efb\u52a1");
        TrainingTask updateTask = new TrainingTask();
        updateTask.setTaskId(taskId);
        updateTask.setTrainingStatus("2");
        this.trainingTaskMapper.updateById((Object)updateTask);
        LambdaQueryWrapper qw = new LambdaQueryWrapper();
        qw.orderByDesc(HussarBaseEntity::getCreateTime);
        List queueList = this.queueTrainingTaskService.list((Wrapper)qw);
        if (HussarUtils.isNotEmpty((Object)queueList)) {
            QueueTrainingTask queueTrainingTask = (QueueTrainingTask)((Object)queueList.get(0));
            this.startTraining(queueTrainingTask.getTaskId());
        }
        return true;
    }

    @Override
    public Boolean saveLogs(TrainingTaskLogDto trainingTaskLogDto) {
        HussarException.throwByNull((Object)trainingTaskLogDto, (String)"\u5165\u53c2\u4e3a\u7a7a");
        HussarException.throwByNull((Object)trainingTaskLogDto.getTaskId(), (String)"\u8bad\u7ec3\u4efb\u52a1id\u4e3a\u7a7a");
        HussarException.throwByNull((Object)trainingTaskLogDto.getLog(), (String)"\u8bad\u7ec3\u4efb\u52a1\u65e5\u5fd7\u4e3a\u7a7a");
        Long taskId = trainingTaskLogDto.getTaskId();
        TrainingTask trainingTask = (TrainingTask)((Object)this.trainingTaskMapper.selectById(taskId));
        HussarException.throwByNull((Object)((Object)trainingTask), (String)"\u8bad\u7ec3\u4efb\u52a1\u4fe1\u606f\u4e3a\u7a7a");
        String logPath = trainingTask.getLogPath();
        File file = null;
        if (HussarUtils.isEmpty((Object)logPath)) {
            String runPath = System.getProperty("user.dir");
            String path = runPath + File.separator + "qaLogs" + File.separator + "trainingTask" + File.separator;
            String fileName = taskId + ".log";
            file = FileUtils.getFile(path, fileName);
            logPath = path + fileName;
            TrainingTask updateTrainingTask = new TrainingTask();
            updateTrainingTask.setTaskId(taskId);
            updateTrainingTask.setLogPath(logPath);
            this.trainingTaskMapper.updateById((Object)updateTrainingTask);
        } else {
            file = new File(logPath);
        }
        FileUtil.writeToFile((File)file, (String)trainingTaskLogDto.getLog(), (boolean)true);
        return true;
    }

    @Override
    public TrainingTaskLog readLogs(Long taskId) {
        HussarException.throwByNull((Object)taskId, (String)"\u8bad\u7ec3\u4efb\u52a1id\u4e3a\u7a7a");
        TrainingTask trainingTask = (TrainingTask)((Object)this.trainingTaskMapper.selectById(taskId));
        HussarException.throwByNull((Object)((Object)trainingTask), (String)"\u8bad\u7ec3\u4efb\u52a1\u4fe1\u606f\u4e3a\u7a7a");
        if (HussarUtils.isNotEmpty((Object)trainingTask.getLogPath())) {
            String log = "";
            try {
                log = FileUtil.readToString((File)new File(trainingTask.getLogPath()));
            }
            catch (Exception e) {
                HussarException.throwBy((boolean)true, (String)"\u672a\u627e\u5230\u6587\u4ef6\u8def\u5f84");
            }
            TrainingTaskLog trainingTaskLog = new TrainingTaskLog();
            trainingTaskLog.setTaskId(taskId);
            trainingTaskLog.setLog(log);
            return trainingTaskLog;
        }
        return null;
    }

    @Override
    public Map<String, Object> qryIndicatorEffect(Long taskId) {
        HashMap<String, Object> resultMap = new HashMap<String, Object>();
        HussarException.throwByNull((Object)taskId, (String)"\u8bad\u7ec3\u4efb\u52a1id\u4e3a\u7a7a");
        TrainingTask trainingTask = (TrainingTask)((Object)this.trainingTaskMapper.selectById(taskId));
        HussarException.throwByNull((Object)((Object)trainingTask), (String)"\u8bad\u7ec3\u4efb\u52a1\u4fe1\u606f\u4e3a\u7a7a");
        String logPath = trainingTask.getLogPath();
        if (HussarUtils.isNotEmpty((Object)logPath)) {
            String logs = "";
            try {
                logs = FileUtil.readToString((File)new File(logPath));
            }
            catch (Exception e) {
                HussarException.throwBy((boolean)true, (String)"\u672a\u627e\u5230\u6587\u4ef6\u8def\u5f84");
            }
            String[] split = logs.split("\n");
            HashMap<String, Double> map = new HashMap<String, Double>();
            ArrayList<IndicatorEffectChart> list = new ArrayList<IndicatorEffectChart>();
            for (int i = 0; i < split.length; ++i) {
                IndicatorEffectChart indicatorEffectChart = new IndicatorEffectChart();
                JSONObject jsonObject = JSONObject.parseObject((String)split[i]);
                indicatorEffectChart.setStep((i + 1) * 10);
                Double loss = jsonObject.getDouble("loss");
                Double epoch = jsonObject.getDouble("epoch");
                indicatorEffectChart.setLoss(loss);
                if (HussarUtils.isEmpty(map.get("loss"))) {
                    map.put("loss", loss);
                    map.put("epoch", epoch);
                } else {
                    int compare = loss.compareTo((Double)map.get("loss"));
                    if (compare < 0) {
                        map.put("loss", loss);
                        map.put("epoch", epoch);
                    }
                }
                list.add(indicatorEffectChart);
            }
            resultMap.put("min", map);
            resultMap.put("chart", list);
        }
        return resultMap;
    }

    @Override
    @HussarTransactional
    public Boolean cancalQueue(Long taskId) {
        HussarException.throwByNull((Object)taskId, (String)"\u8bad\u7ec3\u4efb\u52a1id\u4e3a\u7a7a");
        LambdaQueryWrapper qw = new LambdaQueryWrapper();
        qw.eq(QueueTrainingTask::getTaskId, (Object)taskId);
        this.queueTrainingTaskService.remove((Wrapper)qw);
        TrainingTask trainingTask = new TrainingTask();
        trainingTask.setTrainingStatus("0");
        trainingTask.setTaskId(taskId);
        this.trainingTaskMapper.updateById((Object)trainingTask);
        return true;
    }

    public static void main(String[] args) {
        String s = FileUtil.readToString((File)new File("D:\\workSpace\\ai\\ai-qa\\qaLogs\\trainingTask\\843513821531480064.log"));
        System.out.println(s);
        String[] split = s.split("\n");
        HashMap<String, Double> map = new HashMap<String, Double>();
        for (int i = 0; i < split.length; ++i) {
            IndicatorEffectChart indicatorEffectChart = new IndicatorEffectChart();
            JSONObject jsonObject = JSONObject.parseObject((String)split[i]);
            indicatorEffectChart.setStep((i + 1) * 10);
            Double loss = jsonObject.getDouble("loss");
            Double epoch = jsonObject.getDouble("epoch");
            indicatorEffectChart.setLoss(loss);
            if (HussarUtils.isEmpty(map.get("loss"))) {
                map.put("loss", loss);
                map.put("epoch", epoch);
                continue;
            }
            int compare = loss.compareTo((Double)map.get("loss"));
            if (compare >= 0) continue;
            map.put("loss", loss);
            map.put("epoch", epoch);
        }
        System.out.println(map);
    }

    private Map<String, Object> comineTrainingParams(TrainingTask trainingTask) {
        Long trainingSetId = trainingTask.getTrainingSetId();
        Long ragModelId = trainingTask.getRagModelId();
        LambdaQueryWrapper qw = new LambdaQueryWrapper();
        qw.eq(QATrainingsetDetail::getTrainingsetId, (Object)trainingSetId);
        List list = this.qaTrainingsetDetailService.list((Wrapper)qw);
        HussarException.throwByNull((Object)list, (String)"\u8bad\u7ec3\u96c6\u8be6\u60c5\u4e3a\u7a7a");
        QARagModel qaRagModel = null;
        if (HussarUtils.isNotEmpty((Object)ragModelId)) {
            qaRagModel = (QARagModel)((Object)this.qaRagModelService.getById(ragModelId));
        }
        HashMap<String, Object> params = new HashMap<String, Object>();
        params.put("taskId", trainingTask.getTaskId());
        params.put("taskName", trainingTask.getTaskName());
        params.put("basicModelType", trainingTask.getBasicModelType());
        HashMap<String, Number> trainingParams = new HashMap<String, Number>();
        TrainingParamsVo trainingParamsVo = (TrainingParamsVo)JSONObject.parseObject((String)trainingTask.getTrainingParams(), TrainingParamsVo.class);
        trainingParams.put("per_device_train_batch_size", trainingParamsVo.getPer_device_train_batch_size());
        trainingParams.put("learning_rate", Double.parseDouble(trainingParamsVo.getLearning_rate()));
        trainingParams.put("num_train_epochs", trainingParamsVo.getNum_train_epochs());
        trainingParams.put("save_steps", trainingParamsVo.getSave_steps());
        trainingParams.put("max_new_tokens", trainingParamsVo.getMax_new_tokens());
        trainingParams.put("utoff_len", trainingParamsVo.getCutoff_len());
        params.put("trainingParams", trainingParams);
        ArrayList trainingSetList = new ArrayList();
        for (QATrainingsetDetail qaTrainingsetDetail : list) {
            HashMap<String, String> trainingSetMsp = new HashMap<String, String>();
            trainingSetMsp.put("prompt", qaTrainingsetDetail.getDetailPrompt());
            trainingSetMsp.put("answer", qaTrainingsetDetail.getDetailAnswer());
            trainingSetList.add(trainingSetMsp);
        }
        params.put("trainingSets", trainingSetList);
        String modelArg = qaRagModel.getModelArg();
        Map<String, Object> map = YamlUtil.transferToMap(modelArg);
        params.put("ragParams", map);
        return params;
    }
}

