/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.graphalgo.beta.modularity;

import com.carrotsearch.hppc.BitSet;
import com.carrotsearch.hppc.cursors.LongLongCursor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.concurrent.ExecutorService;
import java.util.stream.BaseStream;
import java.util.stream.LongStream;
import org.apache.commons.lang3.mutable.MutableDouble;
import org.jetbrains.annotations.Nullable;
import org.neo4j.graphalgo.Algorithm;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.api.NodeProperties;
import org.neo4j.graphalgo.api.RelationshipIterator;
import org.neo4j.graphalgo.beta.k1coloring.ImmutableK1ColoringStreamConfig;
import org.neo4j.graphalgo.beta.k1coloring.K1Coloring;
import org.neo4j.graphalgo.beta.k1coloring.K1ColoringFactory;
import org.neo4j.graphalgo.beta.k1coloring.K1ColoringStreamConfig;
import org.neo4j.graphalgo.beta.modularity.ModularityOptimizationTask;
import org.neo4j.graphalgo.core.concurrency.ParallelUtil;
import org.neo4j.graphalgo.core.utils.ProgressLogger;
import org.neo4j.graphalgo.core.utils.paged.AllocationTracker;
import org.neo4j.graphalgo.core.utils.paged.HugeAtomicDoubleArray;
import org.neo4j.graphalgo.core.utils.paged.HugeDoubleArray;
import org.neo4j.graphalgo.core.utils.paged.HugeLongArray;
import org.neo4j.graphalgo.core.utils.paged.HugeLongLongMap;
import org.neo4j.graphalgo.core.utils.paged.LongPageCreator;

public final class ModularityOptimization
extends Algorithm<ModularityOptimization, ModularityOptimization> {
    private final int concurrency;
    private final int maxIterations;
    private final long nodeCount;
    private final long batchSize;
    private final double tolerance;
    private final Graph graph;
    private final NodeProperties seedProperty;
    private final ExecutorService executor;
    private final AllocationTracker tracker;
    private int iterationCounter;
    private boolean didConverge = false;
    private double totalNodeWeight = 0.0;
    private double modularity = -1.0;
    private BitSet colorsUsed;
    private HugeLongArray colors;
    private HugeLongArray currentCommunities;
    private HugeLongArray nextCommunities;
    private HugeLongArray reverseSeedCommunityMapping;
    private HugeDoubleArray cumulativeNodeWeights;
    private HugeDoubleArray nodeCommunityInfluences;
    private HugeAtomicDoubleArray communityWeights;
    private HugeAtomicDoubleArray communityWeightUpdates;

    public ModularityOptimization(Graph graph, int maxIterations, double tolerance, @Nullable NodeProperties seedProperty, int concurrency, int minBatchSize, ExecutorService executor, ProgressLogger progressLogger, AllocationTracker tracker) {
        this.graph = graph;
        this.nodeCount = graph.nodeCount();
        this.maxIterations = maxIterations;
        this.tolerance = tolerance;
        this.seedProperty = seedProperty;
        this.executor = executor;
        this.concurrency = concurrency;
        this.progressLogger = progressLogger;
        this.tracker = tracker;
        this.batchSize = ParallelUtil.adjustedBatchSize((long)this.nodeCount, (int)concurrency, (long)minBatchSize, (long)Integer.MAX_VALUE);
        if (maxIterations < 1) {
            throw new IllegalArgumentException(String.format("Need to run at least one iteration, but got %d", maxIterations));
        }
    }

    public ModularityOptimization compute() {
        this.progressLogger.logMessage(":: Start");
        this.progressLogger.logMessage(":: Initialization :: Start");
        this.computeColoring();
        this.initSeeding();
        this.init();
        this.progressLogger.logMessage(":: Initialization :: Finished");
        this.iterationCounter = 0;
        while (this.iterationCounter < this.maxIterations) {
            this.progressLogger.logMessage(String.format(":: Iteration %d :: Start", this.iterationCounter + 1));
            this.nodeCommunityInfluences.fill(0.0);
            long currentColor = this.colorsUsed.nextSetBit(0);
            while (currentColor != -1L) {
                this.assertRunning();
                this.optimizeForColor(currentColor);
                currentColor = this.colorsUsed.nextSetBit(currentColor + 1L);
            }
            boolean hasConverged = !this.updateModularity();
            this.progressLogger.logMessage(String.format(":: Iteration %d :: Finished", this.iterationCounter + 1));
            if (hasConverged) {
                this.didConverge = true;
                ++this.iterationCounter;
                break;
            }
            this.progressLogger.reset(this.graph.relationshipCount());
            ++this.iterationCounter;
        }
        this.progressLogger.logMessage(":: Finished");
        return this;
    }

    private void computeColoring() {
        K1ColoringStreamConfig k1Config = ImmutableK1ColoringStreamConfig.builder().concurrency(this.concurrency).maxIterations(5).batchSize((int)this.batchSize).build();
        K1Coloring coloring = (K1Coloring)new K1ColoringFactory<K1ColoringStreamConfig>().build(this.graph, k1Config, this.tracker, this.progressLogger.getLog()).withTerminationFlag(this.terminationFlag);
        this.colors = coloring.compute();
        this.colorsUsed = coloring.usedColors();
    }

    private void initSeeding() {
        this.currentCommunities = HugeLongArray.newArray((long)this.nodeCount, (AllocationTracker)this.tracker);
        if (this.seedProperty == null) {
            return;
        }
        long maxSeedCommunity = this.seedProperty.getMaxPropertyValue().orElse(0L);
        HugeLongLongMap communityMapping = new HugeLongLongMap(this.nodeCount, this.tracker);
        long nextAvailableInternalCommunityId = -1L;
        for (long nodeId = 0L; nodeId < this.nodeCount; ++nodeId) {
            long seedCommunity = (long)this.seedProperty.nodeProperty(nodeId, -1.0);
            long l = seedCommunity = seedCommunity >= 0L ? seedCommunity : this.graph.toOriginalNodeId(nodeId) + maxSeedCommunity;
            if (communityMapping.getOrDefault(seedCommunity, -1L) < 0L) {
                communityMapping.addTo(seedCommunity, ++nextAvailableInternalCommunityId);
            }
            this.currentCommunities.set(nodeId, communityMapping.getOrDefault(seedCommunity, -1L));
        }
        this.reverseSeedCommunityMapping = HugeLongArray.newArray((long)communityMapping.size(), (AllocationTracker)this.tracker);
        for (LongLongCursor entry : communityMapping) {
            this.reverseSeedCommunityMapping.set(entry.value, entry.key);
        }
    }

    private void init() {
        this.nextCommunities = HugeLongArray.newArray((long)this.nodeCount, (AllocationTracker)this.tracker);
        this.cumulativeNodeWeights = HugeDoubleArray.newArray((long)this.nodeCount, (AllocationTracker)this.tracker);
        this.nodeCommunityInfluences = HugeDoubleArray.newArray((long)this.nodeCount, (AllocationTracker)this.tracker);
        this.communityWeights = HugeAtomicDoubleArray.newArray((long)this.nodeCount, (AllocationTracker)this.tracker);
        this.communityWeightUpdates = HugeAtomicDoubleArray.newArray((long)this.nodeCount, (AllocationTracker)this.tracker);
        ThreadLocal<RelationshipIterator> graphCopy = ThreadLocal.withInitial(() -> ((Graph)this.graph).concurrentCopy());
        double doubleTotalNodeWeight = (Double)ParallelUtil.parallelStream((BaseStream)LongStream.range(0L, this.nodeCount), (int)this.concurrency, nodeStream -> nodeStream.mapToDouble(nodeId -> {
            if (this.seedProperty == null) {
                this.currentCommunities.set(nodeId, nodeId);
            }
            MutableDouble cumulativeWeight = new MutableDouble(0.0);
            ((RelationshipIterator)graphCopy.get()).forEachRelationship(nodeId, 1.0, (s, t, w) -> {
                cumulativeWeight.add(w);
                return true;
            });
            this.communityWeights.update(this.currentCommunities.get(nodeId), acc -> acc + cumulativeWeight.doubleValue());
            this.cumulativeNodeWeights.set(nodeId, cumulativeWeight.doubleValue());
            return cumulativeWeight.doubleValue();
        }).reduce(Double::sum).orElseThrow(() -> new RuntimeException("Error initializing modularity optimization.")));
        this.totalNodeWeight = doubleTotalNodeWeight / 2.0;
        this.currentCommunities.copyTo(this.nextCommunities, this.nodeCount);
    }

    private void optimizeForColor(long currentColor) {
        ParallelUtil.runWithConcurrency((int)this.concurrency, this.createModularityOptimizationTasks(currentColor), (ExecutorService)this.executor);
        this.nextCommunities.copyTo(this.currentCommunities, this.nodeCount);
        ParallelUtil.parallelStreamConsume((BaseStream)LongStream.range(0L, this.nodeCount), (int)this.concurrency, stream -> stream.forEach(nodeId -> {
            double update = this.communityWeightUpdates.get(nodeId);
            this.communityWeights.update(nodeId, w -> w + update);
        }));
        this.communityWeightUpdates = HugeAtomicDoubleArray.newArray((long)this.nodeCount, (LongPageCreator)LongPageCreator.passThrough((int)1), (AllocationTracker)this.tracker);
    }

    private Collection<ModularityOptimizationTask> createModularityOptimizationTasks(long currentColor) {
        ArrayList<ModularityOptimizationTask> tasks = new ArrayList<ModularityOptimizationTask>(this.concurrency);
        for (long i = 0L; i < this.nodeCount; i += this.batchSize) {
            tasks.add(new ModularityOptimizationTask(this.graph, i, Math.min(i + this.batchSize, this.nodeCount), currentColor, this.totalNodeWeight, this.colors, this.currentCommunities, this.nextCommunities, this.cumulativeNodeWeights, this.nodeCommunityInfluences, this.communityWeights, this.communityWeightUpdates, this.getProgressLogger()));
        }
        return tasks;
    }

    private boolean updateModularity() {
        double oldModularity = this.modularity;
        this.modularity = this.calculateModularity();
        return this.modularity > oldModularity && Math.abs(this.modularity - oldModularity) > this.tolerance;
    }

    private double calculateModularity() {
        double ex = (Double)ParallelUtil.parallelStream((BaseStream)LongStream.range(0L, this.nodeCount), (int)this.concurrency, nodeStream -> nodeStream.mapToDouble(arg_0 -> ((HugeDoubleArray)this.nodeCommunityInfluences).get(arg_0)).reduce(Double::sum).orElseThrow(() -> new RuntimeException("Error while comptuing modularity")));
        double ax = (Double)ParallelUtil.parallelStream((BaseStream)LongStream.range(0L, this.nodeCount), (int)this.concurrency, nodeStream -> nodeStream.mapToDouble(nodeId -> Math.pow(this.communityWeights.get(nodeId), 2.0)).reduce(Double::sum).orElseThrow(() -> new RuntimeException("Error while comptuing modularity")));
        return ex / (2.0 * this.totalNodeWeight) - ax / Math.pow(2.0 * this.totalNodeWeight, 2.0);
    }

    public ModularityOptimization me() {
        return this;
    }

    public void release() {
        this.nextCommunities.release();
        this.communityWeights.release();
        this.communityWeightUpdates.release();
        this.cumulativeNodeWeights.release();
        this.nodeCommunityInfluences.release();
        this.colors.release();
        this.colorsUsed = null;
    }

    public long getCommunityId(long nodeId) {
        if (this.seedProperty == null || this.reverseSeedCommunityMapping == null) {
            return this.currentCommunities.get(nodeId);
        }
        return this.reverseSeedCommunityMapping.get(this.currentCommunities.get(nodeId));
    }

    public int getIterations() {
        return this.iterationCounter;
    }

    public double getModularity() {
        return this.modularity;
    }

    public boolean didConverge() {
        return this.didConverge;
    }

    public double getTolerance() {
        return this.tolerance;
    }
}

