Mercurial > repos > richard-burhans > segalign
comparison diagonal_partition.py @ 21:25fa179d9d0a draft
planemo upload for repository https://github.com/richard-burhans/galaxytools/tree/main/tools/segalign commit e4e05d23d9da18ea87bc352122ca9e6cfa73d1c7
author | richard-burhans |
---|---|
date | Fri, 09 Aug 2024 20:23:12 +0000 |
parents | 08e987868f0f |
children |
comparison
equal
deleted
inserted
replaced
20:96ff17622b17 | 21:25fa179d9d0a |
---|---|
1 #!/usr/bin/env python | 1 #!/usr/bin/env python |
2 | |
2 | 3 |
3 """ | 4 """ |
4 Diagonal partitioning for segment files output by SegAlign. | 5 Diagonal partitioning for segment files output by SegAlign. |
5 | 6 |
6 Usage: | 7 Usage: |
7 diagonal_partition.py <max-segments> <lastz-command> | 8 diagonal_partition.py <max-segments> <lastz-command> |
8 | 9 |
9 set <max-segments> = 0 to skip partitioning, -1 to infer best parameter | 10 set <max-segments> = 0 to skip partitioning, -1 to estimate best parameter |
10 """ | 11 """ |
11 | 12 |
12 | 13 import collections |
13 import os | 14 import os |
15 import statistics | |
14 import sys | 16 import sys |
15 import typing | 17 import typing |
16 | 18 |
17 | 19 |
18 def chunks(lst: tuple[str, ...], n: int) -> typing.Iterator[tuple[str, ...]]: | 20 def chunks(lst: tuple[str, ...], n: int) -> typing.Iterator[tuple[str, ...]]: |
20 for i in range(0, len(lst), n): | 22 for i in range(0, len(lst), n): |
21 yield lst[i:i + n] | 23 yield lst[i:i + n] |
22 | 24 |
23 | 25 |
24 if __name__ == "__main__": | 26 if __name__ == "__main__": |
25 | 27 # TODO: make these optional user defined parameters |
28 | |
29 # deletes original segment file after splitting | |
26 DELETE_AFTER_CHUNKING = True | 30 DELETE_AFTER_CHUNKING = True |
27 MIN_CHUNK_SIZE = 5000 # don't partition segment files with line count below this value | 31 |
28 | 32 # don't partition segment files with line count below this value |
29 # input_params = "10000 sad sadsa sad --segments=tmp21.block0.r0.minus.segments dsa sa --strand=plus --output=out.maf sadads 2> logging.err" | 33 MIN_CHUNK_SIZE = 5000 |
30 # sys.argv = [sys.argv[0]] + input_params.split(' ') | 34 |
31 chunk_size = int(sys.argv[1]) # first parameter contains chunk size | 35 # only used when segment size is being estimated |
36 MAX_CHUNK_SIZE = 50000 | |
37 | |
38 # include chosen split size in file name | |
39 DEBUG = False | |
40 | |
41 # first parameter contains chunk size | |
42 chunk_size = int(sys.argv[1]) | |
32 params = sys.argv[2:] | 43 params = sys.argv[2:] |
33 | 44 |
34 # don't do anything if 0 chunk size | 45 # don't do anything if 0 chunk size |
35 if chunk_size == 0: | 46 if chunk_size == 0: |
36 print(" ".join(params), flush=True) | 47 print(" ".join(params), flush=True) |
37 exit(0) | 48 sys.exit(0) |
38 | 49 |
39 # Parsing command output from SegAlign | 50 # Parsing command output from SegAlign |
40 segment_key = "--segments=" | 51 segment_key = "--segments=" |
41 segment_index = None | 52 segment_index = None |
42 input_file = None | 53 input_file = None |
54 sys.exit(f"Error: could not get segment file from parameters {params}") | 65 sys.exit(f"Error: could not get segment file from parameters {params}") |
55 | 66 |
56 if not os.path.isfile(input_file): | 67 if not os.path.isfile(input_file): |
57 sys.exit(f"Error: File {input_file} does not exist") | 68 sys.exit(f"Error: File {input_file} does not exist") |
58 | 69 |
59 line_size = None # each char in 1 byte | 70 # each char is 1 byte |
71 line_size = None | |
60 file_size = os.path.getsize(input_file) | 72 file_size = os.path.getsize(input_file) |
61 with open(input_file, "r") as f: | 73 with open(input_file, "r") as f: |
62 line_size = len(f.readline()) # add 1 for newline | 74 # add 1 for newline |
75 line_size = len(f.readline()) | |
63 | 76 |
64 estimated_lines = file_size // line_size | 77 estimated_lines = file_size // line_size |
65 | 78 |
66 # check if chunk size should be estimated | 79 # check if chunk size should be estimated |
67 if chunk_size < 0: | 80 if chunk_size < 0: |
68 # optimization, do not need to get each file size in this case | 81 # optimization, do not need to get each file size in this case |
69 if estimated_lines < MIN_CHUNK_SIZE: | 82 if estimated_lines < MIN_CHUNK_SIZE: |
70 print(" ".join(params), flush=True) | 83 print(" ".join(params), flush=True) |
71 exit(0) | 84 sys.exit(0) |
72 | 85 |
73 from collections import defaultdict | |
74 from statistics import quantiles | |
75 # get size of each segment assuming DELETE_AFTER_CHUNKING == True | 86 # get size of each segment assuming DELETE_AFTER_CHUNKING == True |
76 # takes into account already split segments | 87 # takes into account already split segments |
77 files = [i for i in os.listdir(".") if i.endswith(".segments")] | 88 files = [i for i in os.listdir(".") if i.endswith(".segments")] |
78 # find . -maxdepth 1 -name "*.segments" -print0 | du -ba --files0-from=- | 89 |
79 # if not enough segment files for estimation, continue | 90 if len(files) < 2: |
80 if len(files) <= 2: | 91 # if not enough segment files for estimation, use MAX_CHUNK_SIZE |
81 print(" ".join(params), flush=True) | 92 chunk_size = MAX_CHUNK_SIZE |
82 exit(0) | 93 else: |
83 | 94 fdict: typing.DefaultDict[str, int] = collections.defaultdict(int) |
84 fdict: typing.DefaultDict[str, int] = defaultdict(int) | 95 for filename in files: |
85 for filename in files: | 96 size = os.path.getsize(filename) |
86 size = os.path.getsize(filename) | 97 f_ = filename.split(".split", 1)[0] |
87 f_ = filename.split(".split", 1)[0] | 98 fdict[f_] += size |
88 fdict[f_] += size | 99 |
89 chunk_size = int(quantiles(fdict.values())[-1] // line_size) | 100 if len(fdict) < 7: |
90 | 101 # outliers can heavily skew prediction if <7 data points |
91 if file_size // line_size <= chunk_size: # no need to sort if number of lines <= chunk_size | 102 # to be safe, use 50% quantile |
103 chunk_size = int(statistics.quantiles(fdict.values())[1] // line_size) | |
104 else: | |
105 # otherwise use 75% quantile | |
106 chunk_size = int(statistics.quantiles(fdict.values())[-1] // line_size) | |
107 # if not enough data points, there is a chance of getting unlucky | |
108 # minimize worst case by using MAX_CHUNK_SIZE | |
109 | |
110 chunk_size = min(chunk_size, MAX_CHUNK_SIZE) | |
111 | |
112 # no need to sort if number of lines <= chunk_size | |
113 if (estimated_lines <= chunk_size): | |
92 print(" ".join(params), flush=True) | 114 print(" ".join(params), flush=True) |
93 exit(0) | 115 sys.exit(0) |
94 | 116 |
95 # Find rest of relevant parameters | 117 # Find rest of relevant parameters |
96 output_key = "--output=" | 118 output_key = "--output=" |
97 output_index = None | 119 output_index = None |
98 output_alignment_file = None | 120 output_alignment_file = None |
104 for index, value in enumerate(params): | 126 for index, value in enumerate(params): |
105 if value[:len(output_key)] == output_key: | 127 if value[:len(output_key)] == output_key: |
106 output_index = index | 128 output_index = index |
107 output_alignment_file = value[len(output_key):] | 129 output_alignment_file = value[len(output_key):] |
108 output_alignment_file_base, output_format = output_alignment_file.rsplit(".", 1) | 130 output_alignment_file_base, output_format = output_alignment_file.rsplit(".", 1) |
131 | |
109 if value[:len(strand_key)] == strand_key: | 132 if value[:len(strand_key)] == strand_key: |
110 strand_index = index | 133 strand_index = index |
111 | 134 |
135 if output_alignment_file_base is None: | |
136 sys.exit(f"Error: could not get output alignment file base from parameters {params}") | |
137 | |
138 if output_format is None: | |
139 sys.exit(f"Error: could not get output format from parameters {params}") | |
140 | |
112 if output_index is None: | 141 if output_index is None: |
113 sys.exit(f"Error: could not get output key {output_key} from parameters {params}") | 142 sys.exit(f"Error: could not get output key {output_key} from parameters {params}") |
114 | 143 |
115 if output_alignment_file_base is None: | |
116 sys.exit(f"Error: could not get output alignment file base from parameters {params}") | |
117 | |
118 if output_format is None: | |
119 sys.exit(f"Error: could not get output format from parameters {params}") | |
120 | |
121 if strand_index is None: | 144 if strand_index is None: |
122 sys.exit(f"Error: could not get strand key {strand_key} from parameters {params}") | 145 sys.exit(f"Error: could not get strand key {strand_key} from parameters {params}") |
123 | 146 |
124 err_index = -1 # error file is at very end | 147 # error file is at very end |
148 err_index = -1 | |
125 err_name_base = params[-1].split(".err", 1)[0] | 149 err_name_base = params[-1].split(".err", 1)[0] |
126 | 150 |
127 data: dict[tuple[str, str], list[tuple[int, int, str]]] = {} # dict of list of tuple (x, y, str) | 151 # dict of list of tuple (x, y, str) |
152 data: dict[tuple[str, str], list[tuple[int, int, str]]] = {} | |
128 | 153 |
129 direction = None | 154 direction = None |
130 if "plus" in params[strand_index]: | 155 if "plus" in params[strand_index]: |
131 direction = "f" | 156 direction = "f" |
132 elif "minus" in params[strand_index]: | 157 elif "minus" in params[strand_index]: |
147 data.setdefault((seq1_name, seq2_name), []).append((seq1_mid, seq2_mid, line)) | 172 data.setdefault((seq1_name, seq2_name), []).append((seq1_mid, seq2_mid, line)) |
148 | 173 |
149 # If there are chromosome pairs with segment count <= chunk_size | 174 # If there are chromosome pairs with segment count <= chunk_size |
150 # then no need to sort and split these pairs into separate files. | 175 # then no need to sort and split these pairs into separate files. |
151 # It is better to keep these pairs in a single segment file. | 176 # It is better to keep these pairs in a single segment file. |
152 skip_pairs = [] # pairs that have count <= chunk_size. these will not be sorted | 177 |
178 # pairs that have count <= chunk_size. these will not be sorted | |
179 skip_pairs = [] | |
153 | 180 |
154 # save query key order | 181 # save query key order |
155 # for lastz segment files: 'Query sequence names must appear in the same | 182 # for lastz segment files: 'Query sequence names must appear in the same |
156 # order as they do in the query file' | 183 # order as they do in the query file' |
184 | |
185 # NOTE: assuming data.keys() preserves order of keys. Requires Python 3.7+ | |
186 | |
157 query_key_order = list(dict.fromkeys([i[1] for i in data.keys()])) | 187 query_key_order = list(dict.fromkeys([i[1] for i in data.keys()])) |
158 | |
159 # NOTE: assuming data.keys() preserves order of keys. Requires Python 3.7+ | |
160 | 188 |
161 if len(data.keys()) > 1: | 189 if len(data.keys()) > 1: |
162 for pair in data.keys(): | 190 for pair in data.keys(): |
163 if len(data[pair]) <= chunk_size: | 191 if len(data[pair]) <= chunk_size: |
164 skip_pairs.append(pair) | 192 skip_pairs.append(pair) |
166 # sorting for forward segments | 194 # sorting for forward segments |
167 if direction == "r": | 195 if direction == "r": |
168 for pair in data.keys(): | 196 for pair in data.keys(): |
169 if pair not in skip_pairs: | 197 if pair not in skip_pairs: |
170 data[pair] = sorted(data[pair], key=lambda coord: (coord[1] - coord[0], coord[0])) | 198 data[pair] = sorted(data[pair], key=lambda coord: (coord[1] - coord[0], coord[0])) |
199 | |
171 # sorting for reverse segments | 200 # sorting for reverse segments |
172 elif direction == "f": | 201 elif direction == "f": |
173 for pair in data.keys(): | 202 for pair in data.keys(): |
174 if pair not in skip_pairs: | 203 if pair not in skip_pairs: |
175 data[pair] = sorted(data[pair], key=lambda coord: (coord[1] + coord[0], coord[0])) | 204 data[pair] = sorted(data[pair], key=lambda coord: (coord[1] + coord[0], coord[0])) |
176 else: | 205 else: |
177 sys.exit(f"INVALID DIRECTION VALUE: {direction}") | 206 sys.exit(f"INVALID DIRECTION VALUE: {direction}") |
178 | 207 |
179 # Writing file in chunks | 208 # Writing file in chunks |
180 ctr = 0 | 209 ctr = 0 |
181 for pair in data.keys() - skip_pairs: # [i for i in data_keys if i not in set(skip_pairs)]: | 210 # [i for i in data_keys if i not in set(skip_pairs)]: |
211 for pair in (data.keys() - skip_pairs): | |
182 for chunk in chunks(list(zip(*data[pair]))[2], chunk_size): | 212 for chunk in chunks(list(zip(*data[pair]))[2], chunk_size): |
183 ctr += 1 | 213 ctr += 1 |
184 name_addition = f".split{ctr}" | 214 name_addition = f".split{ctr}" |
215 | |
216 if DEBUG: | |
217 name_addition = f".{chunk_size}{name_addition}" | |
218 | |
185 fname = input_file.split(".segments", 1)[0] + name_addition + ".segments" | 219 fname = input_file.split(".segments", 1)[0] + name_addition + ".segments" |
186 | 220 |
187 assert len(chunk) != 0 | 221 assert len(chunk) != 0 |
188 with open(fname, "w") as f: | 222 with open(fname, "w") as f: |
189 f.writelines(chunk) | 223 f.writelines(chunk) |
195 params[-1] = err_name_base + name_addition + ".err" | 229 params[-1] = err_name_base + name_addition + ".err" |
196 print(" ".join(params), flush=True) | 230 print(" ".join(params), flush=True) |
197 | 231 |
198 # writing unsorted skipped pairs | 232 # writing unsorted skipped pairs |
199 if len(skip_pairs) > 0: | 233 if len(skip_pairs) > 0: |
200 skip_pairs_with_len = sorted([(len(data[p]), p) for p in skip_pairs]) # list of tuples of (pair length, pair) | 234 # list of tuples of (pair length, pair) |
201 # NOTE: This can violate lastz query key order requirement | 235 skip_pairs_with_len = sorted([(len(data[p]), p) for p in skip_pairs]) |
202 | 236 # NOTE: This sorting can violate lastz query key order requirement, this is fixed later |
203 query_key_order_table = {item: idx for idx, item in enumerate(query_key_order)} # used for sorting | 237 |
204 | 238 # used for sorting |
205 aggregated_skip_pairs: list[list[tuple[str, str]]] = [] # list of list of pair names | 239 query_key_order_table = {item: idx for idx, item in enumerate(query_key_order)} |
240 | |
241 # list of list of pair names | |
242 aggregated_skip_pairs: list[list[tuple[str, str]]] = [] | |
206 current_count = 0 | 243 current_count = 0 |
207 aggregated_skip_pairs.append([]) | 244 aggregated_skip_pairs.append([]) |
208 for count, pair in skip_pairs_with_len: | 245 for count, pair in skip_pairs_with_len: |
209 if current_count + count <= chunk_size: | 246 if current_count + count <= chunk_size: |
210 current_count += count | 247 current_count += count |
215 aggregated_skip_pairs[-1].append(pair) | 252 aggregated_skip_pairs[-1].append(pair) |
216 | 253 |
217 for aggregate in aggregated_skip_pairs: | 254 for aggregate in aggregated_skip_pairs: |
218 ctr += 1 | 255 ctr += 1 |
219 name_addition = f".split{ctr}" | 256 name_addition = f".split{ctr}" |
257 | |
258 if DEBUG: | |
259 name_addition = f".{chunk_size}{name_addition}" | |
260 | |
220 fname = input_file.split(".segments", 1)[0] + name_addition + ".segments" | 261 fname = input_file.split(".segments", 1)[0] + name_addition + ".segments" |
221 | 262 |
222 with open(fname, "w") as f: | 263 with open(fname, "w") as f: |
223 # fix possible lastz query key order violations | 264 # fix possible lastz query key order violations |
224 for pair in sorted(aggregate, key=lambda p: query_key_order_table[p[1]]): # p[1] is query key | 265 # p[1] is query key |
266 for pair in sorted(aggregate, key=lambda p: query_key_order_table[p[1]]): | |
225 chunk = list(zip(*data[pair]))[2] | 267 chunk = list(zip(*data[pair]))[2] |
226 f.writelines(chunk) | 268 f.writelines(chunk) |
227 # update segment file in command | 269 # update segment file in command |
228 params[segment_index] = segment_key + fname | 270 params[segment_index] = segment_key + fname |
229 # update output file in command | 271 # update output file in command |