package com.jxdinfo.hussar.kgbase.algomodel.service.impl;

import cn.hutool.core.util.IdUtil;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import com.jxdinfo.hussar.core.exception.HussarException;
import com.jxdinfo.hussar.kgbase.algomodel.dao.TrainNoteMapper;
import com.jxdinfo.hussar.kgbase.algomodel.model.dto.NerDTO;
import com.jxdinfo.hussar.kgbase.algomodel.model.po.TrainModel;
import com.jxdinfo.hussar.kgbase.algomodel.model.po.TrainNote;
import com.jxdinfo.hussar.kgbase.algomodel.model.po.TrainTask;
import com.jxdinfo.hussar.kgbase.algomodel.model.vo.NerVO;
import com.jxdinfo.hussar.kgbase.algomodel.model.vo.TrainTaskVO;
import com.jxdinfo.hussar.kgbase.algomodel.service.ITrainModelService;
import com.jxdinfo.hussar.kgbase.algomodel.service.ITrainNoteService;
import com.jxdinfo.hussar.kgbase.algomodel.service.ITrainTaskService;
import com.jxdinfo.hussar.kgbase.algomodel.service.NerService;
import com.jxdinfo.hussar.kgbase.algomodel.util.NerWebSocket;
import com.jxdinfo.hussar.kgbase.bzrw.kgtaggingtask1.service.impl.KgTaggingTask1ServiceImpl;
import com.jxdinfo.hussar.kgbase.common.util.HttpUtil;
import com.jxdinfo.hussar.kgbase.neo4j.model.Neo4jBasicNode;
import com.jxdinfo.hussar.kgbase.neo4j.model.Neo4jBasicRelationShip;
import com.jxdinfo.hussar.platform.core.base.apiresult.ApiResponse;
import com.jxdinfo.hussar.platform.core.utils.StringUtil;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import javax.annotation.Resource;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

@Service
/* loaded from: input_file:com/jxdinfo/hussar/kgbase/algoModel/service/impl/NerServiceImpl.class */
public class NerServiceImpl implements NerService {

    @Resource
    private ITrainTaskService iTrainTaskService;

    @Resource
    private ITrainNoteService iTrainNoteService;

    @Resource
    private ITrainModelService iTrainModelService;

    @Resource
    private TrainNoteMapper trainNoteMapper;

    @Resource
    private NerWebSocket nerWebSocket;

    @Value("${model-config.ner-api}")
    private String nerApi;

    @Value("${model-config.ner-location}")
    private String nerLocation;

    @Value("${model-config.ip}")
    private String ip;

    @Value("${minio.location}")
    private String minioLocation;
    private Process trainProc;
    private Process serviceProc;

    public ApiResponse namedEntityRecognitionPedict(NerDTO nerDTO) {
        String doPostJson = HttpUtil.doPostJson(this.nerApi, null, JSON.toJSONString(nerDTO));
        if (StringUtil.isEmpty(doPostJson)) {
            return ApiResponse.fail("未识别到实体");
        }
        List<NerVO> javaList = JSONArray.parseArray(doPostJson).toJavaList(NerVO.class);
        HashMap hashMap = new HashMap();
        hashMap.put("UNIT", "单位");
        hashMap.put("PERSON", "人");
        hashMap.put("DEVICE", "电力设备");
        hashMap.put("PARTS", "电力设备零件");
        hashMap.put("RULES", "规章制度");
        hashMap.put("P_CLASS", "隐患专业分类");
        hashMap.put("D_CLASS", "隐患详细分类");
        hashMap.put("PROP", "隐患特征");
        hashMap.put("LOCATION", "地点");
        JSONObject jSONObject = new JSONObject();
        ArrayList<Neo4jBasicNode> arrayList = new ArrayList();
        for (NerVO nerVO : javaList) {
            Neo4jBasicNode neo4jBasicNode = new Neo4jBasicNode();
            neo4jBasicNode.setId(IdUtil.fastSimpleUUID());
            neo4jBasicNode.setLabel(nerVO.getName());
            neo4jBasicNode.setNodeType((String) hashMap.get(nerVO.getNodeLabel()));
            arrayList.add(neo4jBasicNode);
        }
        ArrayList arrayList2 = new ArrayList();
        for (Neo4jBasicNode neo4jBasicNode2 : arrayList) {
            for (Neo4jBasicNode neo4jBasicNode3 : arrayList) {
                if (neo4jBasicNode2.getNodeType().equals("人") && neo4jBasicNode3.getNodeType().equals("单位")) {
                    Neo4jBasicRelationShip neo4jBasicRelationShip = new Neo4jBasicRelationShip();
                    neo4jBasicRelationShip.setId(IdUtil.fastSimpleUUID());
                    neo4jBasicRelationShip.setLabel("任职单位");
                    neo4jBasicRelationShip.setSource(neo4jBasicNode2.getId());
                    neo4jBasicRelationShip.setTarget(neo4jBasicNode3.getId());
                    arrayList2.add(neo4jBasicRelationShip);
                }
                if (neo4jBasicNode2.getNodeType().equals("电力设备") && neo4jBasicNode3.getNodeType().equals("电力设备零件")) {
                    Neo4jBasicRelationShip neo4jBasicRelationShip2 = new Neo4jBasicRelationShip();
                    neo4jBasicRelationShip2.setId(IdUtil.fastSimpleUUID());
                    neo4jBasicRelationShip2.setLabel("零部件");
                    neo4jBasicRelationShip2.setSource(neo4jBasicNode2.getId());
                    neo4jBasicRelationShip2.setTarget(neo4jBasicNode3.getId());
                    arrayList2.add(neo4jBasicRelationShip2);
                }
                if (neo4jBasicNode2.getNodeType().equals("隐患专业分类") && neo4jBasicNode3.getNodeType().equals("隐患详细分类")) {
                    Neo4jBasicRelationShip neo4jBasicRelationShip3 = new Neo4jBasicRelationShip();
                    neo4jBasicRelationShip3.setId(IdUtil.fastSimpleUUID());
                    neo4jBasicRelationShip3.setLabel("详细分类");
                    neo4jBasicRelationShip3.setSource(neo4jBasicNode2.getId());
                    neo4jBasicRelationShip3.setTarget(neo4jBasicNode3.getId());
                    arrayList2.add(neo4jBasicRelationShip3);
                }
                if (neo4jBasicNode2.getNodeType().equals("隐患详细分类") && neo4jBasicNode3.getNodeType().equals("隐患特征")) {
                    Neo4jBasicRelationShip neo4jBasicRelationShip4 = new Neo4jBasicRelationShip();
                    neo4jBasicRelationShip4.setId(IdUtil.fastSimpleUUID());
                    neo4jBasicRelationShip4.setLabel("隐患特征");
                    neo4jBasicRelationShip4.setSource(neo4jBasicNode2.getId());
                    neo4jBasicRelationShip4.setTarget(neo4jBasicNode3.getId());
                    arrayList2.add(neo4jBasicRelationShip4);
                }
                if (neo4jBasicNode2.getNodeType().equals("隐患详细分类") && neo4jBasicNode3.getNodeType().equals("规章制度")) {
                    Neo4jBasicRelationShip neo4jBasicRelationShip5 = new Neo4jBasicRelationShip();
                    neo4jBasicRelationShip5.setId(IdUtil.fastSimpleUUID());
                    neo4jBasicRelationShip5.setLabel("相关文档");
                    neo4jBasicRelationShip5.setSource(neo4jBasicNode2.getId());
                    neo4jBasicRelationShip5.setTarget(neo4jBasicNode3.getId());
                    arrayList2.add(neo4jBasicRelationShip5);
                }
                if (neo4jBasicNode2.getNodeType().equals("电力设备") && neo4jBasicNode3.getNodeType().equals("地点")) {
                    Neo4jBasicRelationShip neo4jBasicRelationShip6 = new Neo4jBasicRelationShip();
                    neo4jBasicRelationShip6.setId(IdUtil.fastSimpleUUID());
                    neo4jBasicRelationShip6.setLabel("所在地点");
                    neo4jBasicRelationShip6.setSource(neo4jBasicNode2.getId());
                    neo4jBasicRelationShip6.setTarget(neo4jBasicNode3.getId());
                    arrayList2.add(neo4jBasicRelationShip6);
                }
                if (neo4jBasicNode2.getNodeType().equals("电力设备") && neo4jBasicNode3.getNodeType().equals("单位")) {
                    Neo4jBasicRelationShip neo4jBasicRelationShip7 = new Neo4jBasicRelationShip();
                    neo4jBasicRelationShip7.setId(IdUtil.fastSimpleUUID());
                    neo4jBasicRelationShip7.setLabel("责任单位");
                    neo4jBasicRelationShip7.setSource(neo4jBasicNode2.getId());
                    neo4jBasicRelationShip7.setTarget(neo4jBasicNode3.getId());
                    arrayList2.add(neo4jBasicRelationShip7);
                }
            }
        }
        jSONObject.put("nodes", arrayList);
        jSONObject.put("edges", arrayList2);
        return ApiResponse.success(jSONObject);
    }

    public ApiResponse namedEntityRecognitionTrain(final String str) {
        final TrainTaskVO trainTaskVO = (TrainTaskVO) this.iTrainTaskService.getInfoById(str).getData();
        final String str2 = this.minioLocation + "/" + str + "/checkpoints";
        final String str3 = this.minioLocation + "/" + str + "/vocabs";
        final String str4 = this.minioLocation + "/" + str + "/logs";
        final StringBuffer stringBuffer = new StringBuffer();
        Map dataStatis = trainTaskVO.getDataStatis();
        int i = 0;
        Iterator it = dataStatis.keySet().iterator();
        while (it.hasNext()) {
            stringBuffer.append((String) it.next());
            i++;
            if (i < dataStatis.size()) {
                stringBuffer.append(",");
            }
        }
        try {
            new Thread(new Runnable() { // from class: com.jxdinfo.hussar.kgbase.algomodel.service.impl.NerServiceImpl.1
                @Override // java.lang.Runnable
                public void run() {
                    try {
                        UpdateWrapper updateWrapper = new UpdateWrapper();
                        updateWrapper.eq("ID", str);
                        updateWrapper.eq("DEL_FLAG", KgTaggingTask1ServiceImpl.TASK_USE_ENTER_GRAPH);
                        updateWrapper.set("TASK_STATE", "1");
                        updateWrapper.set("MODEL_PATH", str2);
                        updateWrapper.set("LOG_PATH", str4);
                        updateWrapper.set("TRAIN_START_TIME", new Date());
                        updateWrapper.set("TRAIN_END_TIME", (Object) null);
                        NerServiceImpl.this.iTrainTaskService.update(updateWrapper);
                        NerServiceImpl.this.trainNoteMapper.deleteNotesByTaskId(str);
                        TrainTask trainTask = (TrainTask) NerServiceImpl.this.iTrainTaskService.getById(str);
                        StringBuffer append = new StringBuffer("python ").append(NerServiceImpl.this.nerLocation).append("/main.py --mode=train ");
                        append.append(" --datasets_fold=").append(trainTaskVO.getSamplePath());
                        append.append(" --vocabs_dir=").append(str3);
                        append.append(" --log_dir=").append(str4);
                        append.append(" --checkpoints_dir=").append(str2);
                        append.append(" --suffix=").append(stringBuffer.toString());
                        append.append(" --epoch=").append(trainTaskVO.getEpoch());
                        append.append(" --batch_size=").append(trainTaskVO.getBatchSize());
                        append.append(" --learning_rate=").append(trainTaskVO.getLearningRate());
                        NerServiceImpl.this.trainProc = Runtime.getRuntime().exec(append.toString(), (String[]) null, new File(NerServiceImpl.this.nerLocation));
                        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader((FileInputStream) NerServiceImpl.this.trainProc.getErrorStream(), "gbk"));
                        String str5 = "";
                        while (true) {
                            String readLine = bufferedReader.readLine();
                            if (readLine == null) {
                                break;
                            }
                            System.out.println(readLine);
                            if (str5.startsWith("epoch finished")) {
                                JSONObject parseObject = JSONObject.parseObject(readLine);
                                TrainNote trainNote = new TrainNote();
                                trainNote.setTrainTaskId(str);
                                trainNote.setCreateTime(new Date());
                                trainNote.setCurrentEpoch(Integer.parseInt(parseObject.getString("epoch")));
                                trainNote.setPrecisionRate(Double.parseDouble(parseObject.getString("precision")));
                                trainNote.setRecall(Double.parseDouble(parseObject.getString("recall")));
                                trainNote.setF1(Double.parseDouble(parseObject.getString("f1")));
                                trainNote.setAccuracy(Double.parseDouble(parseObject.getString("accuracy")));
                                trainNote.setIsBest(parseObject.getString("isBest"));
                                if ("1".equals(trainNote.getIsBest())) {
                                    NerServiceImpl.this.trainNoteMapper.setIsBestToZero(str);
                                }
                                trainNote.setTimeConsumption(parseObject.getString("timeConsumption"));
                                NerServiceImpl.this.iTrainNoteService.save(trainNote);
                                trainTask.setTaskProgress(new DecimalFormat("0.00").format(trainNote.getCurrentEpoch() / trainTask.getEpoch().intValue()));
                                NerServiceImpl.this.iTrainTaskService.updateById(trainTask);
                            }
                            NerServiceImpl.this.nerWebSocket.sendMessage(readLine);
                            str5 = readLine;
                        }
                        bufferedReader.close();
                        if (NerServiceImpl.this.trainProc.waitFor() == 0) {
                            trainTask.setTaskState("2");
                            trainTask.setTrainEndTime(new Date());
                            NerServiceImpl.this.iTrainTaskService.updateById(trainTask);
                        } else {
                            NerServiceImpl.this.iTrainTaskService.stopTraining(str);
                        }
                    } catch (Exception e) {
                        e.printStackTrace();
                        NerServiceImpl.this.iTrainTaskService.stopTraining(str);
                    }
                }
            }).start();
            return ApiResponse.success("开始训练");
        } catch (Exception e) {
            e.printStackTrace();
            this.iTrainTaskService.stopTraining(str);
            return ApiResponse.fail(e.getMessage());
        }
    }

    public boolean nerSeviceStart(String str) {
        QueryWrapper queryWrapper = new QueryWrapper();
        queryWrapper.eq("TRAIN_TASK_ID", str);
        queryWrapper.eq("DEL_FLAG", KgTaggingTask1ServiceImpl.TASK_USE_ENTER_GRAPH);
        final TrainModel trainModel = (TrainModel) this.iTrainModelService.getOne(queryWrapper);
        final TrainTaskVO trainTaskVO = (TrainTaskVO) this.iTrainTaskService.getInfoById(str).getData();
        final String str2 = this.minioLocation + "/" + str + "/checkpoints";
        final String str3 = this.minioLocation + "/" + str + "/vocabs";
        final String str4 = this.minioLocation + "/" + str + "/service_logs";
        final StringBuffer stringBuffer = new StringBuffer();
        Map dataStatis = trainTaskVO.getDataStatis();
        int i = 0;
        Iterator it = dataStatis.keySet().iterator();
        while (it.hasNext()) {
            stringBuffer.append((String) it.next());
            i++;
            if (i < dataStatis.size()) {
                stringBuffer.append(",");
            }
        }
        try {
            new Thread(new Runnable() { // from class: com.jxdinfo.hussar.kgbase.algomodel.service.impl.NerServiceImpl.2
                @Override // java.lang.Runnable
                public void run() {
                    try {
                        StringBuffer append = new StringBuffer("python ").append(NerServiceImpl.this.nerLocation).append("/flask_service.py");
                        append.append(" --datasets_fold=").append(trainTaskVO.getSamplePath());
                        append.append(" --vocabs_dir=").append(str3);
                        append.append(" --log_dir=").append(str4);
                        append.append(" --checkpoints_dir=").append(str2);
                        append.append(" --suffix=").append(stringBuffer.toString());
                        append.append(" --epoch=").append(trainTaskVO.getEpoch());
                        append.append(" --batch_size=").append(trainTaskVO.getBatchSize());
                        append.append(" --learning_rate=").append(trainTaskVO.getLearningRate());
                        NerServiceImpl.this.serviceProc = Runtime.getRuntime().exec(append.toString(), (String[]) null, new File(NerServiceImpl.this.nerLocation));
                        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader((FileInputStream) NerServiceImpl.this.serviceProc.getErrorStream(), "gbk"));
                        while (true) {
                            String readLine = bufferedReader.readLine();
                            if (readLine == null) {
                                bufferedReader.close();
                                return;
                            }
                            System.out.println(readLine);
                            if (readLine.startsWith("service launched")) {
                                trainModel.setModelState("1");
                                NerServiceImpl.this.iTrainModelService.updateById(trainModel);
                            }
                        }
                    } catch (Exception e) {
                        trainModel.setModelState(KgTaggingTask1ServiceImpl.PASS_STATUS_FLAG);
                        NerServiceImpl.this.iTrainModelService.updateById(trainModel);
                        e.printStackTrace();
                    }
                }
            }).start();
            return true;
        } catch (HussarException e) {
            e.printStackTrace();
            return false;
        }
    }

    public void destoryNerService() {
        this.serviceProc.destroy();
    }

    public void destoryNerTrain() {
        if (this.trainProc != null) {
            this.trainProc.destroy();
        }
    }
}
