comparison diagonal_partition.py @ 9:08e987868f0f draft

planemo upload for repository https://github.com/richard-burhans/galaxytools/tree/main/tools/segalign commit 062a761a340e095ea7ef7ed7cd1d3d55b1fdc5c4
author richard-burhans
date Wed, 10 Jul 2024 17:06:45 +0000
parents
children 25fa179d9d0a
comparison
equal deleted inserted replaced
8:150de8a3954a 9:08e987868f0f
1 #!/usr/bin/env python
2
3 """
4 Diagonal partitioning for segment files output by SegAlign.
5
6 Usage:
7 diagonal_partition.py <max-segments> <lastz-command>
8
9 set <max-segments> = 0 to skip partitioning, -1 to infer best parameter
10 """
11
12
13 import os
14 import sys
15 import typing
16
17
18 def chunks(lst: tuple[str, ...], n: int) -> typing.Iterator[tuple[str, ...]]:
19 """Yield successive n-sized chunks from list."""
20 for i in range(0, len(lst), n):
21 yield lst[i:i + n]
22
23
24 if __name__ == "__main__":
25
26 DELETE_AFTER_CHUNKING = True
27 MIN_CHUNK_SIZE = 5000 # don't partition segment files with line count below this value
28
29 # input_params = "10000 sad sadsa sad --segments=tmp21.block0.r0.minus.segments dsa sa --strand=plus --output=out.maf sadads 2> logging.err"
30 # sys.argv = [sys.argv[0]] + input_params.split(' ')
31 chunk_size = int(sys.argv[1]) # first parameter contains chunk size
32 params = sys.argv[2:]
33
34 # don't do anything if 0 chunk size
35 if chunk_size == 0:
36 print(" ".join(params), flush=True)
37 exit(0)
38
39 # Parsing command output from SegAlign
40 segment_key = "--segments="
41 segment_index = None
42 input_file = None
43
44 for index, value in enumerate(params):
45 if value[:len(segment_key)] == segment_key:
46 segment_index = index
47 input_file = value[len(segment_key):]
48 break
49
50 if segment_index is None:
51 sys.exit(f"Error: could not get segment key {segment_key} from parameters {params}")
52
53 if input_file is None:
54 sys.exit(f"Error: could not get segment file from parameters {params}")
55
56 if not os.path.isfile(input_file):
57 sys.exit(f"Error: File {input_file} does not exist")
58
59 line_size = None # each char in 1 byte
60 file_size = os.path.getsize(input_file)
61 with open(input_file, "r") as f:
62 line_size = len(f.readline()) # add 1 for newline
63
64 estimated_lines = file_size // line_size
65
66 # check if chunk size should be estimated
67 if chunk_size < 0:
68 # optimization, do not need to get each file size in this case
69 if estimated_lines < MIN_CHUNK_SIZE:
70 print(" ".join(params), flush=True)
71 exit(0)
72
73 from collections import defaultdict
74 from statistics import quantiles
75 # get size of each segment assuming DELETE_AFTER_CHUNKING == True
76 # takes into account already split segments
77 files = [i for i in os.listdir(".") if i.endswith(".segments")]
78 # find . -maxdepth 1 -name "*.segments" -print0 | du -ba --files0-from=-
79 # if not enough segment files for estimation, continue
80 if len(files) <= 2:
81 print(" ".join(params), flush=True)
82 exit(0)
83
84 fdict: typing.DefaultDict[str, int] = defaultdict(int)
85 for filename in files:
86 size = os.path.getsize(filename)
87 f_ = filename.split(".split", 1)[0]
88 fdict[f_] += size
89 chunk_size = int(quantiles(fdict.values())[-1] // line_size)
90
91 if file_size // line_size <= chunk_size: # no need to sort if number of lines <= chunk_size
92 print(" ".join(params), flush=True)
93 exit(0)
94
95 # Find rest of relevant parameters
96 output_key = "--output="
97 output_index = None
98 output_alignment_file = None
99 output_alignment_file_base = None
100 output_format = None
101
102 strand_key = "--strand="
103 strand_index = None
104 for index, value in enumerate(params):
105 if value[:len(output_key)] == output_key:
106 output_index = index
107 output_alignment_file = value[len(output_key):]
108 output_alignment_file_base, output_format = output_alignment_file.rsplit(".", 1)
109 if value[:len(strand_key)] == strand_key:
110 strand_index = index
111
112 if output_index is None:
113 sys.exit(f"Error: could not get output key {output_key} from parameters {params}")
114
115 if output_alignment_file_base is None:
116 sys.exit(f"Error: could not get output alignment file base from parameters {params}")
117
118 if output_format is None:
119 sys.exit(f"Error: could not get output format from parameters {params}")
120
121 if strand_index is None:
122 sys.exit(f"Error: could not get strand key {strand_key} from parameters {params}")
123
124 err_index = -1 # error file is at very end
125 err_name_base = params[-1].split(".err", 1)[0]
126
127 data: dict[tuple[str, str], list[tuple[int, int, str]]] = {} # dict of list of tuple (x, y, str)
128
129 direction = None
130 if "plus" in params[strand_index]:
131 direction = "f"
132 elif "minus" in params[strand_index]:
133 direction = "r"
134 else:
135 sys.exit(f"Error: could not figure out direction from strand value {params[strand_index]}")
136
137 for line in open(input_file, "r"):
138 if line == "":
139 continue
140 seq1_name, seq1_start, seq1_end, seq2_name, seq2_start, seq2_end, _dir, score = line.split()
141 # data.append((int(seq1_start), int(seq2_start), line))
142 half_dist = int((int(seq1_end) - int(seq1_start)) // 2)
143 assert int(seq1_end) > int(seq1_start)
144 assert int(seq2_end) > int(seq2_start)
145 seq1_mid = int(seq1_start) + half_dist
146 seq2_mid = int(seq2_start) + half_dist
147 data.setdefault((seq1_name, seq2_name), []).append((seq1_mid, seq2_mid, line))
148
149 # If there are chromosome pairs with segment count <= chunk_size
150 # then no need to sort and split these pairs into separate files.
151 # It is better to keep these pairs in a single segment file.
152 skip_pairs = [] # pairs that have count <= chunk_size. these will not be sorted
153
154 # save query key order
155 # for lastz segment files: 'Query sequence names must appear in the same
156 # order as they do in the query file'
157 query_key_order = list(dict.fromkeys([i[1] for i in data.keys()]))
158
159 # NOTE: assuming data.keys() preserves order of keys. Requires Python 3.7+
160
161 if len(data.keys()) > 1:
162 for pair in data.keys():
163 if len(data[pair]) <= chunk_size:
164 skip_pairs.append(pair)
165
166 # sorting for forward segments
167 if direction == "r":
168 for pair in data.keys():
169 if pair not in skip_pairs:
170 data[pair] = sorted(data[pair], key=lambda coord: (coord[1] - coord[0], coord[0]))
171 # sorting for reverse segments
172 elif direction == "f":
173 for pair in data.keys():
174 if pair not in skip_pairs:
175 data[pair] = sorted(data[pair], key=lambda coord: (coord[1] + coord[0], coord[0]))
176 else:
177 sys.exit(f"INVALID DIRECTION VALUE: {direction}")
178
179 # Writing file in chunks
180 ctr = 0
181 for pair in data.keys() - skip_pairs: # [i for i in data_keys if i not in set(skip_pairs)]:
182 for chunk in chunks(list(zip(*data[pair]))[2], chunk_size):
183 ctr += 1
184 name_addition = f".split{ctr}"
185 fname = input_file.split(".segments", 1)[0] + name_addition + ".segments"
186
187 assert len(chunk) != 0
188 with open(fname, "w") as f:
189 f.writelines(chunk)
190 # update segment file in command
191 params[segment_index] = segment_key + fname
192 # update output file in command
193 params[output_index] = output_key + output_alignment_file_base + name_addition + "." + output_format
194 # update error file in command
195 params[-1] = err_name_base + name_addition + ".err"
196 print(" ".join(params), flush=True)
197
198 # writing unsorted skipped pairs
199 if len(skip_pairs) > 0:
200 skip_pairs_with_len = sorted([(len(data[p]), p) for p in skip_pairs]) # list of tuples of (pair length, pair)
201 # NOTE: This can violate lastz query key order requirement
202
203 query_key_order_table = {item: idx for idx, item in enumerate(query_key_order)} # used for sorting
204
205 aggregated_skip_pairs: list[list[tuple[str, str]]] = [] # list of list of pair names
206 current_count = 0
207 aggregated_skip_pairs.append([])
208 for count, pair in skip_pairs_with_len:
209 if current_count + count <= chunk_size:
210 current_count += count
211 aggregated_skip_pairs[-1].append(pair)
212 else:
213 aggregated_skip_pairs.append([])
214 current_count = count
215 aggregated_skip_pairs[-1].append(pair)
216
217 for aggregate in aggregated_skip_pairs:
218 ctr += 1
219 name_addition = f".split{ctr}"
220 fname = input_file.split(".segments", 1)[0] + name_addition + ".segments"
221
222 with open(fname, "w") as f:
223 # fix possible lastz query key order violations
224 for pair in sorted(aggregate, key=lambda p: query_key_order_table[p[1]]): # p[1] is query key
225 chunk = list(zip(*data[pair]))[2]
226 f.writelines(chunk)
227 # update segment file in command
228 params[segment_index] = segment_key + fname
229 # update output file in command
230 params[output_index] = output_key + output_alignment_file_base + name_addition + "." + output_format
231 # update error file in command
232 params[-1] = err_name_base + name_addition + ".err"
233 print(" ".join(params), flush=True)
234
235 if DELETE_AFTER_CHUNKING:
236 os.remove(input_file)