/*
 * Decompiled with CFR 0.152.
 */
package zju.cst.aces.runner;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.junit.platform.launcher.listeners.TestExecutionSummary;
import zju.cst.aces.api.config.Config;
import zju.cst.aces.api.impl.ChatGenerator;
import zju.cst.aces.api.impl.PromptConstructorImpl;
import zju.cst.aces.api.impl.RepairImpl;
import zju.cst.aces.api.impl.obfuscator.Obfuscator;
import zju.cst.aces.dto.ChatResponse;
import zju.cst.aces.dto.Message;
import zju.cst.aces.dto.MethodInfo;
import zju.cst.aces.dto.PromptInfo;
import zju.cst.aces.dto.RoundRecord;
import zju.cst.aces.dto.TestMessage;
import zju.cst.aces.dto.TestSkeleton;
import zju.cst.aces.runner.ClassRunner;
import zju.cst.aces.util.CodeExtractor;
import zju.cst.aces.util.TestProcessor;

public class MethodRunner
extends ClassRunner {
    public MethodInfo methodInfo;

    public MethodRunner(Config config, String fullClassName, MethodInfo methodInfo) throws IOException {
        super(config, fullClassName);
        this.methodInfo = methodInfo;
    }

    @Override
    public void start() throws IOException {
        if (!this.config.isStopWhenSuccess() && this.config.isEnableMultithreading()) {
            final ExecutorService executor = Executors.newFixedThreadPool(this.config.getTestNumber());
            ArrayList<Future<String>> futures = new ArrayList<Future<String>>();
            int num = 0;
            while (num < this.config.getTestNumber()) {
                final int n = num++;
                Callable<String> callable = new Callable<String>(){

                    @Override
                    public String call() throws Exception {
                        MethodRunner.this.startRounds(n);
                        return "";
                    }
                };
                Future<String> future = executor.submit(callable);
                futures.add(future);
            }
            Runtime.getRuntime().addShutdownHook(new Thread(){

                @Override
                public void run() {
                    executor.shutdownNow();
                }
            });
            for (Future future : futures) {
                try {
                    String result = (String)future.get();
                    System.out.println(result);
                }
                catch (InterruptedException | ExecutionException e) {
                    e.printStackTrace();
                }
            }
            executor.shutdown();
        } else {
            for (int num = 0; !(num >= this.config.getTestNumber() || this.startRounds(num) && this.config.isStopWhenSuccess()); ++num) {
            }
        }
    }

    public boolean startRounds(int num) throws IOException {
        String testName = this.className + "_" + this.methodInfo.methodName + "_" + this.classInfo.methodSigs.get(this.methodInfo.methodSignature) + "_" + num + "_Test";
        String fullTestName = this.fullClassName + "_" + this.methodInfo.methodName + "_" + this.classInfo.methodSigs.get(this.methodInfo.methodSignature) + "_" + num + "_Test";
        this.config.getLog().info("\n==========================\n[ChatUniTest] Generating test for method < " + this.methodInfo.methodName + " > number " + num + "...\n");
        ChatGenerator generator = new ChatGenerator(this.config);
        PromptConstructorImpl pc = new PromptConstructorImpl(this.config);
        RepairImpl repair = new RepairImpl(this.config, pc);
        if (!this.methodInfo.dependentMethods.isEmpty()) {
            pc.setPromptInfoWithDep(this.classInfo, this.methodInfo);
        } else {
            pc.setPromptInfoWithoutDep(this.classInfo, this.methodInfo);
        }
        pc.setFullTestName(fullTestName);
        pc.setTestName(testName);
        PromptInfo promptInfo = pc.getPromptInfo();
        promptInfo.setFullTestName(fullTestName);
        Path savePath = this.config.getTestOutput().resolve(fullTestName.replace(".", File.separator) + ".java");
        promptInfo.setTestPath(savePath);
        for (int rounds = 0; rounds < this.config.getMaxRounds(); ++rounds) {
            List<Message> prompt;
            promptInfo.addRecord(new RoundRecord(rounds));
            RoundRecord record = promptInfo.getRecords().get(rounds);
            record.setAttempt(num);
            if (rounds == 0) {
                this.config.getLog().info("Generating test for method < " + this.methodInfo.methodName + " > round " + rounds + " ...");
            } else {
                this.config.getLog().info("Fixing test for method < " + this.methodInfo.methodName + " > round " + rounds + " ...");
            }
            Obfuscator obfuscator = new Obfuscator(this.config);
            if (this.config.isEnableObfuscate()) {
                PromptInfo obfuscatedPromptInfo = new PromptInfo(promptInfo);
                obfuscator.obfuscatePromptInfo(obfuscatedPromptInfo);
                prompt = this.promptGenerator.generateMessages(obfuscatedPromptInfo);
            } else {
                prompt = this.promptGenerator.generateMessages(promptInfo);
            }
            String code = this.generateTest(prompt, record);
            if (!record.isHasCode()) continue;
            if (this.config.isEnableObfuscate()) {
                code = obfuscator.deobfuscateJava(code);
            }
            if (CodeExtractor.isTestMethod(code)) {
                TestSkeleton skeleton = new TestSkeleton(promptInfo);
                code = skeleton.build(code);
            } else {
                code = repair.ruleBasedRepair(code);
            }
            promptInfo.setUnitTest(code);
            record.setCode(code);
            repair.LLMBasedRepair(code, record.getRound());
            if (repair.isSuccess()) {
                record.setHasError(false);
                this.exportRecord(promptInfo, this.classInfo, record.getAttempt());
                return true;
            }
            record.setHasError(true);
            record.setErrorMsg(promptInfo.getErrorMsg());
        }
        this.exportRecord(pc.getPromptInfo(), this.classInfo, num);
        return false;
    }

    public String generateTest(List<Message> prompt, RoundRecord record) throws IOException {
        if (MethodRunner.isExceedMaxTokens(this.config.getMaxPromptTokens(), prompt)) {
            this.config.getLog().error("Exceed max prompt tokens: " + this.methodInfo.methodName + " Skipped.");
            return "";
        }
        this.config.getLog().debug("[Prompt]:\n" + prompt.toString());
        ChatResponse response = ChatGenerator.chat(this.config, prompt);
        String content = ChatGenerator.getContentByResponse(response);
        this.config.getLog().debug("[Response]:\n" + content);
        String code = ChatGenerator.extractCodeByContent(content);
        record.setPromptToken(response.getUsage().getPromptTokens());
        record.setResponseToken(response.getUsage().getCompletionTokens());
        record.setPrompt(prompt);
        record.setResponse(content);
        if (code.isEmpty()) {
            this.config.getLog().info("Test for method < " + this.methodInfo.methodName + " > extract code failed");
            record.setHasCode(false);
            return "";
        }
        record.setHasCode(true);
        return code;
    }

    public String generateTest(List<Message> prompt) throws IOException {
        if (MethodRunner.isExceedMaxTokens(this.config.getMaxPromptTokens(), prompt)) {
            this.config.getLog().error("Exceed max prompt tokens: " + this.methodInfo.methodName + " Skipped.");
            return "";
        }
        this.config.getLog().debug("[Prompt]:\n" + prompt.toString());
        ChatResponse response = ChatGenerator.chat(this.config, prompt);
        String content = ChatGenerator.getContentByResponse(response);
        String code = ChatGenerator.extractCodeByContent(content);
        if (code.isEmpty()) {
            this.config.getLog().info("Test for method < " + this.methodInfo.methodName + " > extract code failed");
            return "";
        }
        return code;
    }

    public static boolean runTest(Config config, String fullTestName, PromptInfo promptInfo, int rounds) {
        String testName = fullTestName.substring(fullTestName.lastIndexOf(".") + 1);
        Path savePath = config.getTestOutput().resolve(fullTestName.replace(".", File.separator) + ".java");
        if (promptInfo.getTestPath() == null) {
            promptInfo.setTestPath(savePath);
        }
        TestProcessor testProcessor = new TestProcessor(fullTestName);
        String code = promptInfo.getUnitTest();
        if (rounds >= 1) {
            code = testProcessor.addCorrectTest(promptInfo);
        }
        Path compilationErrorPath = config.getErrorOutput().resolve(testName + "_CompilationError_" + rounds + ".txt");
        Path executionErrorPath = config.getErrorOutput().resolve(testName + "_ExecutionError_" + rounds + ".txt");
        boolean compileResult = config.getValidator().semanticValidate(code, testName, compilationErrorPath, promptInfo);
        if (!compileResult) {
            config.getLog().info("Test for method < " + promptInfo.getMethodInfo().getMethodName() + " > compilation failed round " + rounds);
            return false;
        }
        if (config.isNoExecution()) {
            MethodRunner.exportTest(code, savePath);
            config.getLog().info("Test for method < " + promptInfo.getMethodInfo().getMethodName() + " > generated successfully round " + rounds);
            return true;
        }
        TestExecutionSummary summary = config.getValidator().execute(fullTestName);
        if (summary.getTestsFailedCount() > 0L || summary.getTestsSucceededCount() == 0L) {
            String testProcessed = testProcessor.removeErrorTest(promptInfo, summary);
            if (testProcessed != null) {
                config.getLog().debug("[Original Test]:\n" + code);
                if (config.getValidator().semanticValidate(testProcessed, testName, compilationErrorPath, null) && config.getValidator().runtimeValidate(fullTestName)) {
                    MethodRunner.exportTest(testProcessed, savePath);
                    config.getLog().debug("[Processed Test]:\n" + testProcessed);
                    config.getLog().info("Processed test for method < " + promptInfo.getMethodInfo().getMethodName() + " > generated successfully round " + rounds);
                    return true;
                }
                testProcessor.removeCorrectTest(promptInfo, summary);
            }
            TestMessage testMessage = new TestMessage();
            ArrayList<String> errors = new ArrayList<String>();
            summary.getFailures().forEach(failure -> {
                for (StackTraceElement st : failure.getException().getStackTrace()) {
                    if (!st.getClassName().contains(fullTestName)) continue;
                    errors.add("Error in " + failure.getTestIdentifier().getLegacyReportingName() + ": line " + st.getLineNumber() + " : " + failure.getException().toString());
                }
            });
            testMessage.setErrorType(TestMessage.ErrorType.RUNTIME_ERROR);
            testMessage.setErrorMessage(errors);
            promptInfo.setErrorMsg(testMessage);
            MethodRunner.exportError(code, errors, executionErrorPath);
            testProcessor.removeCorrectTest(promptInfo, summary);
            config.getLog().info("Test for method < " + promptInfo.getMethodInfo().getMethodName() + " > execution failed round " + rounds);
            return false;
        }
        MethodRunner.exportTest(code, savePath);
        config.getLog().info("Test for method < " + promptInfo.getMethodInfo().getMethodName() + " > compile and execute successfully round " + rounds);
        return true;
    }

    public static void exportError(String code, List<String> errors, Path outputPath) {
        try {
            BufferedWriter writer = new BufferedWriter(new FileWriter(outputPath.toFile()));
            writer.write(code);
            writer.write("\n--------------------------------------------\n");
            writer.write(String.join((CharSequence)"\n", errors));
            writer.close();
        }
        catch (Exception e) {
            throw new RuntimeException("In TestCompiler.exportError: " + e);
        }
    }
}

