Mercurial > repos > richard-burhans > segalign
comparison diagonal_partition.py @ 9:08e987868f0f draft
planemo upload for repository https://github.com/richard-burhans/galaxytools/tree/main/tools/segalign commit 062a761a340e095ea7ef7ed7cd1d3d55b1fdc5c4
| author | richard-burhans |
|---|---|
| date | Wed, 10 Jul 2024 17:06:45 +0000 |
| parents | |
| children | 25fa179d9d0a |
comparison
equal
deleted
inserted
replaced
| 8:150de8a3954a | 9:08e987868f0f |
|---|---|
| 1 #!/usr/bin/env python | |
| 2 | |
| 3 """ | |
| 4 Diagonal partitioning for segment files output by SegAlign. | |
| 5 | |
| 6 Usage: | |
| 7 diagonal_partition.py <max-segments> <lastz-command> | |
| 8 | |
| 9 set <max-segments> = 0 to skip partitioning, -1 to infer best parameter | |
| 10 """ | |
| 11 | |
| 12 | |
| 13 import os | |
| 14 import sys | |
| 15 import typing | |
| 16 | |
| 17 | |
| 18 def chunks(lst: tuple[str, ...], n: int) -> typing.Iterator[tuple[str, ...]]: | |
| 19 """Yield successive n-sized chunks from list.""" | |
| 20 for i in range(0, len(lst), n): | |
| 21 yield lst[i:i + n] | |
| 22 | |
| 23 | |
| 24 if __name__ == "__main__": | |
| 25 | |
| 26 DELETE_AFTER_CHUNKING = True | |
| 27 MIN_CHUNK_SIZE = 5000 # don't partition segment files with line count below this value | |
| 28 | |
| 29 # input_params = "10000 sad sadsa sad --segments=tmp21.block0.r0.minus.segments dsa sa --strand=plus --output=out.maf sadads 2> logging.err" | |
| 30 # sys.argv = [sys.argv[0]] + input_params.split(' ') | |
| 31 chunk_size = int(sys.argv[1]) # first parameter contains chunk size | |
| 32 params = sys.argv[2:] | |
| 33 | |
| 34 # don't do anything if 0 chunk size | |
| 35 if chunk_size == 0: | |
| 36 print(" ".join(params), flush=True) | |
| 37 exit(0) | |
| 38 | |
| 39 # Parsing command output from SegAlign | |
| 40 segment_key = "--segments=" | |
| 41 segment_index = None | |
| 42 input_file = None | |
| 43 | |
| 44 for index, value in enumerate(params): | |
| 45 if value[:len(segment_key)] == segment_key: | |
| 46 segment_index = index | |
| 47 input_file = value[len(segment_key):] | |
| 48 break | |
| 49 | |
| 50 if segment_index is None: | |
| 51 sys.exit(f"Error: could not get segment key {segment_key} from parameters {params}") | |
| 52 | |
| 53 if input_file is None: | |
| 54 sys.exit(f"Error: could not get segment file from parameters {params}") | |
| 55 | |
| 56 if not os.path.isfile(input_file): | |
| 57 sys.exit(f"Error: File {input_file} does not exist") | |
| 58 | |
| 59 line_size = None # each char in 1 byte | |
| 60 file_size = os.path.getsize(input_file) | |
| 61 with open(input_file, "r") as f: | |
| 62 line_size = len(f.readline()) # add 1 for newline | |
| 63 | |
| 64 estimated_lines = file_size // line_size | |
| 65 | |
| 66 # check if chunk size should be estimated | |
| 67 if chunk_size < 0: | |
| 68 # optimization, do not need to get each file size in this case | |
| 69 if estimated_lines < MIN_CHUNK_SIZE: | |
| 70 print(" ".join(params), flush=True) | |
| 71 exit(0) | |
| 72 | |
| 73 from collections import defaultdict | |
| 74 from statistics import quantiles | |
| 75 # get size of each segment assuming DELETE_AFTER_CHUNKING == True | |
| 76 # takes into account already split segments | |
| 77 files = [i for i in os.listdir(".") if i.endswith(".segments")] | |
| 78 # find . -maxdepth 1 -name "*.segments" -print0 | du -ba --files0-from=- | |
| 79 # if not enough segment files for estimation, continue | |
| 80 if len(files) <= 2: | |
| 81 print(" ".join(params), flush=True) | |
| 82 exit(0) | |
| 83 | |
| 84 fdict: typing.DefaultDict[str, int] = defaultdict(int) | |
| 85 for filename in files: | |
| 86 size = os.path.getsize(filename) | |
| 87 f_ = filename.split(".split", 1)[0] | |
| 88 fdict[f_] += size | |
| 89 chunk_size = int(quantiles(fdict.values())[-1] // line_size) | |
| 90 | |
| 91 if file_size // line_size <= chunk_size: # no need to sort if number of lines <= chunk_size | |
| 92 print(" ".join(params), flush=True) | |
| 93 exit(0) | |
| 94 | |
| 95 # Find rest of relevant parameters | |
| 96 output_key = "--output=" | |
| 97 output_index = None | |
| 98 output_alignment_file = None | |
| 99 output_alignment_file_base = None | |
| 100 output_format = None | |
| 101 | |
| 102 strand_key = "--strand=" | |
| 103 strand_index = None | |
| 104 for index, value in enumerate(params): | |
| 105 if value[:len(output_key)] == output_key: | |
| 106 output_index = index | |
| 107 output_alignment_file = value[len(output_key):] | |
| 108 output_alignment_file_base, output_format = output_alignment_file.rsplit(".", 1) | |
| 109 if value[:len(strand_key)] == strand_key: | |
| 110 strand_index = index | |
| 111 | |
| 112 if output_index is None: | |
| 113 sys.exit(f"Error: could not get output key {output_key} from parameters {params}") | |
| 114 | |
| 115 if output_alignment_file_base is None: | |
| 116 sys.exit(f"Error: could not get output alignment file base from parameters {params}") | |
| 117 | |
| 118 if output_format is None: | |
| 119 sys.exit(f"Error: could not get output format from parameters {params}") | |
| 120 | |
| 121 if strand_index is None: | |
| 122 sys.exit(f"Error: could not get strand key {strand_key} from parameters {params}") | |
| 123 | |
| 124 err_index = -1 # error file is at very end | |
| 125 err_name_base = params[-1].split(".err", 1)[0] | |
| 126 | |
| 127 data: dict[tuple[str, str], list[tuple[int, int, str]]] = {} # dict of list of tuple (x, y, str) | |
| 128 | |
| 129 direction = None | |
| 130 if "plus" in params[strand_index]: | |
| 131 direction = "f" | |
| 132 elif "minus" in params[strand_index]: | |
| 133 direction = "r" | |
| 134 else: | |
| 135 sys.exit(f"Error: could not figure out direction from strand value {params[strand_index]}") | |
| 136 | |
| 137 for line in open(input_file, "r"): | |
| 138 if line == "": | |
| 139 continue | |
| 140 seq1_name, seq1_start, seq1_end, seq2_name, seq2_start, seq2_end, _dir, score = line.split() | |
| 141 # data.append((int(seq1_start), int(seq2_start), line)) | |
| 142 half_dist = int((int(seq1_end) - int(seq1_start)) // 2) | |
| 143 assert int(seq1_end) > int(seq1_start) | |
| 144 assert int(seq2_end) > int(seq2_start) | |
| 145 seq1_mid = int(seq1_start) + half_dist | |
| 146 seq2_mid = int(seq2_start) + half_dist | |
| 147 data.setdefault((seq1_name, seq2_name), []).append((seq1_mid, seq2_mid, line)) | |
| 148 | |
| 149 # If there are chromosome pairs with segment count <= chunk_size | |
| 150 # then no need to sort and split these pairs into separate files. | |
| 151 # It is better to keep these pairs in a single segment file. | |
| 152 skip_pairs = [] # pairs that have count <= chunk_size. these will not be sorted | |
| 153 | |
| 154 # save query key order | |
| 155 # for lastz segment files: 'Query sequence names must appear in the same | |
| 156 # order as they do in the query file' | |
| 157 query_key_order = list(dict.fromkeys([i[1] for i in data.keys()])) | |
| 158 | |
| 159 # NOTE: assuming data.keys() preserves order of keys. Requires Python 3.7+ | |
| 160 | |
| 161 if len(data.keys()) > 1: | |
| 162 for pair in data.keys(): | |
| 163 if len(data[pair]) <= chunk_size: | |
| 164 skip_pairs.append(pair) | |
| 165 | |
| 166 # sorting for forward segments | |
| 167 if direction == "r": | |
| 168 for pair in data.keys(): | |
| 169 if pair not in skip_pairs: | |
| 170 data[pair] = sorted(data[pair], key=lambda coord: (coord[1] - coord[0], coord[0])) | |
| 171 # sorting for reverse segments | |
| 172 elif direction == "f": | |
| 173 for pair in data.keys(): | |
| 174 if pair not in skip_pairs: | |
| 175 data[pair] = sorted(data[pair], key=lambda coord: (coord[1] + coord[0], coord[0])) | |
| 176 else: | |
| 177 sys.exit(f"INVALID DIRECTION VALUE: {direction}") | |
| 178 | |
| 179 # Writing file in chunks | |
| 180 ctr = 0 | |
| 181 for pair in data.keys() - skip_pairs: # [i for i in data_keys if i not in set(skip_pairs)]: | |
| 182 for chunk in chunks(list(zip(*data[pair]))[2], chunk_size): | |
| 183 ctr += 1 | |
| 184 name_addition = f".split{ctr}" | |
| 185 fname = input_file.split(".segments", 1)[0] + name_addition + ".segments" | |
| 186 | |
| 187 assert len(chunk) != 0 | |
| 188 with open(fname, "w") as f: | |
| 189 f.writelines(chunk) | |
| 190 # update segment file in command | |
| 191 params[segment_index] = segment_key + fname | |
| 192 # update output file in command | |
| 193 params[output_index] = output_key + output_alignment_file_base + name_addition + "." + output_format | |
| 194 # update error file in command | |
| 195 params[-1] = err_name_base + name_addition + ".err" | |
| 196 print(" ".join(params), flush=True) | |
| 197 | |
| 198 # writing unsorted skipped pairs | |
| 199 if len(skip_pairs) > 0: | |
| 200 skip_pairs_with_len = sorted([(len(data[p]), p) for p in skip_pairs]) # list of tuples of (pair length, pair) | |
| 201 # NOTE: This can violate lastz query key order requirement | |
| 202 | |
| 203 query_key_order_table = {item: idx for idx, item in enumerate(query_key_order)} # used for sorting | |
| 204 | |
| 205 aggregated_skip_pairs: list[list[tuple[str, str]]] = [] # list of list of pair names | |
| 206 current_count = 0 | |
| 207 aggregated_skip_pairs.append([]) | |
| 208 for count, pair in skip_pairs_with_len: | |
| 209 if current_count + count <= chunk_size: | |
| 210 current_count += count | |
| 211 aggregated_skip_pairs[-1].append(pair) | |
| 212 else: | |
| 213 aggregated_skip_pairs.append([]) | |
| 214 current_count = count | |
| 215 aggregated_skip_pairs[-1].append(pair) | |
| 216 | |
| 217 for aggregate in aggregated_skip_pairs: | |
| 218 ctr += 1 | |
| 219 name_addition = f".split{ctr}" | |
| 220 fname = input_file.split(".segments", 1)[0] + name_addition + ".segments" | |
| 221 | |
| 222 with open(fname, "w") as f: | |
| 223 # fix possible lastz query key order violations | |
| 224 for pair in sorted(aggregate, key=lambda p: query_key_order_table[p[1]]): # p[1] is query key | |
| 225 chunk = list(zip(*data[pair]))[2] | |
| 226 f.writelines(chunk) | |
| 227 # update segment file in command | |
| 228 params[segment_index] = segment_key + fname | |
| 229 # update output file in command | |
| 230 params[output_index] = output_key + output_alignment_file_base + name_addition + "." + output_format | |
| 231 # update error file in command | |
| 232 params[-1] = err_name_base + name_addition + ".err" | |
| 233 print(" ".join(params), flush=True) | |
| 234 | |
| 235 if DELETE_AFTER_CHUNKING: | |
| 236 os.remove(input_file) |
