/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.mica.mqtt.core.server.session;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.BiFunction;
import java.util.function.BinaryOperator;
import java.util.stream.Collectors;
import org.dromara.mica.mqtt.core.common.TopicFilter;
import org.dromara.mica.mqtt.core.common.TopicFilterType;
import org.dromara.mica.mqtt.core.server.model.Subscribe;
import org.dromara.mica.mqtt.core.util.TopicUtil;
import org.tio.utils.hutool.CollUtil;

public class TrieTopicManager {
    public static final BinaryOperator<Byte> MAX_QOS = (a, b) -> a > b ? a : b;
    private final Node root = Node.getRoot("root");
    private final Map<String, Node> share = new ConcurrentHashMap<String, Node>();
    private final Node queue = Node.getRoot("$queue");

    public void addSubscribe(String topicFilter, String clientId, int mqttQoS) {
        this.addSubscribe(new TopicFilter(topicFilter), clientId, (int)((short)mqttQoS));
    }

    public void addSubscribe(TopicFilter topicFilter, String clientId, int mqttQoS) {
        String topic = topicFilter.getTopic();
        TopicFilterType topicFilterType = topicFilter.getType();
        if (TopicFilterType.NONE == topicFilterType) {
            TrieTopicManager.addSubscribe(this.root, topic, clientId, (byte)mqttQoS);
        } else if (TopicFilterType.QUEUE == topicFilterType) {
            int prefixLen = "$queue/".length();
            TrieTopicManager.addSubscribe(this.queue, topic.substring(prefixLen), clientId, (byte)mqttQoS);
        } else if (TopicFilterType.SHARE == topicFilterType) {
            int prefixLen = "$share/".length();
            String groupName = TopicFilterType.getShareGroupName((String)topic);
            Node groupNode = this.share.computeIfAbsent(groupName, Node::getNode);
            prefixLen = prefixLen + groupName.length() + 1;
            TrieTopicManager.addSubscribe(groupNode, topic.substring(prefixLen), clientId, (byte)mqttQoS);
        }
    }

    private static void addSubscribe(Node node, String topicFilter, String clientId, byte mqttQoS) {
        Node prev = node;
        String[] topicParts = TopicUtil.getTopicParts((String)topicFilter);
        int partLength = topicParts.length - 1;
        for (int i = 0; i < topicParts.length; ++i) {
            boolean isEnd;
            prev = prev.addChildIfAbsent(topicParts[i]);
            boolean bl = isEnd = i == partLength;
            if (!isEnd) continue;
            assert (prev.subscriptions != null);
            Byte existingQos = (Byte)prev.subscriptions.get(clientId);
            if (existingQos != null && existingQos >= mqttQoS) continue;
            prev.subscriptions.put(clientId, mqttQoS);
        }
    }

    public void removeSubscribe(String topicFilter, String clientId) {
        this.removeSubscribe(new TopicFilter(topicFilter), clientId);
    }

    private void removeSubscribe(TopicFilter topicFilter, String clientId) {
        String topic = topicFilter.getTopic();
        TopicFilterType topicFilterType = topicFilter.getType();
        if (TopicFilterType.NONE == topicFilterType) {
            TrieTopicManager.removeSubscribe(this.root, topic, clientId);
        } else if (TopicFilterType.QUEUE == topicFilterType) {
            int prefixLen = "$queue/".length();
            TrieTopicManager.removeSubscribe(this.queue, topic.substring(prefixLen), clientId);
        } else if (TopicFilterType.SHARE == topicFilterType) {
            int prefixLen = "$share/".length();
            String groupName = TopicFilterType.getShareGroupName((String)topic);
            Node groupNode = this.share.computeIfAbsent(groupName, Node::getNode);
            prefixLen = prefixLen + groupName.length() + 1;
            TrieTopicManager.removeSubscribe(groupNode, topic.substring(prefixLen), clientId);
        }
    }

    private static void removeSubscribe(Node node, String topicFilter, String clientId) {
        String[] topicParts;
        Node prev = node;
        for (String part : topicParts = TopicUtil.getTopicParts((String)topicFilter)) {
            Node nodePart = prev.findNodeByPart(part);
            if (nodePart == null) {
                prev = null;
                break;
            }
            prev = nodePart;
        }
        if (prev != null) {
            assert (prev.subscriptions != null);
            prev.subscriptions.remove(clientId);
        }
    }

    public void removeSubscribe(String clientId) {
        TrieTopicManager.removeSubscribe(this.root, clientId);
        TrieTopicManager.removeSubscribe(this.queue, clientId);
        for (Node node : this.share.values()) {
            TrieTopicManager.removeSubscribe(node, clientId);
        }
    }

    private static void removeSubscribe(Node node, String clientId) {
        assert (node.children != null);
        for (Node child : node.children.values()) {
            TrieTopicManager.removeSubscribeRecursively(child, clientId);
        }
    }

    private static void removeSubscribeRecursively(Node child, String clientId) {
        assert (child.subscriptions != null);
        child.subscriptions.remove(clientId);
        assert (child.children != null);
        for (Node node : child.children.values()) {
            TrieTopicManager.removeSubscribeRecursively(node, clientId);
        }
    }

    public List<Subscribe> getSubscriptions(String clientId) {
        List<Subscribe> subscribeList = TrieTopicManager.getSubscriptions(this.root, null, clientId);
        subscribeList.addAll(TrieTopicManager.getSubscriptions(this.queue, "$queue/", clientId));
        for (Map.Entry<String, Node> entry : this.share.entrySet()) {
            String prefix = "$share/" + entry.getKey() + "/";
            subscribeList.addAll(TrieTopicManager.getSubscriptions(entry.getValue(), prefix, clientId));
        }
        return subscribeList.stream().distinct().collect(Collectors.toList());
    }

    private static List<Subscribe> getSubscriptions(Node node, String prefix, String clientId) {
        ArrayList<Subscribe> subscribeList = new ArrayList<Subscribe>();
        for (Node child : node.children.values()) {
            String topicPrefix = prefix == null ? child.part : prefix + child.part;
            TrieTopicManager.getSubscribeRecursively(subscribeList, child, topicPrefix, clientId);
        }
        return subscribeList;
    }

    private static void getSubscribeRecursively(List<Subscribe> subscribeList, Node child, String childPart, String clientId) {
        assert (child.subscriptions != null);
        Byte qos = (Byte)child.subscriptions.get(clientId);
        if (qos != null) {
            subscribeList.add(new Subscribe(childPart, clientId, qos.byteValue()));
        }
        assert (child.children != null);
        for (Node node : child.children.values()) {
            String topicPrefix = TrieTopicManager.isNotNeedAppendTopicLayer(childPart, node.part) ? childPart + node.part : childPart + '/' + node.part;
            TrieTopicManager.getSubscribeRecursively(subscribeList, node, topicPrefix, clientId);
        }
    }

    private static boolean isNotNeedAppendTopicLayer(String prefix, String suffix) {
        return "/".equals(prefix) || prefix.endsWith("//") || "/".equals(suffix);
    }

    public Byte searchSubscribe(String topicName, String clientId) {
        String[] topicParts = TopicUtil.getTopicParts((String)topicName);
        HashMap<String, Byte> subscribeMap = new HashMap<String, Byte>(32);
        TrieTopicManager.searchSubscribeRecursively(this.root, subscribeMap, topicParts, 0);
        Byte qos = (Byte)subscribeMap.get(clientId);
        if (qos != null) {
            return qos;
        }
        TrieTopicManager.searchSubscribeRecursively(this.queue, subscribeMap, topicParts, 0);
        qos = (Byte)subscribeMap.get(clientId);
        if (qos != null) {
            return qos;
        }
        for (Node node : this.share.values()) {
            TrieTopicManager.searchSubscribeRecursively(node, subscribeMap, topicParts, 0);
        }
        return (Byte)subscribeMap.get(clientId);
    }

    public List<Subscribe> searchSubscribe(String topicName) {
        String[] topicParts = TopicUtil.getTopicParts((String)topicName);
        HashMap<String, Byte> subscribeMap = new HashMap<String, Byte>(32);
        TrieTopicManager.searchSubscribeRecursively(this.root, subscribeMap, topicParts, 0);
        HashMap<String, Byte> queueSubscribeMap = new HashMap<String, Byte>(8);
        TrieTopicManager.searchSubscribeRecursively(this.queue, queueSubscribeMap, topicParts, 0);
        if (!queueSubscribeMap.isEmpty()) {
            TrieTopicManager.randomStrategy(subscribeMap, queueSubscribeMap);
        }
        for (Node node : this.share.values()) {
            HashMap<String, Byte> shareSubscribeMap = new HashMap<String, Byte>(8);
            TrieTopicManager.searchSubscribeRecursively(node, shareSubscribeMap, topicParts, 0);
            if (shareSubscribeMap.isEmpty()) continue;
            TrieTopicManager.randomStrategy(subscribeMap, shareSubscribeMap);
        }
        ArrayList<Subscribe> subscribeList = new ArrayList<Subscribe>();
        subscribeMap.forEach((clientId, qos) -> subscribeList.add(new Subscribe((String)clientId, qos.byteValue())));
        subscribeMap.clear();
        return subscribeList;
    }

    private static void searchSubscribeRecursively(Node node, Map<String, Byte> subscribeMap, String[] topicParts, int index) {
        String topicPart;
        Node nodePart;
        if (index >= topicParts.length) {
            return;
        }
        Node nodeMore = node.findNodeByPart("#");
        if (nodeMore != null) {
            for (Map.Entry entry : nodeMore.subscriptions.entrySet()) {
                subscribeMap.merge((String)entry.getKey(), (Byte)entry.getValue(), (BiFunction<Byte, Byte, Byte>)MAX_QOS);
            }
        }
        int topicPartLen = topicParts.length - 1;
        Node nodeOne = node.findNodeByPart("+");
        if (nodeOne != null) {
            if (index == topicPartLen) {
                for (Map.Entry entry : nodeOne.subscriptions.entrySet()) {
                    subscribeMap.merge((String)entry.getKey(), (Byte)entry.getValue(), (BiFunction<Byte, Byte, Byte>)MAX_QOS);
                }
            } else {
                TrieTopicManager.searchSubscribeRecursively(nodeOne, subscribeMap, topicParts, index + 1);
            }
        }
        if ((nodePart = node.findNodeByPart(topicPart = topicParts[index])) != null) {
            if (index == topicPartLen) {
                for (Map.Entry entry : nodePart.subscriptions.entrySet()) {
                    subscribeMap.merge((String)entry.getKey(), (Byte)entry.getValue(), (BiFunction<Byte, Byte, Byte>)MAX_QOS);
                }
                Node nodePartMore = nodePart.findNodeByPart("#");
                if (nodePartMore != null) {
                    for (Map.Entry entry : nodePartMore.subscriptions.entrySet()) {
                        subscribeMap.merge((String)entry.getKey(), (Byte)entry.getValue(), (BiFunction<Byte, Byte, Byte>)MAX_QOS);
                    }
                }
            } else {
                TrieTopicManager.searchSubscribeRecursively(nodePart, subscribeMap, topicParts, index + 1);
            }
        }
    }

    public void clear() {
        this.root.children.clear();
        this.queue.children.clear();
        this.share.clear();
    }

    public String toString() {
        return "TrieTopicManager{root=" + this.root + ", share=" + this.share + ", queue=" + this.queue + '}';
    }

    private static void randomStrategy(Map<String, Byte> subscribeMap, Map<String, Byte> randomSubscribeMap) {
        String[] keys = randomSubscribeMap.keySet().toArray(new String[0]);
        int keyLength = keys.length;
        String key = keyLength > 1 ? keys[ThreadLocalRandom.current().nextInt(keyLength)] : keys[0];
        subscribeMap.merge(key, randomSubscribeMap.get(key), MAX_QOS);
    }

    private static class Node {
        private final String part;
        private final Map<String, Byte> subscriptions;
        private final Map<String, Node> children;

        private Node(String part, Map<String, Byte> subscriptions, Map<String, Node> children) {
            this.part = part;
            this.subscriptions = subscriptions;
            this.children = children;
        }

        protected static Node getRoot(String name) {
            return new Node(name, null, new ConcurrentHashMap<String, Node>(8));
        }

        protected static Node getNode(String part) {
            return new Node(part, new ConcurrentHashMap<String, Byte>(16), new ConcurrentHashMap<String, Node>(16));
        }

        protected Node addChildIfAbsent(String nodePart) {
            assert (this.children != null);
            return (Node)CollUtil.computeIfAbsent(this.children, (Object)nodePart, Node::getNode);
        }

        protected Node findNodeByPart(String nodePart) {
            assert (this.children != null);
            return this.children.get(nodePart);
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || Node.class != o.getClass() || this.part == null) {
                return false;
            }
            return this.part.equals(((Node)o).part);
        }

        public int hashCode() {
            return this.part == null ? 0 : this.part.hashCode();
        }

        public String toString() {
            return "Node{part='" + this.part + '\'' + ", subscriptions=" + this.subscriptions + ", children=" + this.children + '}';
        }
    }
}

