/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.sphinx.result;

import edu.cmu.sphinx.decoder.search.AlternateHypothesisManager;
import edu.cmu.sphinx.decoder.search.Token;
import edu.cmu.sphinx.linguist.WordSearchState;
import edu.cmu.sphinx.linguist.dictionary.Pronunciation;
import edu.cmu.sphinx.linguist.dictionary.Word;
import edu.cmu.sphinx.result.Edge;
import edu.cmu.sphinx.result.Node;
import edu.cmu.sphinx.result.Result;
import edu.cmu.sphinx.util.LogMath;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.LineNumberReader;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;
import java.util.StringTokenizer;

public class Lattice {
    protected Node initialNode;
    protected Node terminalNode;
    protected Set<Edge> edges;
    protected Map<String, Node> nodes;
    protected double logBase;
    protected LogMath logMath;
    private Set<Token> visitedWordTokens;
    private AlternateHypothesisManager loserManager;

    protected Lattice() {
        this.edges = new HashSet<Edge>();
        this.nodes = new HashMap<String, Node>();
    }

    public Lattice(LogMath logMath) {
        this();
        this.logMath = logMath;
    }

    public Lattice(Result result) {
        this(result.getLogMath());
        this.visitedWordTokens = new HashSet<Token>();
        this.loserManager = result.getAlternateHypothesisManager();
        if (this.loserManager != null) {
            this.loserManager.purge();
        }
        Iterable<Token> tokens = result.getBestFinalToken() == null ? result.getActiveTokens() : result.getResultTokens();
        Iterator<Token> i$ = tokens.iterator();
        while (i$.hasNext()) {
            Token token;
            for (token = i$.next(); token != null && !token.isWord(); token = token.getPredecessor()) {
            }
            assert (token != null && token.getWord().isSentenceEndWord());
            if (this.terminalNode == null) {
                this.initialNode = this.terminalNode = new Node(this.getNodeID(result.getBestToken()), token.getWord(), -1, -1);
                this.addNode(this.terminalNode);
            }
            this.collapseWordToken(token);
        }
    }

    private Node getNode(Token token) {
        if (token.getWord().isSentenceEndWord()) {
            return this.terminalNode;
        }
        Node node = this.nodes.get(this.getNodeID(token));
        if (node == null) {
            WordSearchState wordState = (WordSearchState)token.getSearchState();
            int startFrame = -1;
            int endFrame = -1;
            if (wordState.isWordStart()) {
                startFrame = token.getFrameNumber();
            } else {
                endFrame = token.getFrameNumber();
            }
            node = new Node(this.getNodeID(token), token.getWord(), startFrame, endFrame);
            this.addNode(node);
        }
        return node;
    }

    private void collapseWordToken(Token token) {
        assert (token != null);
        if (this.visitedWordTokens.contains(token)) {
            return;
        }
        this.visitedWordTokens.add(token);
        this.collapseWordPath(this.getNode(token), token.getPredecessor(), token.getAcousticScore() + token.getInsertionScore(), token.getLanguageScore());
        if (this.loserManager != null && this.loserManager.hasAlternatePredecessors(token)) {
            for (Token loser : this.loserManager.getAlternatePredecessors(token)) {
                this.collapseWordPath(this.getNode(token), loser, token.getAcousticScore(), token.getLanguageScore());
            }
        }
    }

    private void collapseWordPath(Node parentWordNode, Token token, float acousticScore, float languageScore) {
        if (token == null) {
            return;
        }
        if (token.isWord()) {
            Node fromNode = this.getNode(token);
            this.addEdge(fromNode, parentWordNode, acousticScore, languageScore);
            if (token.getPredecessor() != null) {
                this.collapseWordToken(token);
            } else {
                assert (token.getWord().isSentenceStartWord());
                this.initialNode = fromNode;
            }
            return;
        }
        while (true) {
            acousticScore += token.getAcousticScore() + token.getInsertionScore();
            languageScore += token.getLanguageScore();
            Token preToken = token.getPredecessor();
            if (preToken == null) {
                return;
            }
            if (preToken.isWord() || this.loserManager != null && this.loserManager.hasAlternatePredecessors(token)) break;
            token = preToken;
        }
        this.collapseWordPath(parentWordNode, token.getPredecessor(), acousticScore, languageScore);
        if (this.loserManager != null && this.loserManager.hasAlternatePredecessors(token)) {
            for (Token loser : this.loserManager.getAlternatePredecessors(token)) {
                this.collapseWordPath(parentWordNode, loser, acousticScore, languageScore);
            }
        }
    }

    private String getNodeID(Token token) {
        return Integer.toString(token.hashCode());
    }

    public Lattice(String fileName) {
        try {
            String line;
            System.err.println("Loading from " + fileName);
            LineNumberReader in = new LineNumberReader(new FileReader(fileName));
            while ((line = in.readLine()) != null) {
                StringTokenizer tokens = new StringTokenizer(line);
                if (!tokens.hasMoreTokens()) continue;
                String type = tokens.nextToken();
                if (type.equals("edge:")) {
                    Edge.load(this, tokens);
                    continue;
                }
                if (type.equals("node:")) {
                    Node.load(this, tokens);
                    continue;
                }
                if (type.equals("initialNode:")) {
                    this.setInitialNode(this.getNode(tokens.nextToken()));
                    continue;
                }
                if (type.equals("terminalNode:")) {
                    this.setTerminalNode(this.getNode(tokens.nextToken()));
                    continue;
                }
                if (type.equals("logBase:")) {
                    this.logBase = Double.parseDouble(tokens.nextToken());
                    continue;
                }
                throw new Error("SYNTAX ERROR: " + fileName + '[' + in.getLineNumber() + "] " + line);
            }
            in.close();
        }
        catch (Exception e) {
            throw new Error(e.toString());
        }
    }

    public Edge addEdge(Node fromNode, Node toNode, double acousticScore, double lmScore) {
        Edge e = new Edge(fromNode, toNode, acousticScore, lmScore);
        fromNode.addLeavingEdge(e);
        toNode.addEnteringEdge(e);
        this.edges.add(e);
        return e;
    }

    public Node addNode(Word word, int beginTime, int endTime) {
        Node n = new Node(word, beginTime, endTime);
        this.addNode(n);
        return n;
    }

    protected Node addNode(String id, Word word, int beginTime, int endTime) {
        Node n = new Node(id, word, beginTime, endTime);
        this.addNode(n);
        return n;
    }

    public Node addNode(String id, String word, int beginTime, int endTime) {
        Word w = new Word(word, new Pronunciation[0], false);
        return this.addNode(id, w, beginTime, endTime);
    }

    protected Node addNode(Token token, int beginTime, int endTime) {
        assert (token.getSearchState() instanceof WordSearchState);
        Word word = ((WordSearchState)token.getSearchState()).getPronunciation().getWord();
        return this.addNode(Integer.toString(token.hashCode()), word, beginTime, endTime);
    }

    boolean hasEdge(Edge edge) {
        return this.edges.contains(edge);
    }

    boolean hasNode(Node node) {
        return this.hasNode(node.getId());
    }

    protected boolean hasNode(String ID) {
        return this.nodes.containsKey(ID);
    }

    protected void addNode(Node n) {
        assert (!this.hasNode(n.getId()));
        this.nodes.put(n.getId(), n);
    }

    protected void removeNode(Node n) {
        assert (this.hasNode(n.getId()));
        this.nodes.remove(n.getId());
    }

    protected Node getNode(String id) {
        return this.nodes.get(id);
    }

    protected Collection<Node> getCopyOfNodes() {
        return new ArrayList<Node>(this.nodes.values());
    }

    public Collection<Node> getNodes() {
        return this.nodes.values();
    }

    protected void removeEdge(Edge e) {
        this.edges.remove(e);
    }

    public Collection<Edge> getEdges() {
        return this.edges;
    }

    public void dumpAISee(String fileName, String title) {
        try {
            System.err.println("Dumping " + title + " to " + fileName);
            FileWriter f = new FileWriter(fileName);
            f.write("graph: {\n");
            f.write("title: \"" + title + "\"\n");
            f.write("display_edge_labels: yes\n");
            for (Node node : this.nodes.values()) {
                node.dumpAISee(f);
            }
            for (Edge edge : this.edges) {
                edge.dumpAISee(f);
            }
            f.write("}\n");
            f.close();
        }
        catch (IOException e) {
            throw new Error(e.toString());
        }
    }

    public void dumpDot(String fileName, String title) {
        try {
            System.err.println("Dumping " + title + " to " + fileName);
            FileWriter f = new FileWriter(fileName);
            f.write("digraph \"" + title + "\" {\n");
            f.write("rankdir = LR\n");
            for (Node node : this.nodes.values()) {
                node.dumpDot(f);
            }
            for (Edge edge : this.edges) {
                edge.dumpDot(f);
            }
            f.write("}\n");
            f.close();
        }
        catch (IOException e) {
            throw new Error(e.toString());
        }
    }

    protected void dump(PrintWriter out) throws IOException {
        for (Node node : this.nodes.values()) {
            node.dump(out);
        }
        for (Edge edge : this.edges) {
            edge.dump(out);
        }
        out.println("initialNode: " + this.initialNode.getId());
        out.println("terminalNode: " + this.terminalNode.getId());
        out.println("logBase: " + this.logMath.getLogBase());
        out.flush();
    }

    public void dump(String file) {
        try {
            this.dump(new PrintWriter(new FileWriter(file)));
        }
        catch (IOException e) {
            throw new Error(e.toString());
        }
    }

    protected void removeNodeAndEdges(Node n) {
        for (Edge e : n.getLeavingEdges()) {
            e.getToNode().removeEnteringEdge(e);
            this.edges.remove(e);
        }
        for (Edge e : n.getEnteringEdges()) {
            e.getFromNode().removeLeavingEdge(e);
            this.edges.remove(e);
        }
        this.nodes.remove(n.getId());
        assert (this.checkConsistency());
    }

    protected void removeNodeAndCrossConnectEdges(Node n) {
        System.err.println("Removing node " + n + " and cross connecting edges");
        for (Edge ei : n.getEnteringEdges()) {
            for (Edge ej : n.getLeavingEdges()) {
                this.addEdge(ei.getFromNode(), ej.getToNode(), ei.getAcousticScore(), ei.getLMScore());
            }
        }
        this.removeNodeAndEdges(n);
        assert (this.checkConsistency());
    }

    public Node getInitialNode() {
        return this.initialNode;
    }

    public void setInitialNode(Node p_initialNode) {
        this.initialNode = p_initialNode;
    }

    public Node getTerminalNode() {
        return this.terminalNode;
    }

    public void setTerminalNode(Node p_terminalNode) {
        this.terminalNode = p_terminalNode;
    }

    public double getLogBase() {
        return this.logMath.getLogBase();
    }

    public LogMath getLogMath() {
        return this.logMath;
    }

    public void setLogMath(LogMath logMath) {
        this.logMath = logMath;
    }

    public void dumpAllPaths() {
        for (String path : this.allPaths()) {
            System.out.println(path);
        }
    }

    public List<String> allPaths() {
        return this.allPathsFrom("", this.initialNode);
    }

    protected List<String> allPathsFrom(String path, Node n) {
        String p = path + ' ' + n.getWord();
        LinkedList<String> l = new LinkedList<String>();
        if (n == this.terminalNode) {
            l.add(p);
        } else {
            for (Edge e : n.getLeavingEdges()) {
                l.addAll(this.allPathsFrom(p, e.getToNode()));
            }
        }
        return l;
    }

    boolean checkConsistency() {
        for (Node n : this.nodes.values()) {
            for (Edge e : n.getEnteringEdges()) {
                if (this.hasEdge(e)) continue;
                throw new Error("Lattice has NODE with missing FROM edge: " + n + ',' + e);
            }
            for (Edge e : n.getLeavingEdges()) {
                if (this.hasEdge(e)) continue;
                throw new Error("Lattice has NODE with missing TO edge: " + n + ',' + e);
            }
        }
        for (Edge e : this.edges) {
            if (!this.hasNode(e.getFromNode())) {
                throw new Error("Lattice has EDGE with missing FROM node: " + e);
            }
            if (!this.hasNode(e.getToNode())) {
                throw new Error("Lattice has EDGE with missing TO node: " + e);
            }
            if (!e.getToNode().hasEdgeFromNode(e.getFromNode())) {
                throw new Error("Lattice has EDGE with TO node with no corresponding FROM edge: " + e);
            }
            if (e.getFromNode().hasEdgeToNode(e.getToNode())) continue;
            throw new Error("Lattice has EDGE with FROM node with no corresponding TO edge: " + e);
        }
        return true;
    }

    protected void sortHelper(Node n, List<Node> sorted, Set<Node> visited) {
        if (visited.contains(n)) {
            return;
        }
        visited.add(n);
        if (n == null) {
            throw new Error("Node is null");
        }
        for (Edge e : n.getLeavingEdges()) {
            this.sortHelper(e.getToNode(), sorted, visited);
        }
        sorted.add(n);
    }

    public List<Node> sortNodes() {
        ArrayList<Node> sorted = new ArrayList<Node>(this.nodes.size());
        this.sortHelper(this.initialNode, sorted, new HashSet<Node>());
        Collections.reverse(sorted);
        return sorted;
    }

    public void computeNodePosteriors(float languageModelWeightAdjustment) {
        this.computeNodePosteriors(languageModelWeightAdjustment, false);
    }

    public void computeNodePosteriors(float languageModelWeightAdjustment, boolean useAcousticScoresOnly) {
        if (this.initialNode == null) {
            return;
        }
        this.initialNode.setForwardScore(LogMath.getLogOne());
        this.initialNode.setViterbiScore(LogMath.getLogOne());
        List<Node> sortedNodes = this.sortNodes();
        assert (sortedNodes.get(0) == this.initialNode);
        for (Node currentNode : sortedNodes) {
            for (Edge edge : currentNode.getLeavingEdges()) {
                double forwardProb = edge.getFromNode().getForwardScore();
                double edgeScore = this.computeEdgeScore(edge, languageModelWeightAdjustment, useAcousticScoresOnly);
                edge.getToNode().setForwardScore(this.logMath.addAsLinear((float)(forwardProb += edgeScore), (float)edge.getToNode().getForwardScore()));
                double vs = edge.getFromNode().getViterbiScore() + edgeScore;
                if (edge.getToNode().getBestPredecessor() != null && !(vs > edge.getToNode().getViterbiScore())) continue;
                edge.getToNode().setBestPredecessor(currentNode);
                edge.getToNode().setViterbiScore(vs);
            }
        }
        this.terminalNode.setBackwardScore(LogMath.getLogOne());
        assert (sortedNodes.get(sortedNodes.size() - 1) == this.terminalNode);
        ListIterator<Node> n = sortedNodes.listIterator(sortedNodes.size() - 1);
        while (n.hasPrevious()) {
            Node currentNode;
            currentNode = n.previous();
            Collection<Edge> currentEdges = currentNode.getLeavingEdges();
            for (Edge edge : currentEdges) {
                double backwardProb = edge.getToNode().getBackwardScore();
                edge.getFromNode().setBackwardScore(this.logMath.addAsLinear((float)(backwardProb += this.computeEdgeScore(edge, languageModelWeightAdjustment, useAcousticScoresOnly)), (float)edge.getFromNode().getBackwardScore()));
            }
        }
        double normalizationFactor = this.terminalNode.getForwardScore();
        for (Node node : this.nodes.values()) {
            node.setPosterior(node.getForwardScore() + node.getBackwardScore() - normalizationFactor);
        }
    }

    public List<Node> getViterbiPath() {
        LinkedList<Node> path = new LinkedList<Node>();
        for (Node n = this.terminalNode; n != this.initialNode; n = n.getBestPredecessor()) {
            path.addFirst(n);
        }
        path.addFirst(this.initialNode);
        return path;
    }

    private double computeEdgeScore(Edge edge, float languageModelWeightAdjustment, boolean useAcousticScoresOnly) {
        if (useAcousticScoresOnly) {
            return edge.getAcousticScore();
        }
        return edge.getAcousticScore() + edge.getLMScore() * (double)languageModelWeightAdjustment;
    }

    public boolean isEquivalent(Lattice other) {
        return this.checkNodesEquivalent(this.initialNode, other.getInitialNode());
    }

    private boolean checkNodesEquivalent(Node n1, Node n2) {
        assert (n1 != null && n2 != null);
        boolean equivalent = n1.isEquivalent(n2);
        if (equivalent) {
            Collection<Edge> leavingEdges = n1.getCopyOfLeavingEdges();
            Collection<Edge> leavingEdges2 = n2.getCopyOfLeavingEdges();
            System.out.println("# edges: " + leavingEdges.size() + ' ' + leavingEdges2.size());
            for (Edge edge : leavingEdges) {
                Edge e2 = n2.findEquivalentLeavingEdge(edge);
                if (e2 == null) {
                    System.out.println("Equivalent edge not found, lattices not equivalent.");
                    return false;
                }
                if (!leavingEdges2.remove(e2)) {
                    System.out.println("Equivalent edge already matched, lattices not equivalent.");
                    return false;
                }
                if (equivalent &= this.checkNodesEquivalent(edge.getToNode(), e2.getToNode())) continue;
                return false;
            }
            if (!leavingEdges2.isEmpty()) {
                System.out.println("One lattice has too many edges.");
                return false;
            }
        }
        return equivalent;
    }

    boolean isFillerNode(Node node) {
        return node.getWord().getSpelling().equals("<sil>");
    }

    public void removeFillers() {
        for (Node node : this.sortNodes()) {
            if (!this.isFillerNode(node)) continue;
            this.removeNodeAndCrossConnectEdges(node);
            assert (this.checkConsistency());
        }
    }
}

