/*
 * Decompiled with CFR 0.152.
 */
package io.github.javpower.vectorex.keynote.bm25;

import io.github.javpower.vectorex.keynote.analysis.ScoredEntity;
import io.github.javpower.vectorex.keynote.analysis.SegMode;
import io.github.javpower.vectorex.keynote.analysis.TextSegmenter;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.stream.Collectors;

public class BM25 {
    private final TextSegmenter segmenter = new TextSegmenter(SegMode.SEARCH);
    private static List<String> corpusTerms = new ArrayList<String>();
    private static List<List<String>> documentList = new ArrayList<List<String>>();
    private static Map<List<String>, String> corpusHashMap = new HashMap<List<String>, String>();
    private double k1;
    private double b;

    public BM25() {
        this(1.2, 0.75);
    }

    public BM25(double k1, double b) {
        if (k1 < 0.0) {
            throw new IllegalArgumentException("Negative k1 = " + k1);
        }
        if (b < 0.0 || b > 1.0) {
            throw new IllegalArgumentException("Invalid b = " + b);
        }
        this.k1 = k1;
        this.b = b;
    }

    private double tf(List<String> tfDocument, String term) {
        double count = 0.0;
        int ld = tfDocument.stream().mapToInt(String::length).sum();
        int corpusSize = corpusTerms.stream().mapToInt(String::length).sum();
        double avgDocSize = (double)corpusSize / (double)documentList.size();
        for (String word : tfDocument) {
            if (!term.equalsIgnoreCase(word)) continue;
            count += 1.0;
        }
        double freq = count / (double)tfDocument.size();
        return freq * (this.k1 + 1.0) / (freq + this.k1 * (1.0 - this.b + this.b * (double)ld / avgDocSize));
    }

    private double idf(String term) {
        double count = 0.0;
        for (List<String> idfDoc : documentList) {
            if (!idfDoc.stream().anyMatch(word -> term.equalsIgnoreCase((String)word))) continue;
            count += 1.0;
        }
        return Math.log(1.0 + ((double)documentList.size() - count + 0.5) / (count + 0.5));
    }

    private Map<String, Double> score(List<String> queryTermList) {
        HashMap<String, Double> scoredDocument = new HashMap<String, Double>();
        for (List<String> docTerms : documentList) {
            double sumScore = 0.0;
            for (String queryTerm : queryTermList) {
                sumScore += this.tf(docTerms, queryTerm) * this.idf(queryTerm);
            }
            String docId = corpusHashMap.get(docTerms);
            scoredDocument.put(docId, sumScore);
        }
        return scoredDocument;
    }

    public Map<String, Double> rankBM25(String query, Map<String, String> documents, int topNum) {
        this.clear();
        List<String> segmentList = this.segByCharacter(query);
        for (Map.Entry<String, String> docEntry : documents.entrySet()) {
            String id = docEntry.getKey();
            String doc = docEntry.getValue();
            List<String> segs = this.segByCharacter(doc);
            documentList.add(segs);
            corpusTerms.addAll(segs);
            corpusHashMap.put(segs, id);
        }
        Map<String, Double> scoredDoc = this.score(segmentList);
        return this.getTopN(scoredDoc, topNum);
    }

    private Map<String, Double> getTopN(Map<String, Double> scoredDoc, int topNum) {
        PriorityQueue<ScoredEntity> maxHeap = new PriorityQueue<ScoredEntity>(Comparator.comparingDouble(entry -> -entry.getScore()));
        scoredDoc.forEach((id, score) -> maxHeap.add(new ScoredEntity<String>((String)id, (double)score)));
        LinkedHashMap<String, Double> topNDoc = new LinkedHashMap<String, Double>();
        for (int i = 0; i < Math.min(topNum, maxHeap.size()); ++i) {
            ScoredEntity entry2 = maxHeap.poll();
            topNDoc.put((String)entry2.getItem(), entry2.getScore());
        }
        return topNDoc;
    }

    private List<String> seg(String sentence) {
        List<String> segs = this.segmenter.process(sentence).stream().map(token -> token.getWord().toLowerCase()).collect(Collectors.toList());
        return segs;
    }

    private List<String> segByCharacter(String sentence) {
        return sentence.chars().mapToObj(c -> String.valueOf((char)c).toLowerCase()).collect(Collectors.toList());
    }

    private void clear() {
        documentList.clear();
        corpusTerms.clear();
        corpusHashMap.clear();
    }
}

