package com.jxdinfo.hussar.kgbase.algomodel.service.impl;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.google.common.base.Joiner;
import com.jxdinfo.hussar.common.security.BaseSecurityUtil;
import com.jxdinfo.hussar.common.security.SecurityUser;
import com.jxdinfo.hussar.core.exception.HussarException;
import com.jxdinfo.hussar.engine.metadata.util.StringUtils;
import com.jxdinfo.hussar.kgbase.algomodel.dao.TrainSampleMapper;
import com.jxdinfo.hussar.kgbase.algomodel.dao.TrainTaskMapper;
import com.jxdinfo.hussar.kgbase.algomodel.model.dto.TrainSampleDTO;
import com.jxdinfo.hussar.kgbase.algomodel.model.dto.TrainTaskDTO;
import com.jxdinfo.hussar.kgbase.algomodel.model.po.TrainModel;
import com.jxdinfo.hussar.kgbase.algomodel.model.po.TrainNote;
import com.jxdinfo.hussar.kgbase.algomodel.model.po.TrainSample;
import com.jxdinfo.hussar.kgbase.algomodel.model.po.TrainTask;
import com.jxdinfo.hussar.kgbase.algomodel.model.vo.TrainSampleVO;
import com.jxdinfo.hussar.kgbase.algomodel.model.vo.TrainTaskVO;
import com.jxdinfo.hussar.kgbase.algomodel.service.ISampleService;
import com.jxdinfo.hussar.kgbase.algomodel.service.ITrainModelService;
import com.jxdinfo.hussar.kgbase.algomodel.service.ITrainNoteService;
import com.jxdinfo.hussar.kgbase.algomodel.service.ITrainSampleService;
import com.jxdinfo.hussar.kgbase.algomodel.service.ITrainTaskService;
import com.jxdinfo.hussar.kgbase.algomodel.service.NerService;
import com.jxdinfo.hussar.kgbase.algomodel.service.PeService;
import com.jxdinfo.hussar.kgbase.algomodel.service.ReService;
import com.jxdinfo.hussar.kgbase.bzrw.kgtaggingtask1.service.impl.KgTaggingTask1ServiceImpl;
import com.jxdinfo.hussar.kgbase.bzrw.kgtaggingtask1.service.impl.LabelDocServiceImpl;
import com.jxdinfo.hussar.kgbase.bzrw.kgtaggingtask1.vo.KgAnnotatedCorpusPropertyVO;
import com.jxdinfo.hussar.kgbase.bzrw.kgtaggingtask1.vo.KgAnnotatedCorpusRelationVO;
import com.jxdinfo.hussar.kgbase.bzrw.kgtaggingtask1.vo.KgAnnotatedCorpusVO;
import com.jxdinfo.hussar.kgbase.common.util.ExcelUtil;
import com.jxdinfo.hussar.kgbase.common.util.FileUtil;
import com.jxdinfo.hussar.kgbase.common.util.MinioUtil;
import com.jxdinfo.hussar.kgbase.wdgl.kgdocmanagement.model.SysFileInfo;
import com.jxdinfo.hussar.kgbase.wdgl.kgdocmanagement.service.SysFileInfoService;
import com.jxdinfo.hussar.platform.core.base.apiresult.ApiResponse;
import java.io.File;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import javax.annotation.Resource;
import org.springframework.beans.BeanUtils;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;

@Service
/* loaded from: input_file:com/jxdinfo/hussar/kgbase/algomodel/service/impl/TrainTaskServiceImpl.class */
public class TrainTaskServiceImpl extends ServiceImpl<TrainTaskMapper, TrainTask> implements ITrainTaskService {

    @Resource
    private TrainTaskMapper trainTaskMapper;

    @Resource
    private TrainSampleMapper trainSampleMapper;

    @Resource
    private ITrainSampleService iTrainSampleService;

    @Resource
    private ITrainNoteService iTrainNoteService;

    @Resource
    private ITrainModelService iTrainModelService;

    @Resource
    private LabelDocServiceImpl labelDocService;

    @Resource
    private ISampleService iSampleService;

    @Resource
    private SysFileInfoService sysFileInfoService;

    @Resource
    private NerService nerService;

    @Resource
    private ReService reService;

    @Resource
    private PeService peService;

    @Resource
    private MinioUtil minioUtil;

    @Transactional
    public ApiResponse saveOrEdit(TrainTaskDTO trainTaskDTO) {
        Date date = new Date();
        SecurityUser user = BaseSecurityUtil.getUser();
        TrainTask trainTask = new TrainTask();
        BeanUtils.copyProperties(trainTaskDTO, trainTask);
        if (trainTaskDTO.getId() == null) {
            trainTask.setTaskProgress(KgTaggingTask1ServiceImpl.TASK_USE_ENTER_GRAPH);
            trainTask.setCreatorName(user.getUserName());
            trainTask.setCreator(user.getId());
            trainTask.setCreateTime(date);
            trainTask.setLastEditor(user.getId());
            trainTask.setLastTime(date);
            this.trainTaskMapper.insert(trainTask);
        } else {
            editTrainTask(trainTaskDTO);
        }
        List<TrainSampleDTO> sampleList = trainTaskDTO.getSampleList();
        ArrayList arrayList = new ArrayList();
        for (TrainSampleDTO trainSampleDTO : sampleList) {
            try {
                TrainSample trainSample = new TrainSample();
                BeanUtils.copyProperties(trainSampleDTO, trainSample);
                if (trainSample.getTagTaskId() != null && StringUtils.isNotEmpty(trainSample.getTagTaskId())) {
                    JSONObject jSONObject = (JSONObject) this.labelDocService.kgAnnotatedCorpusList(trainSample.getTagTaskId()).getData();
                    if ("NER".equals(trainTaskDTO.getModelType())) {
                        trainSampleDTO.setSampleJson(JSONArray.parseArray(JSON.toJSONString(jSONObject.get("nodeList"))));
                    } else if ("RE".equals(trainTaskDTO.getModelType())) {
                        trainSampleDTO.setSampleJson(JSONArray.parseArray(JSON.toJSONString(jSONObject.get("relationlist"))));
                    } else {
                        trainSampleDTO.setSampleJson(JSONArray.parseArray(JSON.toJSONString(jSONObject.get("propertyList"))));
                    }
                } else if (trainSample.getFileId() != null && StringUtils.isNotEmpty(trainSample.getFileId())) {
                    SysFileInfo sysFileInfo = (SysFileInfo) this.sysFileInfoService.getById(trainSample.getFileId());
                    File file = new File(sysFileInfo != null ? sysFileInfo.getAttachmentDir() + sysFileInfo.getAttachmentName() : null);
                    if (file.exists()) {
                        trainSampleDTO.setSampleJson((JSONArray) JSONArray.toJSON(ExcelUtil.readExcel(FileUtil.fileToMultipartFile(file))));
                    }
                    sysFileInfo.getAttachmentDir();
                }
                if (trainSampleDTO.getSampleJson() != null) {
                    trainSample.setSampleJson(trainSampleDTO.getSampleJson().toJSONString());
                }
                trainSample.setTrainTaskId(trainTask.getId());
                trainSample.setCreatorName(user.getUserName());
                trainSample.setCreator(user.getId());
                trainSample.setCreateTime(date);
                arrayList.add(trainSample);
            } catch (HussarException e) {
                e.printStackTrace();
            }
        }
        this.iTrainSampleService.saveBatch(arrayList);
        if ("NER".equals(trainTaskDTO.getModelType())) {
            ArrayList arrayList2 = new ArrayList();
            for (TrainSampleDTO trainSampleDTO2 : sampleList) {
                if (trainSampleDTO2.getSampleJson() != null) {
                    arrayList2.addAll(JSONObject.parseArray(trainSampleDTO2.getSampleJson().toJSONString(), KgAnnotatedCorpusVO.class));
                }
            }
            if (arrayList2.size() < trainTask.getBatchSize().intValue()) {
                throw new HussarException("样本数量过少！请至少上传" + trainTask.getBatchSize() + "条语料");
            }
            this.iSampleService.changeToNerCorpus(arrayList2, trainTask.getId());
        } else if ("RE".equals(trainTaskDTO.getModelType())) {
            ArrayList arrayList3 = new ArrayList();
            Iterator it = sampleList.iterator();
            while (it.hasNext()) {
                arrayList3.addAll(JSONObject.parseArray(((TrainSampleDTO) it.next()).getSampleJson().toJSONString(), KgAnnotatedCorpusRelationVO.class));
            }
            if (arrayList3.size() < trainTask.getBatchSize().intValue()) {
                throw new HussarException("样本数量过少！请至少上传" + trainTask.getBatchSize() + "条语料");
            }
            this.iSampleService.changeToReCorpus(arrayList3, trainTask.getId());
        } else {
            ArrayList arrayList4 = new ArrayList();
            Iterator it2 = sampleList.iterator();
            while (it2.hasNext()) {
                arrayList4.addAll(JSONObject.parseArray(((TrainSampleDTO) it2.next()).getSampleJson().toJSONString(), KgAnnotatedCorpusPropertyVO.class));
            }
            if (arrayList4.size() < trainTask.getBatchSize().intValue()) {
                throw new HussarException("样本数量过少！请至少上传" + trainTask.getBatchSize() + "条语料");
            }
            this.iSampleService.changeToPeCorpus(arrayList4, trainTask.getId());
        }
        if ("1".equals(trainTaskDTO.getTaskState())) {
            String modelType = trainTaskDTO.getModelType();
            boolean z = -1;
            switch (modelType.hashCode()) {
                case 2549:
                    if (modelType.equals("PE")) {
                        z = 2;
                        break;
                    }
                    break;
                case 2611:
                    if (modelType.equals("RE")) {
                        z = true;
                        break;
                    }
                    break;
                case 77179:
                    if (modelType.equals("NER")) {
                        z = false;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    this.nerService.namedEntityRecognitionTrain(trainTask.getId());
                    break;
                case true:
                    this.reService.relationRecognitionTrain(trainTask.getId());
                    break;
                case true:
                    this.peService.propertyRecognitionTrain(trainTask.getId());
                    break;
            }
        }
        return ApiResponse.success("训练任务保存成功！");
    }

    public ApiResponse<TrainTaskVO> getInfoById(String str) {
        TrainTaskVO trainTaskVO = new TrainTaskVO();
        TrainTask trainTask = (TrainTask) getById(str);
        if (trainTask == null) {
            return ApiResponse.fail("未查询到训练任务");
        }
        BeanUtils.copyProperties(trainTask, trainTaskVO);
        QueryWrapper queryWrapper = new QueryWrapper();
        queryWrapper.eq("DEL_FLAG", KgTaggingTask1ServiceImpl.TASK_USE_ENTER_GRAPH);
        queryWrapper.eq("TRAIN_TASK_ID", str);
        List<TrainSample> list = this.iTrainSampleService.list(queryWrapper);
        ArrayList arrayList = new ArrayList();
        for (TrainSample trainSample : list) {
            TrainSampleVO trainSampleVO = new TrainSampleVO();
            BeanUtils.copyProperties(trainSample, trainSampleVO);
            trainSampleVO.setSampleJson(JSONArray.parseArray(trainSample.getSampleJson()));
            arrayList.add(trainSampleVO);
        }
        trainTaskVO.setSampleList(arrayList);
        if (trainTask.getModelType().equals("NER")) {
            ArrayList<KgAnnotatedCorpusVO> arrayList2 = new ArrayList();
            Iterator it = list.iterator();
            while (it.hasNext()) {
                arrayList2.addAll(JSONObject.parseArray(((TrainSample) it.next()).getSampleJson(), KgAnnotatedCorpusVO.class));
            }
            HashMap hashMap = new HashMap();
            for (KgAnnotatedCorpusVO kgAnnotatedCorpusVO : arrayList2) {
                if (hashMap.get(kgAnnotatedCorpusVO.getLabelType()) == null) {
                    hashMap.put(kgAnnotatedCorpusVO.getLabelType(), 1);
                } else {
                    hashMap.put(kgAnnotatedCorpusVO.getLabelType(), Integer.valueOf(((Integer) hashMap.get(kgAnnotatedCorpusVO.getLabelType())).intValue() + 1));
                }
            }
            trainTaskVO.setDataStatis(hashMap);
        } else if (trainTask.getModelType().equals("RE")) {
            ArrayList<KgAnnotatedCorpusRelationVO> arrayList3 = new ArrayList();
            Iterator it2 = list.iterator();
            while (it2.hasNext()) {
                arrayList3.addAll(JSONObject.parseArray(((TrainSample) it2.next()).getSampleJson(), KgAnnotatedCorpusRelationVO.class));
            }
            HashMap hashMap2 = new HashMap();
            for (KgAnnotatedCorpusRelationVO kgAnnotatedCorpusRelationVO : arrayList3) {
                if (hashMap2.get(kgAnnotatedCorpusRelationVO.getRel()) == null) {
                    hashMap2.put(kgAnnotatedCorpusRelationVO.getRel(), 1);
                } else {
                    hashMap2.put(kgAnnotatedCorpusRelationVO.getRel(), Integer.valueOf(((Integer) hashMap2.get(kgAnnotatedCorpusRelationVO.getRel())).intValue() + 1));
                }
            }
            trainTaskVO.setDataStatis(hashMap2);
        } else {
            ArrayList<KgAnnotatedCorpusPropertyVO> arrayList4 = new ArrayList();
            Iterator it3 = list.iterator();
            while (it3.hasNext()) {
                arrayList4.addAll(JSONObject.parseArray(((TrainSample) it3.next()).getSampleJson(), KgAnnotatedCorpusPropertyVO.class));
            }
            HashMap hashMap3 = new HashMap();
            for (KgAnnotatedCorpusPropertyVO kgAnnotatedCorpusPropertyVO : arrayList4) {
                if (hashMap3.get(kgAnnotatedCorpusPropertyVO.getPropKey()) == null) {
                    hashMap3.put(kgAnnotatedCorpusPropertyVO.getPropKey(), 1);
                } else {
                    hashMap3.put(kgAnnotatedCorpusPropertyVO.getPropKey(), Integer.valueOf(((Integer) hashMap3.get(kgAnnotatedCorpusPropertyVO.getPropKey())).intValue() + 1));
                }
            }
            trainTaskVO.setDataStatis(hashMap3);
        }
        return ApiResponse.success(trainTaskVO);
    }

    public ApiResponse<Page<TrainTask>> listByPage(TrainTaskDTO trainTaskDTO) {
        Page page = new Page();
        if (trainTaskDTO.getCurrent() == null) {
            trainTaskDTO.setCurrent(1);
        }
        if (trainTaskDTO.getSize() == null) {
            trainTaskDTO.setSize(10);
        }
        page.setCurrent(trainTaskDTO.getCurrent().intValue());
        page.setSize(trainTaskDTO.getSize().intValue());
        QueryWrapper queryWrapper = new QueryWrapper();
        if (trainTaskDTO.getTaskName() != null) {
            queryWrapper.like("TASK_NAME", trainTaskDTO.getTaskName());
        }
        if (trainTaskDTO.getModelType() != null) {
            queryWrapper.eq("MODEL_TYPE", trainTaskDTO.getModelType());
        }
        if (trainTaskDTO.getTaskState() != null) {
            queryWrapper.eq("TASK_STATE", trainTaskDTO.getTaskState());
        }
        queryWrapper.eq("DEL_FLAG", KgTaggingTask1ServiceImpl.TASK_USE_ENTER_GRAPH);
        queryWrapper.orderByDesc("LAST_TIME");
        return ApiResponse.success(page(page, queryWrapper));
    }

    public ApiResponse startTraining(String str) {
        TrainTask trainTask = (TrainTask) getById(str);
        if (!judgeBeforeTraining()) {
            return ApiResponse.fail("同一时间内只能有一个进行中的训练任务！");
        }
        this.minioUtil.iterateDeleteObjects(trainTask.getId(), "/logs/");
        this.minioUtil.iterateDeleteObjects(trainTask.getId(), "/vocabs/");
        String modelType = trainTask.getModelType();
        boolean z = -1;
        switch (modelType.hashCode()) {
            case 2549:
                if (modelType.equals("PE")) {
                    z = 2;
                    break;
                }
                break;
            case 2611:
                if (modelType.equals("RE")) {
                    z = true;
                    break;
                }
                break;
            case 77179:
                if (modelType.equals("NER")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                this.nerService.namedEntityRecognitionTrain(str);
                break;
            case true:
                this.reService.relationRecognitionTrain(str);
                break;
            case true:
                this.peService.propertyRecognitionTrain(str);
                break;
        }
        return ApiResponse.success("开始训练");
    }

    public boolean judgeBeforeTraining() {
        QueryWrapper queryWrapper = new QueryWrapper();
        queryWrapper.eq("TASK_STATE", "1");
        queryWrapper.eq("DEL_FLAG", KgTaggingTask1ServiceImpl.TASK_USE_ENTER_GRAPH);
        return list(queryWrapper).size() <= 0;
    }

    public ApiResponse stopTraining(String str) {
        try {
            TrainTask trainTask = (TrainTask) getById(str);
            trainTask.setTaskState(KgTaggingTask1ServiceImpl.PASS_STATUS_FLAG);
            trainTask.setTaskProgress(KgTaggingTask1ServiceImpl.TASK_USE_ENTER_GRAPH);
            trainTask.setTrainEndTime(new Date());
            updateById(trainTask);
            String modelType = trainTask.getModelType();
            boolean z = -1;
            switch (modelType.hashCode()) {
                case 2549:
                    if (modelType.equals("PE")) {
                        z = 2;
                        break;
                    }
                    break;
                case 2611:
                    if (modelType.equals("RE")) {
                        z = true;
                        break;
                    }
                    break;
                case 77179:
                    if (modelType.equals("NER")) {
                        z = false;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    this.nerService.destoryNerTrain();
                    break;
                case true:
                    this.reService.destoryReTrain();
                    break;
                case true:
                    this.peService.destoryPeTrain();
                    break;
            }
            this.minioUtil.iterateDeleteObjects(trainTask.getId(), "/checkpoints/");
            this.minioUtil.iterateDeleteObjects(trainTask.getId(), "/vocabs/");
            return ApiResponse.success("训练任务终止");
        } catch (HussarException e) {
            e.printStackTrace();
            return ApiResponse.success("操作失败");
        }
    }

    public ApiResponse<TrainModel> extractModel(TrainModel trainModel) {
        try {
            if (trainModel.getTrainTaskId() == null || trainModel.getModelName() == null) {
                return ApiResponse.fail("请填写模型信息！");
            }
            String trainTaskId = trainModel.getTrainTaskId();
            getById(trainTaskId);
            TrainTaskVO trainTaskVO = (TrainTaskVO) this.iTrainNoteService.getFinalTrainResult(trainTaskId).getData();
            TrainNote trainResult = trainTaskVO.getTrainResult();
            if (!"2".equals(trainTaskVO.getTaskState()) || trainResult == null) {
                return ApiResponse.fail("模型提取失败！");
            }
            trainResult.setCreateTime(new Date());
            trainModel.setTrainTaskName(trainTaskVO.getTaskName());
            trainModel.setModelState(KgTaggingTask1ServiceImpl.TASK_USE_ENTER_GRAPH);
            trainModel.setModelType(trainTaskVO.getModelType());
            trainModel.setPrecisionRate(trainResult.getPrecisionRate());
            trainModel.setAccuracy(trainResult.getAccuracy());
            trainModel.setRecall(trainResult.getRecall());
            trainModel.setF1(trainResult.getF1());
            ArrayList arrayList = new ArrayList();
            Iterator it = trainTaskVO.getDataStatis().keySet().iterator();
            while (it.hasNext()) {
                arrayList.add((String) it.next());
            }
            if (arrayList.size() <= 0) {
                return ApiResponse.fail("无效的识别范围！");
            }
            trainModel.setIdentityRange(Joiner.on(",").join(arrayList));
            this.iTrainModelService.save(trainModel);
            return ApiResponse.success(trainModel);
        } catch (HussarException e) {
            e.printStackTrace();
            return ApiResponse.fail("模型提取失败！");
        }
    }

    public void editTrainTask(TrainTaskDTO trainTaskDTO) {
        Date date = new Date();
        SecurityUser user = BaseSecurityUtil.getUser();
        TrainTask trainTask = new TrainTask();
        BeanUtils.copyProperties(trainTaskDTO, trainTask);
        trainTask.setLastEditor(user.getId());
        trainTask.setLastTime(date);
        trainTask.setTrainSize((Integer) null);
        trainTask.setDevSize((Integer) null);
        this.trainTaskMapper.updateById(trainTask);
        this.trainSampleMapper.deleteSampleByTaskId(trainTask.getId());
        try {
            this.minioUtil.iterateDeleteObjects(trainTask.getId(), "/sample/");
        } catch (HussarException e) {
            e.printStackTrace();
        }
    }
}
