Mercurial > repos > cpt > cpt_intron_detection
view cpt_intron_detect/intron_detection.py @ 2:e0e294a1ccdd draft
Uploaded
author | cpt |
---|---|
date | Fri, 20 May 2022 08:07:39 +0000 |
parents | 1a19092729be |
children |
line wrap: on
line source
#!/usr/bin/env python import sys import re import itertools import argparse import hashlib import copy from CPT_GFFParser import gffParse, gffWrite, gffSeqFeature from Bio.Blast import NCBIXML from Bio.SeqFeature import SeqFeature, FeatureLocation from gff3 import feature_lambda from collections import OrderedDict import logging logging.basicConfig(level=logging.DEBUG) log = logging.getLogger() def parse_xml(blastxml, thresh): """ Parses xml file to get desired info (genes, hits, etc) """ blast = [] discarded_records = 0 totLen = 0 for iter_num, blast_record in enumerate(NCBIXML.parse(blastxml), 1): blast_gene = [] align_num = 0 for alignment in blast_record.alignments: align_num += 1 # hit_gis = alignment.hit_id + alignment.hit_def # gi_nos = [str(gi) for gi in re.findall('(?<=gi\|)\d{9,10}', hit_gis)] gi_nos = str(alignment.accession) for hsp in alignment.hsps: x = float(hsp.identities - 1) / ((hsp.query_end) - hsp.query_start) if x < thresh: discarded_records += 1 continue nice_name = blast_record.query if " " in nice_name: nice_name = nice_name[0 : nice_name.index(" ")] blast_gene.append( { "gi_nos": gi_nos, "sbjct_length": alignment.length, "query_length": blast_record.query_length, "sbjct_range": (hsp.sbjct_start, hsp.sbjct_end), "query_range": (hsp.query_start, hsp.query_end), "name": nice_name, "evalue": hsp.expect, "identity": hsp.identities, "identity_percent": x, "hit_num": align_num, "iter_num": iter_num, "match_id": alignment.title.partition(">")[0], } ) blast.append(blast_gene) totLen += len(blast_gene) log.debug("parse_blastxml %s -> %s", totLen + discarded_records, totLen) return blast def filter_lone_clusters(clusters): """ Removes all clusters with only one member and those with no hits """ filtered_clusters = {} for key in clusters: if len(clusters[key]) > 1 and len(key) > 0: filtered_clusters[key] = clusters[key] log.debug("filter_lone_clusters %s -> %s", len(clusters), len(filtered_clusters)) return filtered_clusters def test_true(feature, **kwargs): return True def parse_gff(gff3): """ Extracts strand and start location to be used in cluster filtering """ log.debug("parse_gff3") gff_info = {} _rec = None for rec in gffParse(gff3): endBase = len(rec.seq) _rec = rec _rec.annotations = {} for feat in feature_lambda(rec.features, test_true, {}, subfeatures=False): if feat.type == "CDS": if "Name" in feat.qualifiers.keys(): CDSname = feat.qualifiers["Name"] else: CDSname = feat.qualifiers["ID"] gff_info[feat.id] = { "strand": feat.strand, "start": feat.location.start, "end": feat.location.end, "loc": feat.location, "feat": feat, "name": CDSname, } gff_info = OrderedDict(sorted(gff_info.items(), key=lambda k: k[1]["start"])) # endBase = 0 for i, feat_id in enumerate(gff_info): gff_info[feat_id].update({"index": i}) if gff_info[feat_id]["loc"].end > endBase: endBase = gff_info[feat_id]["loc"].end return dict(gff_info), _rec, endBase def all_same(genes_list): """ Returns True if all gene names in cluster are identical """ return all(gene["name"] == genes_list[0]["name"] for gene in genes_list[1:]) def remove_duplicates(clusters): """ Removes clusters with multiple members but only one gene name """ filtered_clusters = {} for key in clusters: if all_same(clusters[key]): continue else: filtered_clusters[key] = clusters[key] log.debug("remove_duplicates %s -> %s", len(clusters), len(filtered_clusters)) return filtered_clusters class IntronFinder(object): """ IntronFinder objects are lists that contain a list of hits for every gene """ def __init__(self, gff3, blastp, thresh): self.blast = [] self.clusters = {} self.gff_info = {} self.length = 0 (self.gff_info, self.rec, self.length) = parse_gff(gff3) self.blast = parse_xml(blastp, thresh) def create_clusters(self): """ Finds 2 or more genes with matching hits """ clusters = {} for gene in self.blast: for hit in gene: if " " in hit["gi_nos"]: hit["gi_nos"] = hit["gi_nos"][0 : hit["gi_nos"].index(" ")] nameCheck = hit["gi_nos"] if nameCheck == "": continue name = hashlib.md5((nameCheck).encode()).hexdigest() if name in clusters: if hit not in clusters[name]: clusters[name].append(hit) else: clusters[name] = [hit] log.debug("create_clusters %s -> %s", len(self.blast), len(clusters)) self.clusters = filter_lone_clusters(clusters) def check_strand(self): """ filters clusters for genes on the same strand """ filtered_clusters = {} for key in self.clusters: pos_strand = [] neg_strand = [] for gene in self.clusters[key]: if self.gff_info[gene["name"]]["strand"] == 1: pos_strand.append(gene) else: neg_strand.append(gene) if len(pos_strand) == 0 or len(neg_strand) == 0: filtered_clusters[key] = self.clusters[key] else: if len(pos_strand) > 1: filtered_clusters[key + "_+1"] = pos_strand if len(neg_strand) > 1: filtered_clusters[key + "_-1"] = neg_strand return filtered_clusters def check_gene_gap(self, maximum=10000): filtered_clusters = {} for key in self.clusters: hits_lists = [] gene_added = False for gene in self.clusters[key]: for hits in hits_lists: for hit in hits: lastStart = max( self.gff_info[gene["name"]]["start"], self.gff_info[hit["name"]]["start"], ) lastEnd = max( self.gff_info[gene["name"]]["end"], self.gff_info[hit["name"]]["end"], ) firstEnd = min( self.gff_info[gene["name"]]["end"], self.gff_info[hit["name"]]["end"], ) firstStart = min( self.gff_info[gene["name"]]["start"], self.gff_info[hit["name"]]["start"], ) if ( lastStart - firstEnd <= maximum or self.length - lastEnd + firstStart <= maximum ): hits.append(gene) gene_added = True break if not gene_added: hits_lists.append([gene]) for i, hits in enumerate(hits_lists): if len(hits) >= 2: filtered_clusters[key + "_" + str(i)] = hits # for i in filtered_clusters: # print(i) # print(filtered_clusters[i]) log.debug("check_gene_gap %s -> %s", len(self.clusters), len(filtered_clusters)) return remove_duplicates( filtered_clusters ) # call remove_duplicates somewhere else? # maybe figure out how to merge with check_gene_gap? # def check_seq_gap(): # also need a check for gap in sequence coverage? def check_seq_overlap(self, minimum=-1): filtered_clusters = {} for key in self.clusters: add_cluster = True sbjct_ranges = [] query_ranges = [] for gene in self.clusters[key]: sbjct_ranges.append(gene["sbjct_range"]) query_ranges.append(gene["query_range"]) combinations = list(itertools.combinations(sbjct_ranges, 2)) for pair in combinations: overlap = len( set(range(pair[0][0], pair[0][1])) & set(range(pair[1][0], pair[1][1])) ) minPair = pair[0] maxPair = pair[1] if minPair[0] > maxPair[0]: minPair = pair[1] maxPair = pair[0] elif minPair[0] == maxPair[0] and minPair[1] > maxPair[1]: minPair = pair[1] maxPair = pair[0] if overlap > 0: dist1 = maxPair[0] - minPair[0] else: dist1 = abs(maxPair[0] - minPair[1]) if minimum < 0: if overlap > (minimum * -1): # print("Rejcting: Neg min but too much overlap: " + str(pair)) add_cluster = False elif minimum == 0: if overlap > 0: # print("Rejcting: 0 min and overlap: " + str(pair)) add_cluster = False elif overlap > 0: # print("Rejcting: Pos min and overlap: " + str(pair)) add_cluster = False if (dist1 < minimum) and (minimum >= 0): # print("Rejcting: Dist failure: " + str(pair) + " D1: " + dist1) add_cluster = False # if add_cluster: # print("Accepted: " + str(pair) + " D1: " + str(dist1) + " Ov: " + str(overlap)) if add_cluster: filtered_clusters[key] = self.clusters[key] log.debug( "check_seq_overlap %s -> %s", len(self.clusters), len(filtered_clusters) ) # print(self.clusters) return filtered_clusters def cluster_report(self): condensed_report = {} for key in self.clusters: for gene in self.clusters[key]: if gene["name"] in condensed_report: condensed_report[gene["name"]].append(gene["sbjct_range"]) else: condensed_report[gene["name"]] = [gene["sbjct_range"]] return condensed_report def cluster_report_2(self): condensed_report = {} for key in self.clusters: gene_names = [] for gene in self.clusters[key]: gene_names.append((gene["name"]).strip("CPT_phageK_")) if ", ".join(gene_names) in condensed_report: condensed_report[", ".join(gene_names)] += 1 else: condensed_report[", ".join(gene_names)] = 1 return condensed_report def cluster_report_3(self): condensed_report = {} for key in self.clusters: gene_names = [] gi_nos = [] for i, gene in enumerate(self.clusters[key]): if i == 0: gi_nos = gene["gi_nos"] gene_names.append((gene["name"]).strip(".p01").strip("CPT_phageK_gp")) if ", ".join(gene_names) in condensed_report: condensed_report[", ".join(gene_names)].append(gi_nos) else: condensed_report[", ".join(gene_names)] = [gi_nos] return condensed_report def output_gff3(self, clusters): rec = copy.deepcopy(self.rec) rec.features = [] for cluster_idx, cluster_id in enumerate(clusters): # Get the list of genes in this cluster associated_genes = set([x["name"] for x in clusters[cluster_id]]) # print(associated_genes) # Get the gene locations assoc_gene_info = {x: self.gff_info[x]["loc"] for x in associated_genes} # Now we construct a gene from the children as a "standard gene model" gene. # Get the minimum and maximum locations covered by all of the children genes gene_min = min([min(x[1].start, x[1].end) for x in assoc_gene_info.items()]) gene_max = max([max(x[1].start, x[1].end) for x in assoc_gene_info.items()]) evidence_notes = [] for cluster_elem in clusters[cluster_id]: note = "{name} had {ident}% identity to NCBI Protein ID {pretty_gi}".format( pretty_gi=(cluster_elem["gi_nos"]), ident=int( 100 * float(cluster_elem["identity"] - 1.00) / abs( cluster_elem["query_range"][1] - cluster_elem["query_range"][0] ) ), **cluster_elem ) evidence_notes.append(note) if gene_max - gene_min > 0.8 * float(self.length): evidence_notes.append( "Intron is over 80% of the total length of the genome, possible wraparound scenario" ) # With that we can create the top level gene gene = gffSeqFeature( location=FeatureLocation(gene_min, gene_max), type="gene", id=cluster_id, qualifiers={ "ID": ["gp_%s" % cluster_idx], "Percent_Identities": evidence_notes, "Note": clusters[cluster_id][0]["match_id"], }, ) # Below that we have an mRNA mRNA = gffSeqFeature( location=FeatureLocation(gene_min, gene_max), type="mRNA", id=cluster_id + ".mRNA", qualifiers={"ID": ["gp_%s.mRNA" % cluster_idx], "note": evidence_notes}, ) # Now come the CDSs. cdss = [] # We sort them just for kicks for idx, gene_name in enumerate( sorted(associated_genes, key=lambda x: int(self.gff_info[x]["start"])) ): # Copy the CDS so we don't muck up a good one cds = copy.copy(self.gff_info[gene_name]["feat"]) # Get the associated cluster element (used in the Notes above) cluster_elem = [ x for x in clusters[cluster_id] if x["name"] == gene_name ][0] # Calculate %identity which we'll use to score score = int( 1000 * float(cluster_elem["identity"]) / abs( cluster_elem["query_range"][1] - cluster_elem["query_range"][0] ) ) tempLoc = FeatureLocation( cds.location.start + (3 * (cluster_elem["query_range"][0] - 1)), cds.location.start + (3 * (cluster_elem["query_range"][1])), cds.location.strand, ) cds.location = tempLoc # Set the qualifiers appropriately cds.qualifiers = { "ID": ["gp_%s.CDS.%s" % (cluster_idx, idx)], "score": score, "Name": self.gff_info[gene_name]["name"], "evalue": cluster_elem["evalue"], "Identity": cluster_elem["identity_percent"] * 100, #'|'.join(cluster_elem['gi_nos']) + "| title goes here." } # cds.location.start = cds.location.start + cdss.append(cds) # And we attach the things properly. mRNA.sub_features = cdss mRNA.location = FeatureLocation(mRNA.location.start, mRNA.location.end, cds.location.strand) gene.sub_features = [mRNA] gene.location = FeatureLocation(gene.location.start, gene.location.end, cds.location.strand) # And append to our record rec.features.append(gene) return rec def output_xml(self, clusters): threeLevel = {} # print((clusters.viewkeys())) # print(type(enumerate(clusters))) # print(type(clusters)) for cluster_idx, cluster_id in enumerate(clusters): # print(type(cluster_id)) # print(type(cluster_idx)) # print(type(clusters[cluster_id][0]['hit_num'])) if not (clusters[cluster_id][0]["iter_num"] in threeLevel.keys): threeLevel[clusters[cluster_id][0]["iter_num"]] = {} # for cluster_idx, cluster_id in enumerate(clusters): # print(type(clusters[cluster_id])) # b = {clusters[cluster_id][i]: clusters[cluster_id][i+1] for i in range(0, len(clusters[cluster_id]), 2)} # print(type(b))#['name'])) # for hspList in clusters: # for x, idx in (enumerate(clusters)):#for hsp in hspList: # print("In X") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Intron detection") parser.add_argument("gff3", type=argparse.FileType("r"), help="GFF3 gene calls") parser.add_argument( "blastp", type=argparse.FileType("r"), help="blast XML protein results" ) parser.add_argument( "--minimum", help="Gap minimum (Default -1, set to a negative number to allow overlap)", default=-1, type=int, ) parser.add_argument( "--maximum", help="Gap maximum in genome (Default 10000)", default=10000, type=int, ) parser.add_argument( "--idThresh", help="ID Percent Threshold", default=0.4, type=float ) args = parser.parse_args() threshCap = args.idThresh if threshCap > 1.00: threshCap = 1.00 if threshCap < 0: threshCap = 0 # create new IntronFinder object based on user input ifinder = IntronFinder(args.gff3, args.blastp, threshCap) ifinder.create_clusters() ifinder.clusters = ifinder.check_strand() ifinder.clusters = ifinder.check_gene_gap(maximum=args.maximum) ifinder.clusters = ifinder.check_seq_overlap(minimum=args.minimum) # ifinder.output_xml(ifinder.clusters) # for x, idx in (enumerate(ifinder.clusters)): # print(ifinder.blast) condensed_report = ifinder.cluster_report() rec = ifinder.output_gff3(ifinder.clusters) gffWrite([rec], sys.stdout) # import pprint; pprint.pprint(ifinder.clusters)