Mercurial > repos > cpt > cpt_intron_detection
diff cpt_intron_detect/intron_detection.py @ 0:1a19092729be draft
Uploaded
author | cpt |
---|---|
date | Fri, 13 May 2022 05:08:54 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/cpt_intron_detect/intron_detection.py Fri May 13 05:08:54 2022 +0000 @@ -0,0 +1,498 @@ +#!/usr/bin/env python +import sys +import re +import itertools +import argparse +import hashlib +import copy +from CPT_GFFParser import gffParse, gffWrite, gffSeqFeature +from Bio.Blast import NCBIXML +from Bio.SeqFeature import SeqFeature, FeatureLocation +from gff3 import feature_lambda +from collections import OrderedDict +import logging + +logging.basicConfig(level=logging.DEBUG) +log = logging.getLogger() + + +def parse_xml(blastxml, thresh): + """ Parses xml file to get desired info (genes, hits, etc) """ + blast = [] + discarded_records = 0 + totLen = 0 + for iter_num, blast_record in enumerate(NCBIXML.parse(blastxml), 1): + blast_gene = [] + align_num = 0 + for alignment in blast_record.alignments: + align_num += 1 + # hit_gis = alignment.hit_id + alignment.hit_def + # gi_nos = [str(gi) for gi in re.findall('(?<=gi\|)\d{9,10}', hit_gis)] + gi_nos = str(alignment.accession) + + for hsp in alignment.hsps: + x = float(hsp.identities - 1) / ((hsp.query_end) - hsp.query_start) + if x < thresh: + discarded_records += 1 + continue + nice_name = blast_record.query + + if " " in nice_name: + nice_name = nice_name[0 : nice_name.index(" ")] + + blast_gene.append( + { + "gi_nos": gi_nos, + "sbjct_length": alignment.length, + "query_length": blast_record.query_length, + "sbjct_range": (hsp.sbjct_start, hsp.sbjct_end), + "query_range": (hsp.query_start, hsp.query_end), + "name": nice_name, + "evalue": hsp.expect, + "identity": hsp.identities, + "identity_percent": x, + "hit_num": align_num, + "iter_num": iter_num, + "match_id": alignment.title.partition(">")[0], + } + ) + + blast.append(blast_gene) + totLen += len(blast_gene) + log.debug("parse_blastxml %s -> %s", totLen + discarded_records, totLen) + return blast + + +def filter_lone_clusters(clusters): + """ Removes all clusters with only one member and those with no hits """ + filtered_clusters = {} + for key in clusters: + if len(clusters[key]) > 1 and len(key) > 0: + filtered_clusters[key] = clusters[key] + log.debug("filter_lone_clusters %s -> %s", len(clusters), len(filtered_clusters)) + return filtered_clusters + + +def test_true(feature, **kwargs): + return True + + +def parse_gff(gff3): + """ Extracts strand and start location to be used in cluster filtering """ + log.debug("parse_gff3") + gff_info = {} + _rec = None + for rec in gffParse(gff3): + endBase = len(rec.seq) + + _rec = rec + _rec.annotations = {} + for feat in feature_lambda(rec.features, test_true, {}, subfeatures=False): + if feat.type == "CDS": + if "Name" in feat.qualifiers.keys(): + CDSname = feat.qualifiers["Name"] + else: + CDSname = feat.qualifiers["ID"] + gff_info[feat.id] = { + "strand": feat.strand, + "start": feat.location.start, + "end": feat.location.end, + "loc": feat.location, + "feat": feat, + "name": CDSname, + } + + gff_info = OrderedDict(sorted(gff_info.items(), key=lambda k: k[1]["start"])) + # endBase = 0 + for i, feat_id in enumerate(gff_info): + gff_info[feat_id].update({"index": i}) + if gff_info[feat_id]["loc"].end > endBase: + endBase = gff_info[feat_id]["loc"].end + + return dict(gff_info), _rec, endBase + + +def all_same(genes_list): + """ Returns True if all gene names in cluster are identical """ + return all(gene["name"] == genes_list[0]["name"] for gene in genes_list[1:]) + + +def remove_duplicates(clusters): + """ Removes clusters with multiple members but only one gene name """ + filtered_clusters = {} + for key in clusters: + if all_same(clusters[key]): + continue + else: + filtered_clusters[key] = clusters[key] + log.debug("remove_duplicates %s -> %s", len(clusters), len(filtered_clusters)) + return filtered_clusters + + +class IntronFinder(object): + """ IntronFinder objects are lists that contain a list of hits for every gene """ + + def __init__(self, gff3, blastp, thresh): + self.blast = [] + self.clusters = {} + self.gff_info = {} + self.length = 0 + + (self.gff_info, self.rec, self.length) = parse_gff(gff3) + self.blast = parse_xml(blastp, thresh) + + def create_clusters(self): + """ Finds 2 or more genes with matching hits """ + clusters = {} + for gene in self.blast: + for hit in gene: + if " " in hit["gi_nos"]: + hit["gi_nos"] = hit["gi_nos"][0 : hit["gi_nos"].index(" ")] + + nameCheck = hit["gi_nos"] + if nameCheck == "": + continue + name = hashlib.md5((nameCheck).encode()).hexdigest() + + if name in clusters: + if hit not in clusters[name]: + clusters[name].append(hit) + else: + clusters[name] = [hit] + log.debug("create_clusters %s -> %s", len(self.blast), len(clusters)) + self.clusters = filter_lone_clusters(clusters) + + def check_strand(self): + """ filters clusters for genes on the same strand """ + filtered_clusters = {} + for key in self.clusters: + pos_strand = [] + neg_strand = [] + for gene in self.clusters[key]: + if self.gff_info[gene["name"]]["strand"] == 1: + pos_strand.append(gene) + else: + neg_strand.append(gene) + if len(pos_strand) == 0 or len(neg_strand) == 0: + filtered_clusters[key] = self.clusters[key] + else: + if len(pos_strand) > 1: + filtered_clusters[key + "_+1"] = pos_strand + if len(neg_strand) > 1: + filtered_clusters[key + "_-1"] = neg_strand + + return filtered_clusters + + def check_gene_gap(self, maximum=10000): + filtered_clusters = {} + for key in self.clusters: + hits_lists = [] + gene_added = False + for gene in self.clusters[key]: + for hits in hits_lists: + for hit in hits: + lastStart = max( + self.gff_info[gene["name"]]["start"], + self.gff_info[hit["name"]]["start"], + ) + lastEnd = max( + self.gff_info[gene["name"]]["end"], + self.gff_info[hit["name"]]["end"], + ) + firstEnd = min( + self.gff_info[gene["name"]]["end"], + self.gff_info[hit["name"]]["end"], + ) + firstStart = min( + self.gff_info[gene["name"]]["start"], + self.gff_info[hit["name"]]["start"], + ) + if ( + lastStart - firstEnd <= maximum + or self.length - lastEnd + firstStart <= maximum + ): + hits.append(gene) + gene_added = True + break + if not gene_added: + hits_lists.append([gene]) + + for i, hits in enumerate(hits_lists): + if len(hits) >= 2: + filtered_clusters[key + "_" + str(i)] = hits + # for i in filtered_clusters: + # print(i) + # print(filtered_clusters[i]) + log.debug("check_gene_gap %s -> %s", len(self.clusters), len(filtered_clusters)) + + return remove_duplicates( + filtered_clusters + ) # call remove_duplicates somewhere else? + + # maybe figure out how to merge with check_gene_gap? + # def check_seq_gap(): + + # also need a check for gap in sequence coverage? + def check_seq_overlap(self, minimum=-1): + filtered_clusters = {} + for key in self.clusters: + add_cluster = True + sbjct_ranges = [] + query_ranges = [] + for gene in self.clusters[key]: + sbjct_ranges.append(gene["sbjct_range"]) + query_ranges.append(gene["query_range"]) + + combinations = list(itertools.combinations(sbjct_ranges, 2)) + + for pair in combinations: + overlap = len( + set(range(pair[0][0], pair[0][1])) + & set(range(pair[1][0], pair[1][1])) + ) + minPair = pair[0] + maxPair = pair[1] + + if minPair[0] > maxPair[0]: + minPair = pair[1] + maxPair = pair[0] + elif minPair[0] == maxPair[0] and minPair[1] > maxPair[1]: + minPair = pair[1] + maxPair = pair[0] + if overlap > 0: + dist1 = maxPair[0] - minPair[0] + else: + dist1 = abs(maxPair[0] - minPair[1]) + + if minimum < 0: + if overlap > (minimum * -1): + # print("Rejcting: Neg min but too much overlap: " + str(pair)) + add_cluster = False + elif minimum == 0: + if overlap > 0: + # print("Rejcting: 0 min and overlap: " + str(pair)) + add_cluster = False + elif overlap > 0: + # print("Rejcting: Pos min and overlap: " + str(pair)) + add_cluster = False + + if (dist1 < minimum) and (minimum >= 0): + # print("Rejcting: Dist failure: " + str(pair) + " D1: " + dist1) + add_cluster = False + # if add_cluster: + # print("Accepted: " + str(pair) + " D1: " + str(dist1) + " Ov: " + str(overlap)) + if add_cluster: + + filtered_clusters[key] = self.clusters[key] + + log.debug( + "check_seq_overlap %s -> %s", len(self.clusters), len(filtered_clusters) + ) + # print(self.clusters) + return filtered_clusters + + def cluster_report(self): + condensed_report = {} + for key in self.clusters: + for gene in self.clusters[key]: + if gene["name"] in condensed_report: + condensed_report[gene["name"]].append(gene["sbjct_range"]) + else: + condensed_report[gene["name"]] = [gene["sbjct_range"]] + return condensed_report + + def cluster_report_2(self): + condensed_report = {} + for key in self.clusters: + gene_names = [] + for gene in self.clusters[key]: + gene_names.append((gene["name"]).strip("CPT_phageK_")) + if ", ".join(gene_names) in condensed_report: + condensed_report[", ".join(gene_names)] += 1 + else: + condensed_report[", ".join(gene_names)] = 1 + return condensed_report + + def cluster_report_3(self): + condensed_report = {} + for key in self.clusters: + gene_names = [] + gi_nos = [] + for i, gene in enumerate(self.clusters[key]): + if i == 0: + gi_nos = gene["gi_nos"] + gene_names.append((gene["name"]).strip(".p01").strip("CPT_phageK_gp")) + if ", ".join(gene_names) in condensed_report: + condensed_report[", ".join(gene_names)].append(gi_nos) + else: + condensed_report[", ".join(gene_names)] = [gi_nos] + return condensed_report + + def output_gff3(self, clusters): + rec = copy.deepcopy(self.rec) + rec.features = [] + for cluster_idx, cluster_id in enumerate(clusters): + # Get the list of genes in this cluster + associated_genes = set([x["name"] for x in clusters[cluster_id]]) + # print(associated_genes) + # Get the gene locations + assoc_gene_info = {x: self.gff_info[x]["loc"] for x in associated_genes} + # Now we construct a gene from the children as a "standard gene model" gene. + # Get the minimum and maximum locations covered by all of the children genes + gene_min = min([min(x[1].start, x[1].end) for x in assoc_gene_info.items()]) + gene_max = max([max(x[1].start, x[1].end) for x in assoc_gene_info.items()]) + + evidence_notes = [] + for cluster_elem in clusters[cluster_id]: + note = "{name} had {ident}% identity to NCBI Protein ID {pretty_gi}".format( + pretty_gi=(cluster_elem["gi_nos"]), + ident=int( + 100 + * float(cluster_elem["identity"] - 1.00) + / abs( + cluster_elem["query_range"][1] + - cluster_elem["query_range"][0] + ) + ), + **cluster_elem + ) + evidence_notes.append(note) + if gene_max - gene_min > 0.8 * float(self.length): + evidence_notes.append( + "Intron is over 80% of the total length of the genome, possible wraparound scenario" + ) + # With that we can create the top level gene + gene = gffSeqFeature( + location=FeatureLocation(gene_min, gene_max), + type="gene", + id=cluster_id, + qualifiers={ + "ID": ["gp_%s" % cluster_idx], + "Percent_Identities": evidence_notes, + "Note": clusters[cluster_id][0]["match_id"], + }, + ) + + # Below that we have an mRNA + mRNA = gffSeqFeature( + location=FeatureLocation(gene_min, gene_max), + type="mRNA", + id=cluster_id + ".mRNA", + qualifiers={"ID": ["gp_%s.mRNA" % cluster_idx], "note": evidence_notes}, + ) + + # Now come the CDSs. + cdss = [] + # We sort them just for kicks + for idx, gene_name in enumerate( + sorted(associated_genes, key=lambda x: int(self.gff_info[x]["start"])) + ): + # Copy the CDS so we don't muck up a good one + cds = copy.copy(self.gff_info[gene_name]["feat"]) + # Get the associated cluster element (used in the Notes above) + cluster_elem = [ + x for x in clusters[cluster_id] if x["name"] == gene_name + ][0] + + # Calculate %identity which we'll use to score + score = int( + 1000 + * float(cluster_elem["identity"]) + / abs( + cluster_elem["query_range"][1] - cluster_elem["query_range"][0] + ) + ) + + tempLoc = FeatureLocation( + cds.location.start + (3 * (cluster_elem["query_range"][0] - 1)), + cds.location.start + (3 * (cluster_elem["query_range"][1])), + cds.location.strand, + ) + cds.location = tempLoc + # Set the qualifiers appropriately + cds.qualifiers = { + "ID": ["gp_%s.CDS.%s" % (cluster_idx, idx)], + "score": score, + "Name": self.gff_info[gene_name]["name"], + "evalue": cluster_elem["evalue"], + "Identity": cluster_elem["identity_percent"] * 100, + #'|'.join(cluster_elem['gi_nos']) + "| title goes here." + } + # cds.location.start = cds.location.start + + cdss.append(cds) + + # And we attach the things properly. + mRNA.sub_features = cdss + mRNA.location = FeatureLocation(mRNA.location.start, mRNA.location.end, cds.location.strand) + gene.sub_features = [mRNA] + gene.location = FeatureLocation(gene.location.start, gene.location.end, cds.location.strand) + + # And append to our record + rec.features.append(gene) + return rec + + def output_xml(self, clusters): + threeLevel = {} + # print((clusters.viewkeys())) + # print(type(enumerate(clusters))) + # print(type(clusters)) + for cluster_idx, cluster_id in enumerate(clusters): + # print(type(cluster_id)) + # print(type(cluster_idx)) + # print(type(clusters[cluster_id][0]['hit_num'])) + if not (clusters[cluster_id][0]["iter_num"] in threeLevel.keys): + threeLevel[clusters[cluster_id][0]["iter_num"]] = {} + # for cluster_idx, cluster_id in enumerate(clusters): + # print(type(clusters[cluster_id])) + # b = {clusters[cluster_id][i]: clusters[cluster_id][i+1] for i in range(0, len(clusters[cluster_id]), 2)} + # print(type(b))#['name'])) + # for hspList in clusters: + # for x, idx in (enumerate(clusters)):#for hsp in hspList: + # print("In X") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Intron detection") + parser.add_argument("gff3", type=argparse.FileType("r"), help="GFF3 gene calls") + parser.add_argument( + "blastp", type=argparse.FileType("r"), help="blast XML protein results" + ) + parser.add_argument( + "--minimum", + help="Gap minimum (Default -1, set to a negative number to allow overlap)", + default=-1, + type=int, + ) + parser.add_argument( + "--maximum", + help="Gap maximum in genome (Default 10000)", + default=10000, + type=int, + ) + parser.add_argument( + "--idThresh", help="ID Percent Threshold", default=0.4, type=float + ) + + args = parser.parse_args() + + threshCap = args.idThresh + if threshCap > 1.00: + threshCap = 1.00 + if threshCap < 0: + threshCap = 0 + + # create new IntronFinder object based on user input + ifinder = IntronFinder(args.gff3, args.blastp, threshCap) + ifinder.create_clusters() + ifinder.clusters = ifinder.check_strand() + ifinder.clusters = ifinder.check_gene_gap(maximum=args.maximum) + ifinder.clusters = ifinder.check_seq_overlap(minimum=args.minimum) + # ifinder.output_xml(ifinder.clusters) + # for x, idx in (enumerate(ifinder.clusters)): + # print(ifinder.blast) + + condensed_report = ifinder.cluster_report() + rec = ifinder.output_gff3(ifinder.clusters) + gffWrite([rec], sys.stdout) + + # import pprint; pprint.pprint(ifinder.clusters)