/*
 * Decompiled with CFR 0.152.
 */
package com.jxdinfo.hussar.kgbase.algomodel.service.impl;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.conditions.Wrapper;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.google.common.base.Joiner;
import com.jxdinfo.hussar.common.security.BaseSecurityUtil;
import com.jxdinfo.hussar.common.security.SecurityUser;
import com.jxdinfo.hussar.core.exception.HussarException;
import com.jxdinfo.hussar.kgbase.algomodel.dao.TrainSampleMapper;
import com.jxdinfo.hussar.kgbase.algomodel.dao.TrainTaskMapper;
import com.jxdinfo.hussar.kgbase.algomodel.model.dto.TrainSampleDTO;
import com.jxdinfo.hussar.kgbase.algomodel.model.dto.TrainTaskDTO;
import com.jxdinfo.hussar.kgbase.algomodel.model.po.TrainModel;
import com.jxdinfo.hussar.kgbase.algomodel.model.po.TrainNote;
import com.jxdinfo.hussar.kgbase.algomodel.model.po.TrainSample;
import com.jxdinfo.hussar.kgbase.algomodel.model.po.TrainTask;
import com.jxdinfo.hussar.kgbase.algomodel.model.vo.TrainSampleVO;
import com.jxdinfo.hussar.kgbase.algomodel.model.vo.TrainTaskVO;
import com.jxdinfo.hussar.kgbase.algomodel.service.ISampleService;
import com.jxdinfo.hussar.kgbase.algomodel.service.ITrainModelService;
import com.jxdinfo.hussar.kgbase.algomodel.service.ITrainNoteService;
import com.jxdinfo.hussar.kgbase.algomodel.service.ITrainSampleService;
import com.jxdinfo.hussar.kgbase.algomodel.service.ITrainTaskService;
import com.jxdinfo.hussar.kgbase.algomodel.service.NerService;
import com.jxdinfo.hussar.kgbase.algomodel.service.PeService;
import com.jxdinfo.hussar.kgbase.algomodel.service.ReService;
import com.jxdinfo.hussar.kgbase.bzrw.kgtaggingtask1.service.impl.LabelDocServiceImpl;
import com.jxdinfo.hussar.kgbase.bzrw.kgtaggingtask1.vo.KgAnnotatedCorpusPropertyVO;
import com.jxdinfo.hussar.kgbase.bzrw.kgtaggingtask1.vo.KgAnnotatedCorpusRelationVO;
import com.jxdinfo.hussar.kgbase.bzrw.kgtaggingtask1.vo.KgAnnotatedCorpusVO;
import com.jxdinfo.hussar.kgbase.common.util.ExcelUtil;
import com.jxdinfo.hussar.kgbase.common.util.FileUtil;
import com.jxdinfo.hussar.kgbase.common.util.MinioUtil;
import com.jxdinfo.hussar.kgbase.wdgl.kgdocmanagement.model.SysFileInfo;
import com.jxdinfo.hussar.kgbase.wdgl.kgdocmanagement.service.SysFileInfoService;
import com.jxdinfo.hussar.platform.core.base.apiresult.ApiResponse;
import com.jxdinfo.hussar.support.rmi.core.utils.StringUtils;
import java.io.File;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Resource;
import org.springframework.beans.BeanUtils;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.web.multipart.MultipartFile;

@Service
public class TrainTaskServiceImpl
extends ServiceImpl<TrainTaskMapper, TrainTask>
implements ITrainTaskService {
    @Resource
    private TrainTaskMapper trainTaskMapper;
    @Resource
    private TrainSampleMapper trainSampleMapper;
    @Resource
    private ITrainSampleService iTrainSampleService;
    @Resource
    private ITrainNoteService iTrainNoteService;
    @Resource
    private ITrainModelService iTrainModelService;
    @Resource
    private LabelDocServiceImpl labelDocService;
    @Resource
    private ISampleService iSampleService;
    @Resource
    private SysFileInfoService sysFileInfoService;
    @Resource
    private NerService nerService;
    @Resource
    private ReService reService;
    @Resource
    private PeService peService;
    @Resource
    private MinioUtil minioUtil;

    @Transactional
    public ApiResponse saveOrEdit(TrainTaskDTO trainTaskDTO) {
        ArrayList corpusVOS;
        Date nowDate = new Date();
        SecurityUser user = BaseSecurityUtil.getUser();
        TrainTask trainTask = new TrainTask();
        BeanUtils.copyProperties((Object)trainTaskDTO, (Object)trainTask);
        if (trainTaskDTO.getId() == null) {
            trainTask.setTaskProgress("0");
            trainTask.setCreatorName(user.getUserName());
            trainTask.setCreator(user.getId());
            trainTask.setCreateTime(nowDate);
            trainTask.setLastEditor(user.getId());
            trainTask.setLastTime(nowDate);
            this.trainTaskMapper.insert(trainTask);
        } else {
            this.editTrainTask(trainTaskDTO);
        }
        List sampleDTOS = trainTaskDTO.getSampleList();
        ArrayList<TrainSample> samples = new ArrayList<TrainSample>();
        for (TrainSampleDTO sampleDTO : sampleDTOS) {
            try {
                TrainSample sample = new TrainSample();
                BeanUtils.copyProperties((Object)sampleDTO, (Object)sample);
                if (sample.getTagTaskId() != null && StringUtils.isNotEmpty((CharSequence)sample.getTagTaskId())) {
                    JSONObject o = (JSONObject)this.labelDocService.kgAnnotatedCorpusList(sample.getTagTaskId()).getData();
                    if ("NER".equals(trainTaskDTO.getModelType())) {
                        sampleDTO.setSampleJson(JSONArray.parseArray((String)JSON.toJSONString((Object)o.get((Object)"nodeList"))));
                    } else if ("RE".equals(trainTaskDTO.getModelType())) {
                        sampleDTO.setSampleJson(JSONArray.parseArray((String)JSON.toJSONString((Object)o.get((Object)"relationlist"))));
                    } else {
                        sampleDTO.setSampleJson(JSONArray.parseArray((String)JSON.toJSONString((Object)o.get((Object)"propertyList"))));
                    }
                } else if (sample.getFileId() != null && StringUtils.isNotEmpty((CharSequence)sample.getFileId())) {
                    File f;
                    SysFileInfo fileInfo = (SysFileInfo)this.sysFileInfoService.getById((Serializable)((Object)sample.getFileId()));
                    String filePath = null;
                    if (fileInfo != null) {
                        filePath = fileInfo.getAttachmentDir() + fileInfo.getAttachmentName();
                    }
                    if ((f = new File(filePath)).exists()) {
                        MultipartFile file = FileUtil.fileToMultipartFile(f);
                        List<Map<String, String>> excelMap = ExcelUtil.readExcel(file);
                        JSONArray sampleJson = (JSONArray)JSONArray.toJSON(excelMap);
                        sampleDTO.setSampleJson(sampleJson);
                    }
                    fileInfo.getAttachmentDir();
                }
                if (sampleDTO.getSampleJson() != null) {
                    sample.setSampleJson(sampleDTO.getSampleJson().toJSONString());
                }
                sample.setTrainTaskId(trainTask.getId());
                sample.setCreatorName(user.getUserName());
                sample.setCreator(user.getId());
                sample.setCreateTime(nowDate);
                samples.add(sample);
            }
            catch (HussarException e) {
                e.printStackTrace();
            }
        }
        this.iTrainSampleService.saveBatch(samples);
        if ("NER".equals(trainTaskDTO.getModelType())) {
            corpusVOS = new ArrayList();
            for (TrainSampleDTO sampleDTO : sampleDTOS) {
                if (sampleDTO.getSampleJson() == null) continue;
                corpusVOS.addAll(JSONObject.parseArray((String)sampleDTO.getSampleJson().toJSONString(), KgAnnotatedCorpusVO.class));
            }
            if (corpusVOS.size() < trainTask.getBatchSize()) {
                throw new HussarException("\u6837\u672c\u6570\u91cf\u8fc7\u5c11\uff01\u8bf7\u81f3\u5c11\u4e0a\u4f20" + trainTask.getBatchSize() + "\u6761\u8bed\u6599");
            }
            this.iSampleService.changeToNerCorpus(corpusVOS, trainTask.getId());
        } else if ("RE".equals(trainTaskDTO.getModelType())) {
            corpusVOS = new ArrayList();
            for (TrainSampleDTO sampleDTO : sampleDTOS) {
                corpusVOS.addAll(JSONObject.parseArray((String)sampleDTO.getSampleJson().toJSONString(), KgAnnotatedCorpusRelationVO.class));
            }
            if (corpusVOS.size() < trainTask.getBatchSize()) {
                throw new HussarException("\u6837\u672c\u6570\u91cf\u8fc7\u5c11\uff01\u8bf7\u81f3\u5c11\u4e0a\u4f20" + trainTask.getBatchSize() + "\u6761\u8bed\u6599");
            }
            this.iSampleService.changeToReCorpus(corpusVOS, trainTask.getId());
        } else {
            corpusVOS = new ArrayList();
            for (TrainSampleDTO sampleDTO : sampleDTOS) {
                corpusVOS.addAll(JSONObject.parseArray((String)sampleDTO.getSampleJson().toJSONString(), KgAnnotatedCorpusPropertyVO.class));
            }
            if (corpusVOS.size() < trainTask.getBatchSize()) {
                throw new HussarException("\u6837\u672c\u6570\u91cf\u8fc7\u5c11\uff01\u8bf7\u81f3\u5c11\u4e0a\u4f20" + trainTask.getBatchSize() + "\u6761\u8bed\u6599");
            }
            this.iSampleService.changeToPeCorpus(corpusVOS, trainTask.getId());
        }
        if ("1".equals(trainTaskDTO.getTaskState())) {
            switch (trainTaskDTO.getModelType()) {
                case "NER": {
                    this.nerService.namedEntityRecognitionTrain(trainTask.getId());
                    break;
                }
                case "RE": {
                    this.reService.relationRecognitionTrain(trainTask.getId());
                    break;
                }
                case "PE": {
                    this.peService.propertyRecognitionTrain(trainTask.getId());
                }
            }
        }
        return ApiResponse.success((String)"\u8bad\u7ec3\u4efb\u52a1\u4fdd\u5b58\u6210\u529f\uff01");
    }

    public ApiResponse<TrainTaskVO> getInfoById(String id) {
        ArrayList corpusVOS;
        TrainTaskVO taskVO = new TrainTaskVO();
        TrainTask task = (TrainTask)this.getById((Serializable)((Object)id));
        if (task == null) {
            return ApiResponse.fail((String)"\u672a\u67e5\u8be2\u5230\u8bad\u7ec3\u4efb\u52a1");
        }
        BeanUtils.copyProperties((Object)task, (Object)taskVO);
        QueryWrapper queryWrapper = new QueryWrapper();
        queryWrapper.eq((Object)"DEL_FLAG", (Object)"0");
        queryWrapper.eq((Object)"TRAIN_TASK_ID", (Object)id);
        List samples = this.iTrainSampleService.list((Wrapper)queryWrapper);
        ArrayList<TrainSampleVO> sampleVOS = new ArrayList<TrainSampleVO>();
        for (Object sample : samples) {
            TrainSampleVO sampleVO = new TrainSampleVO();
            BeanUtils.copyProperties((Object)sample, (Object)sampleVO);
            sampleVO.setSampleJson(JSONArray.parseArray((String)sample.getSampleJson()));
            sampleVOS.add(sampleVO);
        }
        taskVO.setSampleList(sampleVOS);
        if (task.getModelType().equals("NER")) {
            corpusVOS = new ArrayList();
            for (Object sample : samples) {
                corpusVOS.addAll(JSONObject.parseArray((String)sample.getSampleJson(), KgAnnotatedCorpusVO.class));
            }
            HashMap nerDataMap = new HashMap();
            for (KgAnnotatedCorpusVO corpusVO : corpusVOS) {
                if (nerDataMap.get(corpusVO.getLabelType()) == null) {
                    nerDataMap.put(corpusVO.getLabelType(), 1);
                    continue;
                }
                nerDataMap.put(corpusVO.getLabelType(), (Integer)nerDataMap.get(corpusVO.getLabelType()) + 1);
            }
            taskVO.setDataStatis((Map)nerDataMap);
        } else if (task.getModelType().equals("RE")) {
            corpusVOS = new ArrayList();
            for (Object sample : samples) {
                corpusVOS.addAll(JSONObject.parseArray((String)sample.getSampleJson(), KgAnnotatedCorpusRelationVO.class));
            }
            HashMap reDataMap = new HashMap();
            for (KgAnnotatedCorpusRelationVO corpusVO : corpusVOS) {
                if (reDataMap.get(corpusVO.getRel()) == null) {
                    reDataMap.put(corpusVO.getRel(), 1);
                    continue;
                }
                reDataMap.put(corpusVO.getRel(), (Integer)reDataMap.get(corpusVO.getRel()) + 1);
            }
            taskVO.setDataStatis((Map)reDataMap);
        } else {
            corpusVOS = new ArrayList();
            for (Object sample : samples) {
                corpusVOS.addAll(JSONObject.parseArray((String)sample.getSampleJson(), KgAnnotatedCorpusPropertyVO.class));
            }
            HashMap<String, Integer> peDataMap = new HashMap<String, Integer>();
            for (KgAnnotatedCorpusPropertyVO corpusVO : corpusVOS) {
                if (peDataMap.get(corpusVO.getPropKey()) == null) {
                    peDataMap.put(corpusVO.getPropKey(), 1);
                    continue;
                }
                peDataMap.put(corpusVO.getPropKey(), (Integer)peDataMap.get(corpusVO.getPropKey()) + 1);
            }
            taskVO.setDataStatis(peDataMap);
        }
        return ApiResponse.success((Object)taskVO);
    }

    public ApiResponse<Page<TrainTask>> listByPage(TrainTaskDTO trainTaskDTO) {
        Page page = new Page();
        if (trainTaskDTO.getCurrent() == null) {
            trainTaskDTO.setCurrent(Integer.valueOf(1));
        }
        if (trainTaskDTO.getSize() == null) {
            trainTaskDTO.setSize(Integer.valueOf(10));
        }
        page.setCurrent((long)trainTaskDTO.getCurrent().intValue());
        page.setSize((long)trainTaskDTO.getSize().intValue());
        QueryWrapper queryWrapper = new QueryWrapper();
        if (trainTaskDTO.getTaskName() != null) {
            queryWrapper.like((Object)"TASK_NAME", (Object)trainTaskDTO.getTaskName());
        }
        if (trainTaskDTO.getModelType() != null) {
            queryWrapper.eq((Object)"MODEL_TYPE", (Object)trainTaskDTO.getModelType());
        }
        if (trainTaskDTO.getTaskState() != null) {
            queryWrapper.eq((Object)"TASK_STATE", (Object)trainTaskDTO.getTaskState());
        }
        queryWrapper.eq((Object)"DEL_FLAG", (Object)"0");
        queryWrapper.orderByDesc((Object)"LAST_TIME");
        page = (Page)this.page((IPage)page, (Wrapper)queryWrapper);
        return ApiResponse.success((Object)page);
    }

    public ApiResponse startTraining(String id) {
        TrainTask task = (TrainTask)this.getById((Serializable)((Object)id));
        if (!this.judgeBeforeTraining()) {
            return ApiResponse.fail((String)"\u540c\u4e00\u65f6\u95f4\u5185\u53ea\u80fd\u6709\u4e00\u4e2a\u8fdb\u884c\u4e2d\u7684\u8bad\u7ec3\u4efb\u52a1\uff01");
        }
        this.minioUtil.iterateDeleteObjects(task.getId(), "/logs/");
        this.minioUtil.iterateDeleteObjects(task.getId(), "/vocabs/");
        switch (task.getModelType()) {
            case "NER": {
                this.nerService.namedEntityRecognitionTrain(id);
                break;
            }
            case "RE": {
                this.reService.relationRecognitionTrain(id);
                break;
            }
            case "PE": {
                this.peService.propertyRecognitionTrain(id);
            }
        }
        return ApiResponse.success((String)"\u5f00\u59cb\u8bad\u7ec3");
    }

    public boolean judgeBeforeTraining() {
        QueryWrapper taskQueryWrapper = new QueryWrapper();
        taskQueryWrapper.eq((Object)"TASK_STATE", (Object)"1");
        taskQueryWrapper.eq((Object)"DEL_FLAG", (Object)"0");
        List list = this.list((Wrapper)taskQueryWrapper);
        return list.size() <= 0;
    }

    public ApiResponse stopTraining(String id) {
        try {
            TrainTask task = (TrainTask)this.getById((Serializable)((Object)id));
            task.setTaskState("3");
            task.setTaskProgress("0");
            task.setTrainEndTime(new Date());
            this.updateById(task);
            switch (task.getModelType()) {
                case "NER": {
                    this.nerService.destoryNerTrain();
                    break;
                }
                case "RE": {
                    this.reService.destoryReTrain();
                    break;
                }
                case "PE": {
                    this.peService.destoryPeTrain();
                }
            }
            this.minioUtil.iterateDeleteObjects(task.getId(), "/checkpoints/");
            this.minioUtil.iterateDeleteObjects(task.getId(), "/vocabs/");
            return ApiResponse.success((String)"\u8bad\u7ec3\u4efb\u52a1\u7ec8\u6b62");
        }
        catch (HussarException e) {
            e.printStackTrace();
            return ApiResponse.success((String)"\u64cd\u4f5c\u5931\u8d25");
        }
    }

    public ApiResponse<TrainModel> extractModel(TrainModel trainModel) {
        try {
            if (trainModel.getTrainTaskId() == null || trainModel.getModelName() == null) {
                return ApiResponse.fail((String)"\u8bf7\u586b\u5199\u6a21\u578b\u4fe1\u606f\uff01");
            }
            String taskId = trainModel.getTrainTaskId();
            this.getById((Serializable)((Object)taskId));
            TrainTaskVO taskVO = (TrainTaskVO)this.iTrainNoteService.getFinalTrainResult(taskId).getData();
            TrainNote trainRes = taskVO.getTrainResult();
            if (!"2".equals(taskVO.getTaskState()) || trainRes == null) {
                return ApiResponse.fail((String)"\u6a21\u578b\u63d0\u53d6\u5931\u8d25\uff01");
            }
            trainRes.setCreateTime(new Date());
            trainModel.setTrainTaskName(taskVO.getTaskName());
            trainModel.setModelState("0");
            trainModel.setModelType(taskVO.getModelType());
            trainModel.setPrecisionRate(trainRes.getPrecisionRate());
            trainModel.setAccuracy(trainRes.getAccuracy());
            trainModel.setRecall(trainRes.getRecall());
            trainModel.setF1(trainRes.getF1());
            ArrayList<String> identityRange = new ArrayList<String>();
            Map map = taskVO.getDataStatis();
            for (String key : map.keySet()) {
                identityRange.add(key);
            }
            if (identityRange.size() > 0) {
                trainModel.setIdentityRange(Joiner.on((String)",").join(identityRange));
                this.iTrainModelService.save((Object)trainModel);
                return ApiResponse.success((Object)trainModel);
            }
            return ApiResponse.fail((String)"\u65e0\u6548\u7684\u8bc6\u522b\u8303\u56f4\uff01");
        }
        catch (HussarException e) {
            e.printStackTrace();
            return ApiResponse.fail((String)"\u6a21\u578b\u63d0\u53d6\u5931\u8d25\uff01");
        }
    }

    public void editTrainTask(TrainTaskDTO trainTaskDTO) {
        Date nowDate = new Date();
        SecurityUser user = BaseSecurityUtil.getUser();
        TrainTask trainTask = new TrainTask();
        BeanUtils.copyProperties((Object)trainTaskDTO, (Object)trainTask);
        trainTask.setLastEditor(user.getId());
        trainTask.setLastTime(nowDate);
        trainTask.setTrainSize(null);
        trainTask.setDevSize(null);
        this.trainTaskMapper.updateById(trainTask);
        this.trainSampleMapper.deleteSampleByTaskId(trainTask.getId());
        try {
            this.minioUtil.iterateDeleteObjects(trainTask.getId(), "/sample/");
        }
        catch (HussarException e) {
            e.printStackTrace();
        }
    }
}

