/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.parser.metrics;

import edu.stanford.nlp.international.Language;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.HasTag;
import edu.stanford.nlp.ling.Label;
import edu.stanford.nlp.parser.lexparser.EnglishTreebankParserParams;
import edu.stanford.nlp.parser.lexparser.Lexicon;
import edu.stanford.nlp.parser.lexparser.TreebankLangParserParams;
import edu.stanford.nlp.parser.metrics.AbstractEval;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.trees.DiskTreebank;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreeTransformer;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.StringUtils;
import java.io.PrintWriter;
import java.text.DecimalFormat;
import java.util.AbstractCollection;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;

public class TaggingEval
extends AbstractEval {
    private final Lexicon lex;
    private static boolean doCatLevelEval = false;
    private Counter<String> precisions;
    private Counter<String> recalls;
    private Counter<String> f1s;
    private Counter<String> precisions2;
    private Counter<String> recalls2;
    private Counter<String> pnums2;
    private Counter<String> rnums2;
    private Counter<String> percentOOV;
    private Counter<String> percentOOV2;
    private static final int minArgs = 2;
    private static final StringBuilder usage = new StringBuilder();
    public static final Map<String, Integer> optionArgDefs;

    public TaggingEval(String str) {
        this(str, true, null);
    }

    public TaggingEval(String str, boolean runningAverages, Lexicon lex) {
        super(str, runningAverages);
        this.lex = lex;
        if (doCatLevelEval) {
            this.precisions = new ClassicCounter<String>();
            this.recalls = new ClassicCounter<String>();
            this.f1s = new ClassicCounter<String>();
            this.precisions2 = new ClassicCounter<String>();
            this.recalls2 = new ClassicCounter<String>();
            this.pnums2 = new ClassicCounter<String>();
            this.rnums2 = new ClassicCounter<String>();
            this.percentOOV = new ClassicCounter<String>();
            this.percentOOV2 = new ClassicCounter<String>();
        }
    }

    protected Set<HasTag> makeObjects(Tree tree) {
        return tree == null ? Generics.newHashSet() : Generics.newHashSet(tree.taggedLabeledYield());
    }

    private static Map<String, Set<Label>> makeObjectsByCat(Tree t) {
        Map<String, Set<Label>> catMap = Generics.newHashMap();
        List<CoreLabel> tly = t.taggedLabeledYield();
        for (CoreLabel label : tly) {
            if (catMap.containsKey(label.value())) {
                catMap.get(label.value()).add(label);
                continue;
            }
            Set<CoreLabel> catSet = Generics.newHashSet();
            catSet.add(label);
            catMap.put(label.value(), catSet);
        }
        return catMap;
    }

    @Override
    public void evaluate(Tree guess, Tree gold, PrintWriter pw) {
        if (gold == null || guess == null) {
            System.err.printf("%s: Cannot compare against a null gold or guess tree!\n", this.getClass().getName());
            return;
        }
        super.evaluate(guess, gold, pw);
        if (doCatLevelEval) {
            Map<String, Set<Label>> guessCats = TaggingEval.makeObjectsByCat(guess);
            Map<String, Set<Label>> goldCats = TaggingEval.makeObjectsByCat(gold);
            Set<String> allCats = Generics.newHashSet();
            allCats.addAll(guessCats.keySet());
            allCats.addAll(goldCats.keySet());
            for (String cat : allCats) {
                Set<Label> thisGuessCats = guessCats.get(cat);
                Set<Label> thisGoldCats = goldCats.get(cat);
                if (thisGuessCats == null) {
                    thisGuessCats = Generics.newHashSet();
                }
                if (thisGoldCats == null) {
                    thisGoldCats = Generics.newHashSet();
                }
                double currentPrecision = TaggingEval.precision(thisGuessCats, thisGoldCats);
                double currentRecall = TaggingEval.precision(thisGoldCats, thisGuessCats);
                double currentF1 = currentPrecision > 0.0 && currentRecall > 0.0 ? 2.0 / (1.0 / currentPrecision + 1.0 / currentRecall) : 0.0;
                this.precisions.incrementCount(cat, currentPrecision);
                this.recalls.incrementCount(cat, currentRecall);
                this.f1s.incrementCount(cat, currentF1);
                this.precisions2.incrementCount(cat, (double)thisGuessCats.size() * currentPrecision);
                this.pnums2.incrementCount(cat, thisGuessCats.size());
                this.recalls2.incrementCount(cat, (double)thisGoldCats.size() * currentRecall);
                this.rnums2.incrementCount(cat, thisGoldCats.size());
                if (this.lex != null) {
                    this.measureOOV(guess, gold);
                }
                if (pw == null || !this.runningAverages) continue;
                pw.println(cat + "\tP: " + (double)((int)(currentPrecision * 10000.0)) / 100.0 + " (sent ave " + (double)((int)(this.precisions.getCount(cat) * 10000.0 / this.num)) / 100.0 + ") (evalb " + (double)((int)(this.precisions2.getCount(cat) * 10000.0 / this.pnums2.getCount(cat))) / 100.0 + ")");
                pw.println("\tR: " + (double)((int)(currentRecall * 10000.0)) / 100.0 + " (sent ave " + (double)((int)(this.recalls.getCount(cat) * 10000.0 / this.num)) / 100.0 + ") (evalb " + (double)((int)(this.recalls2.getCount(cat) * 10000.0 / this.rnums2.getCount(cat))) / 100.0 + ")");
                double cF1 = 2.0 / (this.rnums2.getCount(cat) / this.recalls2.getCount(cat) + this.pnums2.getCount(cat) / this.precisions2.getCount(cat));
                String emit = this.str + " F1: " + (double)((int)(currentF1 * 10000.0)) / 100.0 + " (sent ave " + (double)((int)(10000.0 * this.f1s.getCount(cat) / this.num)) / 100.0 + ", evalb " + (double)((int)(10000.0 * cF1)) / 100.0 + ")";
                pw.println(emit);
            }
            if (pw != null && this.runningAverages) {
                pw.println("========================================");
            }
        }
    }

    private void measureOOV(Tree guess, Tree gold) {
        List<CoreLabel> goldTagging = gold.taggedLabeledYield();
        List<CoreLabel> guessTagging = guess.taggedLabeledYield();
        assert (goldTagging.size() == guessTagging.size());
        for (int i = 0; i < goldTagging.size(); ++i) {
            if (goldTagging.get(i) == guessTagging.get(i)) continue;
            this.percentOOV2.incrementCount(goldTagging.get(i).tag());
            if (this.lex.isKnown(goldTagging.get(i).word())) continue;
            this.percentOOV.incrementCount(goldTagging.get(i).tag());
        }
    }

    @Override
    public void display(boolean verbose, PrintWriter pw) {
        super.display(verbose, pw);
        if (doCatLevelEval) {
            double f1;
            double rec;
            double prec;
            double rnum2;
            double pnum2;
            DecimalFormat nf = new DecimalFormat("0.00");
            Set<String> cats = Generics.newHashSet();
            Random rand = new Random();
            cats.addAll(this.precisions.keySet());
            cats.addAll(this.recalls.keySet());
            TreeMap<Double, String> f1Map = new TreeMap<Double, String>();
            for (String cat : cats) {
                pnum2 = this.pnums2.getCount(cat);
                rnum2 = this.rnums2.getCount(cat);
                prec = this.precisions2.getCount(cat) / pnum2;
                f1 = 2.0 / (1.0 / prec + 1.0 / (rec = this.recalls2.getCount(cat) / rnum2));
                if (new Double(f1).equals(Double.NaN)) {
                    f1 = -1.0;
                }
                if (f1Map.containsKey(f1)) {
                    f1Map.put(f1 + rand.nextDouble() / 1000.0, cat);
                    continue;
                }
                f1Map.put(f1, cat);
            }
            pw.println("============================================================");
            pw.println("Tagging Performance by Category -- final statistics");
            pw.println("============================================================");
            for (String cat : f1Map.values()) {
                pnum2 = this.pnums2.getCount(cat);
                rnum2 = this.rnums2.getCount(cat);
                prec = this.precisions2.getCount(cat) / pnum2;
                rec = this.recalls2.getCount(cat) / rnum2;
                f1 = 2.0 / (1.0 / (prec *= 100.0) + 1.0 / (rec *= 100.0));
                double oovRate = this.lex == null ? -1.0 : this.percentOOV.getCount(cat) / this.percentOOV2.getCount(cat);
                pw.println(cat + "\tLP: " + (pnum2 == 0.0 ? " N/A" : nf.format(prec)) + "\tguessed: " + (int)pnum2 + "\tLR: " + (rnum2 == 0.0 ? " N/A" : nf.format(rec)) + "\tgold:  " + (int)rnum2 + "\tF1: " + (pnum2 == 0.0 || rnum2 == 0.0 ? " N/A" : nf.format(f1)) + "\tOOV: " + (this.lex == null ? " N/A" : nf.format(oovRate)));
            }
            pw.println("============================================================");
        }
    }

    public static void main(String[] args) {
        if (args.length < 2) {
            System.out.println(usage.toString());
            System.exit(-1);
        }
        TreebankLangParserParams tlpp = new EnglishTreebankParserParams();
        int maxGoldYield = Integer.MAX_VALUE;
        boolean VERBOSE = false;
        String encoding = "UTF-8";
        String guessFile = null;
        String goldFile = null;
        Map<String, String[]> argsMap = StringUtils.argsToMap(args, optionArgDefs);
        for (Map.Entry<String, String[]> opt : argsMap.entrySet()) {
            if (opt.getKey() == null) continue;
            if (opt.getKey().equals("-l")) {
                Language lang = Language.valueOf(opt.getValue()[0].trim());
                tlpp = lang.params;
            } else if (opt.getKey().equals("-y")) {
                maxGoldYield = Integer.parseInt(opt.getValue()[0].trim());
            } else if (opt.getKey().equals("-v")) {
                VERBOSE = true;
            } else if (opt.getKey().equals("-c")) {
                doCatLevelEval = true;
            } else if (opt.getKey().equals("-e")) {
                encoding = opt.getValue()[0];
            } else {
                System.err.println(usage.toString());
                System.exit(-1);
            }
            String[] rest = argsMap.get(null);
            if (rest == null || rest.length < 2) {
                System.err.println(usage.toString());
                System.exit(-1);
            }
            goldFile = rest[0];
            guessFile = rest[1];
        }
        tlpp.setInputEncoding(encoding);
        PrintWriter pwOut = tlpp.pw();
        DiskTreebank guessTreebank = tlpp.diskTreebank();
        guessTreebank.loadPath(guessFile);
        pwOut.println("GUESS TREEBANK:");
        pwOut.println(guessTreebank.textualSummary());
        DiskTreebank goldTreebank = tlpp.diskTreebank();
        goldTreebank.loadPath(goldFile);
        pwOut.println("GOLD TREEBANK:");
        pwOut.println(goldTreebank.textualSummary());
        TaggingEval metric = new TaggingEval("Tagging LP/LR");
        TreeTransformer tc = tlpp.collinizer();
        Iterator goldItr = ((AbstractCollection)goldTreebank).iterator();
        Iterator guessItr = ((AbstractCollection)guessTreebank).iterator();
        int goldLineId = 0;
        int guessLineId = 0;
        int skippedGuessTrees = 0;
        while (guessItr.hasNext() && goldItr.hasNext()) {
            Tree guessTree = (Tree)guessItr.next();
            ArrayList<Label> guessYield = guessTree.yield();
            ++guessLineId;
            Tree goldTree = (Tree)goldItr.next();
            ArrayList<Label> goldYield = goldTree.yield();
            ++goldLineId;
            if (goldYield.size() > maxGoldYield) {
                ++skippedGuessTrees;
                continue;
            }
            if (goldYield.size() != guessYield.size()) {
                pwOut.printf("Yield mismatch gold: %d tokens vs. guess: %d tokens (lines: gold %d guess %d)%n", goldYield.size(), guessYield.size(), goldLineId, guessLineId);
                ++skippedGuessTrees;
                continue;
            }
            Tree evalGuess = tc.transformTree(guessTree);
            Tree evalGold = tc.transformTree(goldTree);
            metric.evaluate(evalGuess, evalGold, VERBOSE ? pwOut : null);
        }
        if (guessItr.hasNext() || goldItr.hasNext()) {
            System.err.printf("Guess/gold files do not have equal lengths (guess: %d gold: %d)%n.", guessLineId, goldLineId);
        }
        pwOut.println("================================================================================");
        if (skippedGuessTrees != 0) {
            pwOut.printf("%s %d guess trees\n", "Unable to evaluate", skippedGuessTrees);
        }
        metric.display(true, pwOut);
        pwOut.println();
        pwOut.close();
    }

    static {
        usage.append(String.format("Usage: java %s [OPTS] gold guess\n\n", TaggingEval.class.getName()));
        usage.append("Options:\n");
        usage.append("  -v         : Verbose mode.\n");
        usage.append("  -l lang    : Select language settings from " + Language.langList + "\n");
        usage.append("  -y num     : Skip gold trees with yields longer than num.\n");
        usage.append("  -c         : Compute LP/LR/F1 by category.\n");
        usage.append("  -e         : Input encoding.\n");
        optionArgDefs = Generics.newHashMap();
        optionArgDefs.put("-v", 0);
        optionArgDefs.put("-l", 1);
        optionArgDefs.put("-y", 1);
        optionArgDefs.put("-c", 0);
        optionArgDefs.put("-e", 0);
    }
}

