/*
 * Decompiled with CFR 0.152.
 */
package com.jxdinfo.hussar.kgbase.algomodel.service.impl;

import cn.hutool.core.util.IdUtil;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.conditions.Wrapper;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import com.jxdinfo.hussar.core.exception.HussarException;
import com.jxdinfo.hussar.kgbase.algomodel.dao.TrainNoteMapper;
import com.jxdinfo.hussar.kgbase.algomodel.model.dto.NerDTO;
import com.jxdinfo.hussar.kgbase.algomodel.model.po.TrainModel;
import com.jxdinfo.hussar.kgbase.algomodel.model.po.TrainNote;
import com.jxdinfo.hussar.kgbase.algomodel.model.po.TrainTask;
import com.jxdinfo.hussar.kgbase.algomodel.model.vo.NerVO;
import com.jxdinfo.hussar.kgbase.algomodel.model.vo.TrainTaskVO;
import com.jxdinfo.hussar.kgbase.algomodel.service.ITrainModelService;
import com.jxdinfo.hussar.kgbase.algomodel.service.ITrainNoteService;
import com.jxdinfo.hussar.kgbase.algomodel.service.ITrainTaskService;
import com.jxdinfo.hussar.kgbase.algomodel.service.NerService;
import com.jxdinfo.hussar.kgbase.algomodel.util.NerWebSocket;
import com.jxdinfo.hussar.kgbase.common.util.HttpUtil;
import com.jxdinfo.hussar.kgbase.neo4j.model.Neo4jBasicNode;
import com.jxdinfo.hussar.kgbase.neo4j.model.Neo4jBasicRelationShip;
import com.jxdinfo.hussar.platform.core.base.apiresult.ApiResponse;
import com.jxdinfo.hussar.platform.core.utils.StringUtil;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Resource;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

@Service
public class NerServiceImpl
implements NerService {
    @Resource
    private ITrainTaskService iTrainTaskService;
    @Resource
    private ITrainNoteService iTrainNoteService;
    @Resource
    private ITrainModelService iTrainModelService;
    @Resource
    private TrainNoteMapper trainNoteMapper;
    @Resource
    private NerWebSocket nerWebSocket;
    @Value(value="${model-config.ner-api}")
    private String nerApi;
    @Value(value="${model-config.ner-location}")
    private String nerLocation;
    @Value(value="${model-config.ip}")
    private String ip;
    @Value(value="${minio.location}")
    private String minioLocation;
    private Process trainProc;
    private Process serviceProc;

    public ApiResponse namedEntityRecognitionPedict(NerDTO nerDTO) {
        String strRes = HttpUtil.doPostJson(this.nerApi, null, JSON.toJSONString((Object)nerDTO));
        if (StringUtil.isEmpty((Object)strRes)) {
            return ApiResponse.fail((String)"\u672a\u8bc6\u522b\u5230\u5b9e\u4f53");
        }
        JSONArray jsonArrayRes = JSONArray.parseArray((String)strRes);
        List nerVOS = jsonArrayRes.toJavaList(NerVO.class);
        HashMap<String, String> transMap = new HashMap<String, String>();
        transMap.put("UNIT", "\u5355\u4f4d");
        transMap.put("PERSON", "\u4eba");
        transMap.put("DEVICE", "\u7535\u529b\u8bbe\u5907");
        transMap.put("PARTS", "\u7535\u529b\u8bbe\u5907\u96f6\u4ef6");
        transMap.put("RULES", "\u89c4\u7ae0\u5236\u5ea6");
        transMap.put("P_CLASS", "\u9690\u60a3\u4e13\u4e1a\u5206\u7c7b");
        transMap.put("D_CLASS", "\u9690\u60a3\u8be6\u7ec6\u5206\u7c7b");
        transMap.put("PROP", "\u9690\u60a3\u7279\u5f81");
        transMap.put("LOCATION", "\u5730\u70b9");
        JSONObject resultObj = new JSONObject();
        ArrayList<Neo4jBasicNode> nodeList = new ArrayList<Neo4jBasicNode>();
        for (NerVO nerVO : nerVOS) {
            Neo4jBasicNode basicNode = new Neo4jBasicNode();
            basicNode.setId(IdUtil.fastSimpleUUID());
            basicNode.setLabel(nerVO.getName());
            basicNode.setNodeType((String)transMap.get(nerVO.getNodeLabel()));
            nodeList.add(basicNode);
        }
        ArrayList<Neo4jBasicRelationShip> relationShipList = new ArrayList<Neo4jBasicRelationShip>();
        for (Neo4jBasicNode node : nodeList) {
            for (Neo4jBasicNode node2 : nodeList) {
                Neo4jBasicRelationShip basicRelationShip;
                if (node.getNodeType().equals("\u4eba") && node2.getNodeType().equals("\u5355\u4f4d")) {
                    basicRelationShip = new Neo4jBasicRelationShip();
                    basicRelationShip.setId(IdUtil.fastSimpleUUID());
                    basicRelationShip.setLabel("\u4efb\u804c\u5355\u4f4d");
                    basicRelationShip.setSource(node.getId());
                    basicRelationShip.setTarget(node2.getId());
                    relationShipList.add(basicRelationShip);
                }
                if (node.getNodeType().equals("\u7535\u529b\u8bbe\u5907") && node2.getNodeType().equals("\u7535\u529b\u8bbe\u5907\u96f6\u4ef6")) {
                    basicRelationShip = new Neo4jBasicRelationShip();
                    basicRelationShip.setId(IdUtil.fastSimpleUUID());
                    basicRelationShip.setLabel("\u96f6\u90e8\u4ef6");
                    basicRelationShip.setSource(node.getId());
                    basicRelationShip.setTarget(node2.getId());
                    relationShipList.add(basicRelationShip);
                }
                if (node.getNodeType().equals("\u9690\u60a3\u4e13\u4e1a\u5206\u7c7b") && node2.getNodeType().equals("\u9690\u60a3\u8be6\u7ec6\u5206\u7c7b")) {
                    basicRelationShip = new Neo4jBasicRelationShip();
                    basicRelationShip.setId(IdUtil.fastSimpleUUID());
                    basicRelationShip.setLabel("\u8be6\u7ec6\u5206\u7c7b");
                    basicRelationShip.setSource(node.getId());
                    basicRelationShip.setTarget(node2.getId());
                    relationShipList.add(basicRelationShip);
                }
                if (node.getNodeType().equals("\u9690\u60a3\u8be6\u7ec6\u5206\u7c7b") && node2.getNodeType().equals("\u9690\u60a3\u7279\u5f81")) {
                    basicRelationShip = new Neo4jBasicRelationShip();
                    basicRelationShip.setId(IdUtil.fastSimpleUUID());
                    basicRelationShip.setLabel("\u9690\u60a3\u7279\u5f81");
                    basicRelationShip.setSource(node.getId());
                    basicRelationShip.setTarget(node2.getId());
                    relationShipList.add(basicRelationShip);
                }
                if (node.getNodeType().equals("\u9690\u60a3\u8be6\u7ec6\u5206\u7c7b") && node2.getNodeType().equals("\u89c4\u7ae0\u5236\u5ea6")) {
                    basicRelationShip = new Neo4jBasicRelationShip();
                    basicRelationShip.setId(IdUtil.fastSimpleUUID());
                    basicRelationShip.setLabel("\u76f8\u5173\u6587\u6863");
                    basicRelationShip.setSource(node.getId());
                    basicRelationShip.setTarget(node2.getId());
                    relationShipList.add(basicRelationShip);
                }
                if (node.getNodeType().equals("\u7535\u529b\u8bbe\u5907") && node2.getNodeType().equals("\u5730\u70b9")) {
                    basicRelationShip = new Neo4jBasicRelationShip();
                    basicRelationShip.setId(IdUtil.fastSimpleUUID());
                    basicRelationShip.setLabel("\u6240\u5728\u5730\u70b9");
                    basicRelationShip.setSource(node.getId());
                    basicRelationShip.setTarget(node2.getId());
                    relationShipList.add(basicRelationShip);
                }
                if (!node.getNodeType().equals("\u7535\u529b\u8bbe\u5907") || !node2.getNodeType().equals("\u5355\u4f4d")) continue;
                basicRelationShip = new Neo4jBasicRelationShip();
                basicRelationShip.setId(IdUtil.fastSimpleUUID());
                basicRelationShip.setLabel("\u8d23\u4efb\u5355\u4f4d");
                basicRelationShip.setSource(node.getId());
                basicRelationShip.setTarget(node2.getId());
                relationShipList.add(basicRelationShip);
            }
        }
        resultObj.put("nodes", nodeList);
        resultObj.put("edges", relationShipList);
        return ApiResponse.success((Object)resultObj);
    }

    public ApiResponse namedEntityRecognitionTrain(final String trainTaskId) {
        final TrainTaskVO taskVO = (TrainTaskVO)this.iTrainTaskService.getInfoById(trainTaskId).getData();
        final String modelPath = this.minioLocation + "/" + trainTaskId + "/checkpoints";
        final String vocabsPath = this.minioLocation + "/" + trainTaskId + "/vocabs";
        final String logPath = this.minioLocation + "/" + trainTaskId + "/logs";
        final StringBuffer nodeLabels = new StringBuffer();
        Map labelMap = taskVO.getDataStatis();
        int index = 0;
        for (String key : labelMap.keySet()) {
            nodeLabels.append(key);
            if (++index >= labelMap.size()) continue;
            nodeLabels.append(",");
        }
        try {
            Runnable runnable = new Runnable(){

                @Override
                public void run() {
                    try {
                        UpdateWrapper taskUpdateWrapper = new UpdateWrapper();
                        taskUpdateWrapper.eq((Object)"ID", (Object)trainTaskId);
                        taskUpdateWrapper.eq((Object)"DEL_FLAG", (Object)"0");
                        taskUpdateWrapper.set((Object)"TASK_STATE", (Object)"1");
                        taskUpdateWrapper.set((Object)"MODEL_PATH", (Object)modelPath);
                        taskUpdateWrapper.set((Object)"LOG_PATH", (Object)logPath);
                        taskUpdateWrapper.set((Object)"TRAIN_START_TIME", (Object)new Date());
                        taskUpdateWrapper.set((Object)"TRAIN_END_TIME", null);
                        NerServiceImpl.this.iTrainTaskService.update((Wrapper)taskUpdateWrapper);
                        NerServiceImpl.this.trainNoteMapper.deleteNotesByTaskId(trainTaskId);
                        TrainTask task = (TrainTask)NerServiceImpl.this.iTrainTaskService.getById((Serializable)((Object)trainTaskId));
                        StringBuffer command = new StringBuffer("python ").append(NerServiceImpl.this.nerLocation).append("/main.py --mode=train ");
                        command.append(" --datasets_fold=").append(taskVO.getSamplePath());
                        command.append(" --vocabs_dir=").append(vocabsPath);
                        command.append(" --log_dir=").append(logPath);
                        command.append(" --checkpoints_dir=").append(modelPath);
                        command.append(" --suffix=").append(nodeLabels.toString());
                        command.append(" --epoch=").append(taskVO.getEpoch());
                        command.append(" --batch_size=").append(taskVO.getBatchSize());
                        command.append(" --learning_rate=").append(taskVO.getLearningRate());
                        NerServiceImpl.this.trainProc = Runtime.getRuntime().exec(command.toString(), null, new File(NerServiceImpl.this.nerLocation));
                        FileInputStream errorStream = (FileInputStream)NerServiceImpl.this.trainProc.getErrorStream();
                        InputStreamReader isr = new InputStreamReader((InputStream)errorStream, "gbk");
                        BufferedReader in = new BufferedReader(isr);
                        String line = null;
                        String lastLine = "";
                        while ((line = in.readLine()) != null) {
                            System.out.println(line);
                            if (lastLine.startsWith("epoch finished")) {
                                JSONObject epochRes = JSONObject.parseObject((String)line);
                                TrainNote trainNote = new TrainNote();
                                trainNote.setTrainTaskId(trainTaskId);
                                trainNote.setCreateTime(new Date());
                                trainNote.setCurrentEpoch(Integer.parseInt(epochRes.getString("epoch")));
                                trainNote.setPrecisionRate(Double.parseDouble(epochRes.getString("precision")));
                                trainNote.setRecall(Double.parseDouble(epochRes.getString("recall")));
                                trainNote.setF1(Double.parseDouble(epochRes.getString("f1")));
                                trainNote.setAccuracy(Double.parseDouble(epochRes.getString("accuracy")));
                                trainNote.setIsBest(epochRes.getString("isBest"));
                                if ("1".equals(trainNote.getIsBest())) {
                                    NerServiceImpl.this.trainNoteMapper.setIsBestToZero(trainTaskId);
                                }
                                trainNote.setTimeConsumption(epochRes.getString("timeConsumption"));
                                NerServiceImpl.this.iTrainNoteService.save((Object)trainNote);
                                DecimalFormat decimalFormat = new DecimalFormat("0.00");
                                String progress = decimalFormat.format((float)trainNote.getCurrentEpoch() / (float)task.getEpoch().intValue());
                                task.setTaskProgress(progress);
                                NerServiceImpl.this.iTrainTaskService.updateById((Object)task);
                            }
                            NerServiceImpl.this.nerWebSocket.sendMessage(line);
                            lastLine = line;
                        }
                        in.close();
                        if (NerServiceImpl.this.trainProc.waitFor() == 0) {
                            task.setTaskState("2");
                            task.setTrainEndTime(new Date());
                            NerServiceImpl.this.iTrainTaskService.updateById((Object)task);
                        } else {
                            NerServiceImpl.this.iTrainTaskService.stopTraining(trainTaskId);
                        }
                    }
                    catch (Exception e) {
                        e.printStackTrace();
                        NerServiceImpl.this.iTrainTaskService.stopTraining(trainTaskId);
                    }
                }
            };
            Thread thread = new Thread(runnable);
            thread.start();
            return ApiResponse.success((String)"\u5f00\u59cb\u8bad\u7ec3");
        }
        catch (Exception e) {
            e.printStackTrace();
            this.iTrainTaskService.stopTraining(trainTaskId);
            return ApiResponse.fail((String)e.getMessage());
        }
    }

    public boolean nerSeviceStart(String trainTaskId) {
        QueryWrapper queryWrapper = new QueryWrapper();
        queryWrapper.eq((Object)"TRAIN_TASK_ID", (Object)trainTaskId);
        queryWrapper.eq((Object)"DEL_FLAG", (Object)"0");
        final TrainModel model = (TrainModel)this.iTrainModelService.getOne((Wrapper)queryWrapper);
        final TrainTaskVO taskVO = (TrainTaskVO)this.iTrainTaskService.getInfoById(trainTaskId).getData();
        final String modelPath = this.minioLocation + "/" + trainTaskId + "/checkpoints";
        final String vocabsPath = this.minioLocation + "/" + trainTaskId + "/vocabs";
        final String logPath = this.minioLocation + "/" + trainTaskId + "/service_logs";
        final StringBuffer nodeLabels = new StringBuffer();
        Map labelMap = taskVO.getDataStatis();
        int index = 0;
        for (String key : labelMap.keySet()) {
            nodeLabels.append(key);
            if (++index >= labelMap.size()) continue;
            nodeLabels.append(",");
        }
        try {
            Runnable runnable = new Runnable(){

                @Override
                public void run() {
                    try {
                        StringBuffer command = new StringBuffer("python ").append(NerServiceImpl.this.nerLocation).append("/flask_service.py");
                        command.append(" --datasets_fold=").append(taskVO.getSamplePath());
                        command.append(" --vocabs_dir=").append(vocabsPath);
                        command.append(" --log_dir=").append(logPath);
                        command.append(" --checkpoints_dir=").append(modelPath);
                        command.append(" --suffix=").append(nodeLabels.toString());
                        command.append(" --epoch=").append(taskVO.getEpoch());
                        command.append(" --batch_size=").append(taskVO.getBatchSize());
                        command.append(" --learning_rate=").append(taskVO.getLearningRate());
                        NerServiceImpl.this.serviceProc = Runtime.getRuntime().exec(command.toString(), null, new File(NerServiceImpl.this.nerLocation));
                        FileInputStream errorStream = (FileInputStream)NerServiceImpl.this.serviceProc.getErrorStream();
                        InputStreamReader isr = new InputStreamReader((InputStream)errorStream, "gbk");
                        BufferedReader in = new BufferedReader(isr);
                        String line = null;
                        while ((line = in.readLine()) != null) {
                            System.out.println(line);
                            if (!line.startsWith("service launched")) continue;
                            model.setModelState("1");
                            NerServiceImpl.this.iTrainModelService.updateById((Object)model);
                        }
                        in.close();
                    }
                    catch (Exception e) {
                        model.setModelState("3");
                        NerServiceImpl.this.iTrainModelService.updateById((Object)model);
                        e.printStackTrace();
                    }
                }
            };
            Thread thread = new Thread(runnable);
            thread.start();
            return true;
        }
        catch (HussarException e) {
            e.printStackTrace();
            return false;
        }
    }

    public void destoryNerService() {
        this.serviceProc.destroy();
    }

    public void destoryNerTrain() {
        if (this.trainProc != null) {
            this.trainProc.destroy();
        }
    }
}

