1 #!/usr/bin/env python
2 import logging
3 import argparse
4 from intervaltree import IntervalTree, Interval
5 from CPT_GFFParser import gffParse, gffWrite
6 from Bio.SeqRecord import SeqRecord
7 from Bio.Seq import Seq
9 logging.basicConfig(level=logging.INFO)
10 log = logging.getLogger(__name__)
12 def validFeat(rec):
13 for feat in rec.features:
14 if feat.type != 'remark' and feat.type != 'annotation':
15 return True
16 return False
18 def treeFeatures(features, window):
19 for feat in features:
20 # Interval(begin, end, data)
21 yield Interval(
22 int(feat.location.start) - int(window),
23 int(feat.location.end) + int(window),
24 feat.id,
25 )
26 def treeFeatures_noRem(features, window):
27 for feat in features:
28 if feat.type == 'remark' or feat.type == 'annotation':
29 continue
30 # Interval(begin, end, data)
31 yield Interval(
32 int(feat.location.start) - int(window),
33 int(feat.location.end) + int(window),
34 feat.id,
35 )
38 def intersect(a, b, window, stranding):
39 rec_a = list(gffParse(a))
40 rec_b = list(gffParse(b))
41 rec_a_out = []
42 rec_b_out = []
43 maxLen = min(len(rec_a), len(rec_b))
44 iterate = 0
47 if maxLen > 0:
48 while iterate < maxLen:
49 rec_a_i = rec_a[iterate]
50 rec_b_i = rec_b[iterate]
52 if (not validFeat(rec_a_i)) or (not validFeat(rec_b_i)):
53 rec_a_out.append(SeqRecord(rec_a[iterate].seq, rec_a[iterate].id, rec_a[iterate].name, rec_a[iterate].description, rec_a[iterate].dbxrefs, [], rec_a[iterate].annotations))
54 rec_b_out.append(SeqRecord(rec_b[iterate].seq, rec_b[iterate].id, rec_b[iterate].name, rec_b[iterate].description, rec_b[iterate].dbxrefs, [], rec_b[iterate].annotations))
55 iterate += 1
56 continue
58 a_neg = []
59 a_pos = []
60 b_neg = []
61 b_pos = []
62 tree_a = []
63 tree_b = []
64 if stranding == True:
65 for feat in rec_a_i.features:
66 if feat.type == 'remark' or feat.type == 'annotation':
67 continue
68 if feat.strand > 0:
69 a_pos.append(
70 Interval(
71 int(feat.location.start) - int(window),
72 int(feat.location.end) + int(window),
73 feat.id,
74 )
75 )
76 else:
77 a_neg.append(
78 Interval(
79 int(feat.location.start) - int(window),
80 int(feat.location.end) + int(window),
81 feat.id,
82 )
83 )
85 for feat in rec_b_i.features:
86 if feat.type == 'remark' or feat.type == 'annotation':
87 continue
88 if feat.strand > 0:
89 b_pos.append(
90 Interval(
91 int(feat.location.start) - int(window),
92 int(feat.location.end) + int(window),
93 feat.id,
94 )
95 )
96 else:
97 b_neg.append(
98 Interval(
99 int(feat.location.start) - int(window),
100 int(feat.location.end) + int(window),
101 feat.id,
102 )
103 )
105 else:
106 for feat in rec_a_i.features:
107 if feat.type == 'remark' or feat.type == 'annotation':
108 continue
109 tree_a.append(
110 Interval(
111 int(feat.location.start) - int(window),
112 int(feat.location.end) + int(window),
113 feat.id,
114 )
115 )
116 for feat in rec_b_i.features:
117 if feat.type == 'remark' or feat.type == 'annotation':
118 continue
119 tree_b.append(
120 Interval(
121 int(feat.location.start) - int(window),
122 int(feat.location.end) + int(window),
123 feat.id,
124 )
125 )
126 if stranding:
127 # builds interval tree from Interval objects of form (start, end, id) for each feature
128 # tree_a = IntervalTree(list(treeFeatures_noRem(rec_a_i.features, window)))
129 #tree_b = IntervalTree(list(treeFeatures_noRem(rec_b_i.features, window)))
130 #else:
131 tree_a_pos = IntervalTree(a_pos)
132 tree_a_neg = IntervalTree(a_neg)
133 tree_b_pos = IntervalTree(b_pos)
134 tree_b_neg = IntervalTree(b_neg)
135 else:
136 tree_a = IntervalTree(tree_a)
137 tree_b = IntervalTree(tree_b)
140 # Used to map ids back to features later
141 rec_a_map = {f.id: f for f in rec_a_i.features}
142 rec_b_map = {f.id: f for f in rec_b_i.features}
144 rec_a_hits_in_b = []
145 rec_b_hits_in_a = []
147 for feature in rec_a_i.features:
148 # Save each feature in rec_a that overlaps a feature in rec_b
149 # hits = tree_b.find_range((int(feature.location.start), int(feature.location.end)))
151 if feature.type == "remark" or feature.type == "annotation":
152 continue
154 if stranding == False:
155 hits = tree_b[int(feature.location.start) : int(feature.location.end)]
158 # feature id is saved in interval result.data, use map to get full feature
159 for hit in hits:
160 rec_a_hits_in_b.append(rec_b_map[hit.data])
162 else:
163 if feature.strand > 0:
164 hits_pos = tree_b_pos[
165 int(feature.location.start) : int(feature.location.end)
166 ]
167 for hit in hits_pos:
168 rec_a_hits_in_b.append(rec_b_map[hit.data])
169 else:
170 hits_neg = tree_b_neg[
171 int(feature.location.start) : int(feature.location.end)
172 ]
173 for hit in hits_neg:
174 rec_a_hits_in_b.append(rec_b_map[hit.data])
176 for feature in rec_b_i.features:
177 if feature.type == "remark" or feature.type == "annotation":
178 continue
180 if stranding == False:
181 hits = tree_a[int(feature.location.start) : int(feature.location.end)]
183 # feature id is saved in interval result.data, use map to get full feature
184 for hit in hits:
185 rec_b_hits_in_a.append(rec_a_map[hit.data])
187 else:
188 if feature.strand > 0:
189 hits_pos = tree_a_pos[
190 int(feature.location.start) : int(feature.location.end)
191 ]
192 for hit in hits_pos:
193 rec_b_hits_in_a.append(rec_a_map[hit.data])
194 else:
195 hits_neg = tree_a_neg[
196 int(feature.location.start) : int(feature.location.end)
197 ]
198 for hit in hits_neg:
199 rec_b_hits_in_a.append(rec_a_map[hit.data])
201 # Remove duplicate features using sets
202 rec_a_out.append(SeqRecord(rec_a[iterate].seq, rec_a[iterate].id, rec_a[iterate].name, rec_a[iterate].description, rec_a[iterate].dbxrefs, sorted(set(rec_a_hits_in_b), key=lambda feat: feat.location.start), rec_a[iterate].annotations))
203 rec_b_out.append(SeqRecord(rec_b[iterate].seq, rec_b[iterate].id, rec_b[iterate].name, rec_b[iterate].description, rec_b[iterate].dbxrefs, sorted(set(rec_b_hits_in_a), key=lambda feat: feat.location.start), rec_b[iterate].annotations))
204 iterate += 1
206 else:
207 # If one input is empty, output two empty result files.
208 rec_a_out = [SeqRecord(Seq(""), "none")]
209 rec_b_out = [SeqRecord(Seq(""), "none")]
210 return rec_a_out, rec_b_out
213 if __name__ == "__main__":
214 parser = argparse.ArgumentParser(
215 description="rebase gff3 features against parent locations", epilog=""
216 )
217 parser.add_argument("a", type=argparse.FileType("r"))
218 parser.add_argument("b", type=argparse.FileType("r"))
219 parser.add_argument(
220 "window",
221 type=int,
222 default=50,
223 help="Allows features this far away to still be considered 'adjacent'",
224 )
225 parser.add_argument(
226 "-stranding",
227 action="store_true",
228 help="Only allow adjacency for same-strand features",
229 )
230 parser.add_argument("--oa", type=str, default="a_hits_near_b.gff")
231 parser.add_argument("--ob", type=str, default="b_hits_near_a.gff")
232 args = parser.parse_args()
234 b, a = intersect(args.a, args.b, args.window, args.stranding)
236 with open(args.oa, "w") as handle:
237 for rec in a:
238 gffWrite([rec], handle)
240 with open(args.ob, "w") as handle:
241 for rec in b:
242 gffWrite([rec], handle)