comparison diagonal_partition.py @ 21:25fa179d9d0a draft

planemo upload for repository https://github.com/richard-burhans/galaxytools/tree/main/tools/segalign commit e4e05d23d9da18ea87bc352122ca9e6cfa73d1c7
author richard-burhans
date Fri, 09 Aug 2024 20:23:12 +0000
parents 08e987868f0f
children
comparison
equal deleted inserted replaced
20:96ff17622b17 21:25fa179d9d0a
1 #!/usr/bin/env python 1 #!/usr/bin/env python
2
2 3
3 """ 4 """
4 Diagonal partitioning for segment files output by SegAlign. 5 Diagonal partitioning for segment files output by SegAlign.
5 6
6 Usage: 7 Usage:
7 diagonal_partition.py <max-segments> <lastz-command> 8 diagonal_partition.py <max-segments> <lastz-command>
8 9
9 set <max-segments> = 0 to skip partitioning, -1 to infer best parameter 10 set <max-segments> = 0 to skip partitioning, -1 to estimate best parameter
10 """ 11 """
11 12
12 13 import collections
13 import os 14 import os
15 import statistics
14 import sys 16 import sys
15 import typing 17 import typing
16 18
17 19
18 def chunks(lst: tuple[str, ...], n: int) -> typing.Iterator[tuple[str, ...]]: 20 def chunks(lst: tuple[str, ...], n: int) -> typing.Iterator[tuple[str, ...]]:
20 for i in range(0, len(lst), n): 22 for i in range(0, len(lst), n):
21 yield lst[i:i + n] 23 yield lst[i:i + n]
22 24
23 25
24 if __name__ == "__main__": 26 if __name__ == "__main__":
25 27 # TODO: make these optional user defined parameters
28
29 # deletes original segment file after splitting
26 DELETE_AFTER_CHUNKING = True 30 DELETE_AFTER_CHUNKING = True
27 MIN_CHUNK_SIZE = 5000 # don't partition segment files with line count below this value 31
28 32 # don't partition segment files with line count below this value
29 # input_params = "10000 sad sadsa sad --segments=tmp21.block0.r0.minus.segments dsa sa --strand=plus --output=out.maf sadads 2> logging.err" 33 MIN_CHUNK_SIZE = 5000
30 # sys.argv = [sys.argv[0]] + input_params.split(' ') 34
31 chunk_size = int(sys.argv[1]) # first parameter contains chunk size 35 # only used when segment size is being estimated
36 MAX_CHUNK_SIZE = 50000
37
38 # include chosen split size in file name
39 DEBUG = False
40
41 # first parameter contains chunk size
42 chunk_size = int(sys.argv[1])
32 params = sys.argv[2:] 43 params = sys.argv[2:]
33 44
34 # don't do anything if 0 chunk size 45 # don't do anything if 0 chunk size
35 if chunk_size == 0: 46 if chunk_size == 0:
36 print(" ".join(params), flush=True) 47 print(" ".join(params), flush=True)
37 exit(0) 48 sys.exit(0)
38 49
39 # Parsing command output from SegAlign 50 # Parsing command output from SegAlign
40 segment_key = "--segments=" 51 segment_key = "--segments="
41 segment_index = None 52 segment_index = None
42 input_file = None 53 input_file = None
54 sys.exit(f"Error: could not get segment file from parameters {params}") 65 sys.exit(f"Error: could not get segment file from parameters {params}")
55 66
56 if not os.path.isfile(input_file): 67 if not os.path.isfile(input_file):
57 sys.exit(f"Error: File {input_file} does not exist") 68 sys.exit(f"Error: File {input_file} does not exist")
58 69
59 line_size = None # each char in 1 byte 70 # each char is 1 byte
71 line_size = None
60 file_size = os.path.getsize(input_file) 72 file_size = os.path.getsize(input_file)
61 with open(input_file, "r") as f: 73 with open(input_file, "r") as f:
62 line_size = len(f.readline()) # add 1 for newline 74 # add 1 for newline
75 line_size = len(f.readline())
63 76
64 estimated_lines = file_size // line_size 77 estimated_lines = file_size // line_size
65 78
66 # check if chunk size should be estimated 79 # check if chunk size should be estimated
67 if chunk_size < 0: 80 if chunk_size < 0:
68 # optimization, do not need to get each file size in this case 81 # optimization, do not need to get each file size in this case
69 if estimated_lines < MIN_CHUNK_SIZE: 82 if estimated_lines < MIN_CHUNK_SIZE:
70 print(" ".join(params), flush=True) 83 print(" ".join(params), flush=True)
71 exit(0) 84 sys.exit(0)
72 85
73 from collections import defaultdict
74 from statistics import quantiles
75 # get size of each segment assuming DELETE_AFTER_CHUNKING == True 86 # get size of each segment assuming DELETE_AFTER_CHUNKING == True
76 # takes into account already split segments 87 # takes into account already split segments
77 files = [i for i in os.listdir(".") if i.endswith(".segments")] 88 files = [i for i in os.listdir(".") if i.endswith(".segments")]
78 # find . -maxdepth 1 -name "*.segments" -print0 | du -ba --files0-from=- 89
79 # if not enough segment files for estimation, continue 90 if len(files) < 2:
80 if len(files) <= 2: 91 # if not enough segment files for estimation, use MAX_CHUNK_SIZE
81 print(" ".join(params), flush=True) 92 chunk_size = MAX_CHUNK_SIZE
82 exit(0) 93 else:
83 94 fdict: typing.DefaultDict[str, int] = collections.defaultdict(int)
84 fdict: typing.DefaultDict[str, int] = defaultdict(int) 95 for filename in files:
85 for filename in files: 96 size = os.path.getsize(filename)
86 size = os.path.getsize(filename) 97 f_ = filename.split(".split", 1)[0]
87 f_ = filename.split(".split", 1)[0] 98 fdict[f_] += size
88 fdict[f_] += size 99
89 chunk_size = int(quantiles(fdict.values())[-1] // line_size) 100 if len(fdict) < 7:
90 101 # outliers can heavily skew prediction if <7 data points
91 if file_size // line_size <= chunk_size: # no need to sort if number of lines <= chunk_size 102 # to be safe, use 50% quantile
103 chunk_size = int(statistics.quantiles(fdict.values())[1] // line_size)
104 else:
105 # otherwise use 75% quantile
106 chunk_size = int(statistics.quantiles(fdict.values())[-1] // line_size)
107 # if not enough data points, there is a chance of getting unlucky
108 # minimize worst case by using MAX_CHUNK_SIZE
109
110 chunk_size = min(chunk_size, MAX_CHUNK_SIZE)
111
112 # no need to sort if number of lines <= chunk_size
113 if (estimated_lines <= chunk_size):
92 print(" ".join(params), flush=True) 114 print(" ".join(params), flush=True)
93 exit(0) 115 sys.exit(0)
94 116
95 # Find rest of relevant parameters 117 # Find rest of relevant parameters
96 output_key = "--output=" 118 output_key = "--output="
97 output_index = None 119 output_index = None
98 output_alignment_file = None 120 output_alignment_file = None
104 for index, value in enumerate(params): 126 for index, value in enumerate(params):
105 if value[:len(output_key)] == output_key: 127 if value[:len(output_key)] == output_key:
106 output_index = index 128 output_index = index
107 output_alignment_file = value[len(output_key):] 129 output_alignment_file = value[len(output_key):]
108 output_alignment_file_base, output_format = output_alignment_file.rsplit(".", 1) 130 output_alignment_file_base, output_format = output_alignment_file.rsplit(".", 1)
131
109 if value[:len(strand_key)] == strand_key: 132 if value[:len(strand_key)] == strand_key:
110 strand_index = index 133 strand_index = index
111 134
135 if output_alignment_file_base is None:
136 sys.exit(f"Error: could not get output alignment file base from parameters {params}")
137
138 if output_format is None:
139 sys.exit(f"Error: could not get output format from parameters {params}")
140
112 if output_index is None: 141 if output_index is None:
113 sys.exit(f"Error: could not get output key {output_key} from parameters {params}") 142 sys.exit(f"Error: could not get output key {output_key} from parameters {params}")
114 143
115 if output_alignment_file_base is None:
116 sys.exit(f"Error: could not get output alignment file base from parameters {params}")
117
118 if output_format is None:
119 sys.exit(f"Error: could not get output format from parameters {params}")
120
121 if strand_index is None: 144 if strand_index is None:
122 sys.exit(f"Error: could not get strand key {strand_key} from parameters {params}") 145 sys.exit(f"Error: could not get strand key {strand_key} from parameters {params}")
123 146
124 err_index = -1 # error file is at very end 147 # error file is at very end
148 err_index = -1
125 err_name_base = params[-1].split(".err", 1)[0] 149 err_name_base = params[-1].split(".err", 1)[0]
126 150
127 data: dict[tuple[str, str], list[tuple[int, int, str]]] = {} # dict of list of tuple (x, y, str) 151 # dict of list of tuple (x, y, str)
152 data: dict[tuple[str, str], list[tuple[int, int, str]]] = {}
128 153
129 direction = None 154 direction = None
130 if "plus" in params[strand_index]: 155 if "plus" in params[strand_index]:
131 direction = "f" 156 direction = "f"
132 elif "minus" in params[strand_index]: 157 elif "minus" in params[strand_index]:
147 data.setdefault((seq1_name, seq2_name), []).append((seq1_mid, seq2_mid, line)) 172 data.setdefault((seq1_name, seq2_name), []).append((seq1_mid, seq2_mid, line))
148 173
149 # If there are chromosome pairs with segment count <= chunk_size 174 # If there are chromosome pairs with segment count <= chunk_size
150 # then no need to sort and split these pairs into separate files. 175 # then no need to sort and split these pairs into separate files.
151 # It is better to keep these pairs in a single segment file. 176 # It is better to keep these pairs in a single segment file.
152 skip_pairs = [] # pairs that have count <= chunk_size. these will not be sorted 177
178 # pairs that have count <= chunk_size. these will not be sorted
179 skip_pairs = []
153 180
154 # save query key order 181 # save query key order
155 # for lastz segment files: 'Query sequence names must appear in the same 182 # for lastz segment files: 'Query sequence names must appear in the same
156 # order as they do in the query file' 183 # order as they do in the query file'
184
185 # NOTE: assuming data.keys() preserves order of keys. Requires Python 3.7+
186
157 query_key_order = list(dict.fromkeys([i[1] for i in data.keys()])) 187 query_key_order = list(dict.fromkeys([i[1] for i in data.keys()]))
158
159 # NOTE: assuming data.keys() preserves order of keys. Requires Python 3.7+
160 188
161 if len(data.keys()) > 1: 189 if len(data.keys()) > 1:
162 for pair in data.keys(): 190 for pair in data.keys():
163 if len(data[pair]) <= chunk_size: 191 if len(data[pair]) <= chunk_size:
164 skip_pairs.append(pair) 192 skip_pairs.append(pair)
166 # sorting for forward segments 194 # sorting for forward segments
167 if direction == "r": 195 if direction == "r":
168 for pair in data.keys(): 196 for pair in data.keys():
169 if pair not in skip_pairs: 197 if pair not in skip_pairs:
170 data[pair] = sorted(data[pair], key=lambda coord: (coord[1] - coord[0], coord[0])) 198 data[pair] = sorted(data[pair], key=lambda coord: (coord[1] - coord[0], coord[0]))
199
171 # sorting for reverse segments 200 # sorting for reverse segments
172 elif direction == "f": 201 elif direction == "f":
173 for pair in data.keys(): 202 for pair in data.keys():
174 if pair not in skip_pairs: 203 if pair not in skip_pairs:
175 data[pair] = sorted(data[pair], key=lambda coord: (coord[1] + coord[0], coord[0])) 204 data[pair] = sorted(data[pair], key=lambda coord: (coord[1] + coord[0], coord[0]))
176 else: 205 else:
177 sys.exit(f"INVALID DIRECTION VALUE: {direction}") 206 sys.exit(f"INVALID DIRECTION VALUE: {direction}")
178 207
179 # Writing file in chunks 208 # Writing file in chunks
180 ctr = 0 209 ctr = 0
181 for pair in data.keys() - skip_pairs: # [i for i in data_keys if i not in set(skip_pairs)]: 210 # [i for i in data_keys if i not in set(skip_pairs)]:
211 for pair in (data.keys() - skip_pairs):
182 for chunk in chunks(list(zip(*data[pair]))[2], chunk_size): 212 for chunk in chunks(list(zip(*data[pair]))[2], chunk_size):
183 ctr += 1 213 ctr += 1
184 name_addition = f".split{ctr}" 214 name_addition = f".split{ctr}"
215
216 if DEBUG:
217 name_addition = f".{chunk_size}{name_addition}"
218
185 fname = input_file.split(".segments", 1)[0] + name_addition + ".segments" 219 fname = input_file.split(".segments", 1)[0] + name_addition + ".segments"
186 220
187 assert len(chunk) != 0 221 assert len(chunk) != 0
188 with open(fname, "w") as f: 222 with open(fname, "w") as f:
189 f.writelines(chunk) 223 f.writelines(chunk)
195 params[-1] = err_name_base + name_addition + ".err" 229 params[-1] = err_name_base + name_addition + ".err"
196 print(" ".join(params), flush=True) 230 print(" ".join(params), flush=True)
197 231
198 # writing unsorted skipped pairs 232 # writing unsorted skipped pairs
199 if len(skip_pairs) > 0: 233 if len(skip_pairs) > 0:
200 skip_pairs_with_len = sorted([(len(data[p]), p) for p in skip_pairs]) # list of tuples of (pair length, pair) 234 # list of tuples of (pair length, pair)
201 # NOTE: This can violate lastz query key order requirement 235 skip_pairs_with_len = sorted([(len(data[p]), p) for p in skip_pairs])
202 236 # NOTE: This sorting can violate lastz query key order requirement, this is fixed later
203 query_key_order_table = {item: idx for idx, item in enumerate(query_key_order)} # used for sorting 237
204 238 # used for sorting
205 aggregated_skip_pairs: list[list[tuple[str, str]]] = [] # list of list of pair names 239 query_key_order_table = {item: idx for idx, item in enumerate(query_key_order)}
240
241 # list of list of pair names
242 aggregated_skip_pairs: list[list[tuple[str, str]]] = []
206 current_count = 0 243 current_count = 0
207 aggregated_skip_pairs.append([]) 244 aggregated_skip_pairs.append([])
208 for count, pair in skip_pairs_with_len: 245 for count, pair in skip_pairs_with_len:
209 if current_count + count <= chunk_size: 246 if current_count + count <= chunk_size:
210 current_count += count 247 current_count += count
215 aggregated_skip_pairs[-1].append(pair) 252 aggregated_skip_pairs[-1].append(pair)
216 253
217 for aggregate in aggregated_skip_pairs: 254 for aggregate in aggregated_skip_pairs:
218 ctr += 1 255 ctr += 1
219 name_addition = f".split{ctr}" 256 name_addition = f".split{ctr}"
257
258 if DEBUG:
259 name_addition = f".{chunk_size}{name_addition}"
260
220 fname = input_file.split(".segments", 1)[0] + name_addition + ".segments" 261 fname = input_file.split(".segments", 1)[0] + name_addition + ".segments"
221 262
222 with open(fname, "w") as f: 263 with open(fname, "w") as f:
223 # fix possible lastz query key order violations 264 # fix possible lastz query key order violations
224 for pair in sorted(aggregate, key=lambda p: query_key_order_table[p[1]]): # p[1] is query key 265 # p[1] is query key
266 for pair in sorted(aggregate, key=lambda p: query_key_order_table[p[1]]):
225 chunk = list(zip(*data[pair]))[2] 267 chunk = list(zip(*data[pair]))[2]
226 f.writelines(chunk) 268 f.writelines(chunk)
227 # update segment file in command 269 # update segment file in command
228 params[segment_index] = segment_key + fname 270 params[segment_index] = segment_key + fname
229 # update output file in command 271 # update output file in command