/*
 * Decompiled with CFR 0.152.
 */
package edu.unc.bioinf.ubu.assembly;

import edu.unc.bioinf.ubu.assembly.Aligner;
import edu.unc.bioinf.ubu.assembly.Contig;
import edu.unc.bioinf.ubu.assembly.Counts;
import edu.unc.bioinf.ubu.assembly.Edge;
import edu.unc.bioinf.ubu.assembly.Node;
import edu.unc.bioinf.ubu.assembly.ReadPosition;
import edu.unc.bioinf.ubu.fastq.FastqInputFile;
import edu.unc.bioinf.ubu.fastq.FastqRecord;
import edu.unc.bioinf.ubu.sam.ReadBlock;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import net.sf.samtools.CigarOperator;
import net.sf.samtools.SAMFileHeader;
import net.sf.samtools.SAMFileReader;
import net.sf.samtools.SAMFileWriter;
import net.sf.samtools.SAMFileWriterFactory;
import net.sf.samtools.SAMRecord;

public class Assembler {
    private int kmerSize = 33;
    private int minEdgeFrequency = 15;
    private int minNodeFrequncy = 15;
    private int minContigLength = 101;
    private double minEdgeRatio = 0.05;
    private int minMergeSize = 25;
    private FastqInputFile fastq = new FastqInputFile();
    private Map<String, Node> nodes = new HashMap<String, Node>();
    private Set<Node> rootNodes = new HashSet<Node>();
    private List<Contig> contigs = new ArrayList<Contig>();
    private BufferedWriter writer;
    private SAMFileHeader samHeader;
    Set<SAMRecord> updatedReads = new HashSet<SAMRecord>();
    private Aligner aligner = new Aligner("/home/lisle/reference/chr17/chr17.fa");

    public void assemble(String inputSam, String outputPrefix) throws Exception {
        String contigsFasta = outputPrefix + "_contigs.fasta";
        String contigsSam = outputPrefix + "_contigs.sam";
        String readsBam = outputPrefix + "_reads.bam";
        System.out.println("Assembling contigs");
        this.assembleContigs(inputSam, contigsFasta);
        System.out.println("Aligning contigs");
        this.alignContigs(contigsFasta, contigsSam);
        System.out.println("Adjusting reads");
        this.adjustReads(contigsSam);
        System.out.println("Writing adjusted reads");
        this.outputReads(readsBam);
        System.out.println("Done.");
    }

    public void assembleContigs(String inputSam, String output) throws FileNotFoundException, IOException {
        SAMFileReader reader = new SAMFileReader(new File(inputSam));
        reader.setValidationStringency(SAMFileReader.ValidationStringency.SILENT);
        this.samHeader = reader.getFileHeader();
        this.writer = new BufferedWriter(new FileWriter(output, false));
        int numRecs = 0;
        for (SAMRecord read : reader) {
            this.addToGraph(read);
            ++numRecs;
        }
        System.out.println("Num records: " + numRecs);
        System.out.println("Num nodes: " + this.nodes.size());
        this.printEdgeCounts();
        this.filterLowFrequencyEdges();
        this.filterLowFrequencyNodes();
        this.identifyRootNodes();
        this.buildContigs();
        this.outputContigs();
        this.writer.close();
        reader.close();
    }

    private void adjustReads(String contigSam) {
        HashMap<String, Contig> contigMap = new HashMap<String, Contig>();
        for (Contig contig : this.contigs) {
            contigMap.put(contig.getDescriptor(), contig);
        }
        SAMFileReader reader = new SAMFileReader(new File(contigSam));
        reader.setValidationStringency(SAMFileReader.ValidationStringency.SILENT);
        for (SAMRecord contigRead : reader) {
            List<ReadBlock> contigReadBlocks = ReadBlock.getReadBlocks(contigRead);
            Contig contig = (Contig)contigMap.get(contigRead.getReadName());
            List<ReadPosition> readPositions = contig.getFilteredReadPositions();
            for (ReadPosition readPosition : readPositions) {
                SAMRecord updatedRead = this.updateReadAlignment(contigReadBlocks, readPosition);
                if (updatedRead == null) continue;
                this.updatedReads.add(updatedRead);
            }
        }
    }

    private SAMRecord cloneRead(SAMRecord read) {
        try {
            return (SAMRecord)read.clone();
        }
        catch (CloneNotSupportedException e) {
            e.printStackTrace();
            throw new RuntimeException(e);
        }
    }

    SAMRecord updateReadAlignment(List<ReadBlock> contigReadBlocks, ReadPosition orig) {
        ArrayList<ReadBlock> blocks = new ArrayList<ReadBlock>();
        SAMRecord read = this.cloneRead(orig.getRead());
        int contigPosition = orig.getPosition();
        int accumulatedLength = 0;
        for (ReadBlock contigBlock : contigReadBlocks) {
            ReadBlock block;
            if (contigBlock.getReadStart() + contigBlock.getReferenceLength() < orig.getPosition() + 1 || (block = contigBlock.getSubBlock(accumulatedLength, contigPosition, read.getReadLength() - accumulatedLength)).getLength() == 0) continue;
            blocks.add(block);
            if (block.getType() != CigarOperator.D) {
                accumulatedLength += block.getLength();
            }
            if (accumulatedLength > read.getReadLength()) {
                throw new IllegalStateException("Accumulated Length: " + accumulatedLength + " is greater than read length: " + read.getReadLength());
            }
            if (accumulatedLength != read.getReadLength()) continue;
            break;
        }
        if (blocks.size() > 0) {
            int newAlignmentStart = ((ReadBlock)blocks.get(0)).getReferenceStart();
            String newCigar = ReadBlock.toCigarString(blocks);
            read.setCigarString(newCigar);
            read.setAlignmentStart(newAlignmentStart);
        } else {
            read = null;
        }
        return read;
    }

    private void outputReads(String readsBam) {
        this.samHeader.setSortOrder(SAMFileHeader.SortOrder.unsorted);
        SAMFileWriter out = new SAMFileWriterFactory().makeSAMOrBAMWriter(this.samHeader, true, new File(readsBam));
        for (SAMRecord read : this.updatedReads) {
            out.addAlignment(read);
        }
        out.close();
    }

    private void alignContigs(String contigFile, String contigSam) throws InterruptedException, IOException {
        this.aligner.align(contigFile, contigSam);
    }

    private void _assemble(String inputFastq, String output) throws FileNotFoundException, IOException {
        this.fastq.init(inputFastq);
        this.writer = new BufferedWriter(new FileWriter(output, false));
        FastqRecord rec = this.fastq.getNextRecord();
        int numRecs = 0;
        while (rec != null) {
            String sequence = rec.getSequence();
            this.addToGraph(sequence);
            rec = this.fastq.getNextRecord();
            ++numRecs;
        }
        System.out.println("Num records: " + numRecs);
        System.out.println("Num nodes: " + this.nodes.size());
        this.printEdgeCounts();
        this.filterLowFrequencyEdges();
        this.filterLowFrequencyNodes();
        this.identifyRootNodes();
        this.buildContigs();
        this.outputContigs();
        this.writer.close();
    }

    public void setKmerSize(int kmerSize) {
        this.kmerSize = kmerSize;
    }

    public void setMinContigLength(int minContigLength) {
        this.minContigLength = minContigLength;
    }

    public void setMinEdgeFrequency(int minEdgeFrequency) {
        this.minEdgeFrequency = minEdgeFrequency;
    }

    public void setMinNodeFrequncy(int minNodeFrequncy) {
        this.minNodeFrequncy = minNodeFrequncy;
    }

    public void setMinEdgeRatio(double minEdgeRatio) {
        this.minEdgeRatio = minEdgeRatio;
    }

    private void filterLowFrequencyNodes() {
        ArrayList<Node> nodesToFilter = new ArrayList<Node>();
        for (Node node : this.nodes.values()) {
            if (node.getCount() >= this.minNodeFrequncy) continue;
            nodesToFilter.add(node);
        }
        ArrayList<Edge> edgesToFilter = new ArrayList<Edge>();
        for (Node node : nodesToFilter) {
            edgesToFilter.addAll(node.getToEdges());
            edgesToFilter.addAll(node.getFromEdges());
        }
        for (Edge edge : edgesToFilter) {
            edge.remove();
        }
        for (Node node : nodesToFilter) {
            this.nodes.remove(node.getSequence());
        }
    }

    private void filterLowFrequencyEdges() {
        HashSet<Edge> edgesToFilter = new HashSet<Edge>();
        for (Node node : this.nodes.values()) {
            for (Edge edge : node.getToEdges()) {
                if (edge.getCount() >= this.minEdgeFrequency) continue;
                edgesToFilter.add(edge);
            }
            edgesToFilter.addAll(node.getInfrequentEdges(this.minEdgeRatio));
        }
        for (Edge edge : edgesToFilter) {
            edge.remove();
        }
    }

    private void outputContigs() throws IOException {
        int count = 0;
        for (Contig contig : this.contigs) {
            contig.setDescriptor("contig" + count++ + "_" + contig.getDescriptor());
            this.writer.append(">" + contig.getDescriptor() + "\n");
            this.writer.append(contig.getSequence());
            this.writer.append("\n");
        }
    }

    private void identifyRootNodes() {
        for (Node node : this.nodes.values()) {
            if (!node.isRootNode()) continue;
            this.rootNodes.add(node);
        }
    }

    private void buildContigs() {
        System.out.println("Num starting nodes: " + this.rootNodes.size());
        for (Node node : this.rootNodes) {
            Contig contig = new Contig();
            HashSet<Node> visitedNodes = new HashSet<Node>();
            Counts counts = new Counts();
            this.buildContig(node, visitedNodes, contig, counts);
        }
    }

    private void processContigTerminus(Node node, Counts counts, Contig contig) {
        if (!counts.isTerminatedAtRepeat()) {
            contig.append(node, node.getSequence());
        }
        if (contig.getSequence().length() >= this.minContigLength) {
            contig.setDescriptor(counts.toString());
            this.contigs.add(contig);
        }
    }

    private void buildContig(Node node, Set<Node> visitedNodes, Contig contig, Counts counts) {
        if (visitedNodes.contains(node)) {
            counts.setTerminatedAtRepeat(true);
            this.processContigTerminus(node, counts, contig);
        } else {
            visitedNodes.add(node);
            Collection<Edge> edges = node.getToEdges();
            if (edges.isEmpty()) {
                this.processContigTerminus(node, counts, contig);
            } else {
                contig.append(node, Character.toString(node.getSequence().charAt(0)));
                for (Edge edge : edges) {
                    counts.incrementEdgeCounts(edge.getCount());
                    Contig contigBranch = new Contig(contig);
                    HashSet<Node> visitedNodesBranch = new HashSet<Node>(visitedNodes);
                    this.buildContig(edge.getTo(), visitedNodesBranch, contigBranch, (Counts)counts.clone());
                }
            }
        }
    }

    private void mergeContigs() {
        if (this.minMergeSize > this.kmerSize) {
            ArrayList<Contig> updatedContigs = new ArrayList<Contig>(this.contigs);
            int mergedCount = 0;
            for (Contig contig1 : this.contigs) {
                boolean isMerged = false;
                for (Contig contig2 : updatedContigs) {
                    int overlapIdx;
                    if (contig1 == contig2 || (overlapIdx = this.getOverlapIndex(contig1.getSequence(), contig2.getSequence())) <= -1) continue;
                    contig2.prependSequence(contig1.getDescriptor(), contig1.getSequence());
                    isMerged = true;
                }
                if (!isMerged) continue;
                updatedContigs.remove(contig1);
                ++mergedCount;
            }
            this.contigs = updatedContigs;
            System.out.println("Merged: " + mergedCount + " overlapping contigs.");
        }
    }

    private int getOverlapIndex(String s1, String s2) {
        int strLenDiff = s2.length() - s1.length();
        int start = strLenDiff > 0 ? strLenDiff : 0;
        for (int i = start = Math.max(start, s1.length() - this.minMergeSize); i < s1.length(); ++i) {
            if (!s2.startsWith(s1.substring(i))) continue;
            return i;
        }
        return -1;
    }

    private void printEdgeCounts() {
        long[] edgeCounts = new long[this.nodes.size()];
        ArrayList<Integer> edgeSizes = new ArrayList<Integer>();
        int idx = 0;
        for (Node node : this.nodes.values()) {
            edgeCounts[idx++] = node.getToEdges().size();
            for (Edge edge : node.getToEdges()) {
                edgeSizes.add(edge.getCount());
            }
        }
        Arrays.sort(edgeCounts);
        System.out.println("Median edge count: " + edgeCounts[edgeCounts.length / 2]);
        System.out.println("Max edge count: " + edgeCounts[edgeCounts.length - 1]);
        System.out.println("Min edge count: " + edgeCounts[0]);
        Object[] sizes = edgeSizes.toArray(new Integer[edgeSizes.size()]);
        Arrays.sort(sizes);
        System.out.println("Median edge size: " + sizes[sizes.length / 2]);
        System.out.println("Max edge size: " + sizes[sizes.length - 1]);
        System.out.println("Min edge size: " + sizes[0]);
    }

    private void addToGraph(SAMRecord read) {
        Node node = this.addToGraph(read.getReadString());
        if (node != null) {
            node.addStartingRead(read);
        }
    }

    private Node addToGraph(String sequence) {
        Node prev = null;
        Node firstNode = null;
        for (int i = 0; i <= sequence.length() - this.kmerSize; ++i) {
            String kmer = sequence.substring(i, i + this.kmerSize);
            Node node = this.nodes.get(kmer);
            if (node == null) {
                node = new Node(kmer, sequence);
                this.nodes.put(kmer, node);
            } else {
                node.incrementCount();
            }
            if (prev != null) {
                prev.addToEdge(node);
            }
            if (firstNode == null) {
                firstNode = node;
            }
            prev = node;
        }
        return firstNode;
    }

    public static void main(String[] args) throws Exception {
        long s = System.currentTimeMillis();
        Assembler ayc = new Assembler();
        ayc.assemble("/home/lisle/ayc/case0/round2/case0_tumor.bam", "/home/lisle/ayc/case0/round2/ra_tumor");
        long e = System.currentTimeMillis();
        System.out.println("Elapsed secs: " + (e - s) / 1000L);
    }
}

