Mercurial > repos > richard-burhans > segalign
diff diagonal_partition.py @ 0:5c72425b7f1b draft
planemo upload for repository https://github.com/richard-burhans/galaxytools/tree/main/tools/segalign commit 98a4dd44360447aa96d92143384d78e116d7581b
author | richard-burhans |
---|---|
date | Wed, 17 Apr 2024 18:06:54 +0000 |
parents | |
children | 9e34b25a8670 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/diagonal_partition.py Wed Apr 17 18:06:54 2024 +0000 @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 + +""" +Diagonal partitioning for segment files output by SegAlign. + +Usage: +diagonal_partition <max-segments> <lastz-command> +""" + +import os +import sys + + +def chunks(lst, n): + """Yield successive n-sized chunks from list.""" + for i in range(0, len(lst), n): + yield lst[i: i + n] + + +if __name__ == "__main__": + + DELETE_AFTER_CHUNKING = True + + # input_params = "10000 sad sadsa sad --segments=tmp10.block5.r1239937044.plus.segments dsa sa --strand=plus --output=out.maf sadads 2> logging.err" + # sys.argv = [sys.argv[0]] + input_params.split(' ') + chunk_size = int(sys.argv[1]) # first parameter contains chunk size + params = sys.argv[2:] + + # Parsing command output from SegAlign + segment_key = "--segments=" + segment_index = None + input_file = None + + for index, value in enumerate(params): + if value[: len(segment_key)] == segment_key: + segment_index = index + input_file = value[len(segment_key):] + break + if segment_index is None: + print(f"Error: could not segment key {segment_key} in parameters {params}") + exit(0) + + if not os.path.isfile(input_file): + print(f"Error: File {input_file} does not exist") + exit(0) + + if ( + chunk_size == 0 or sum(1 for _ in open(input_file)) <= chunk_size + ): # no need to sort if number of lines <= chunk_size + print(" ".join(params), flush=True) + exit(0) + + # Find rest of relevant parameters + output_key = "--output=" + output_index = None + output_alignment_file = None + output_alignment_file_base = None + output_format = None + + strand_key = "--strand=" + strand_index = None + for index, value in enumerate(params): + if value[: len(output_key)] == output_key: + output_index = index + output_alignment_file = value[len(output_key):] + output_alignment_file_base, output_format = output_alignment_file.rsplit( + ".", 1 + ) + if value[: len(strand_key)] == strand_key: + strand_index = index + if segment_index is None: + print(f"Error: could not output key {output_key} in parameters {params}") + exit(0) + if strand_index is None: + print(f"Error: could not output key {strand_key} in parameters {params}") + exit(0) + + err_index = -1 # error file is at very end + err_name_base = params[-1].split(".err", 1)[0] + + data = {} # dict of list of tuple (x, y, str) + + direction = None + if "plus" in params[strand_index]: + direction = "f" + elif "minus" in params[strand_index]: + direction = "r" + else: + print( + f"Error: could not figure out direction from strand value {params[strand_index]}" + ) + exit(0) + + for line in open(input_file, "r"): + if line == "": + continue + ( + seq1_name, + seq1_start, + seq1_end, + seq2_name, + seq2_start, + seq2_end, + _dir, + score, + ) = line.split() + # data.append((int(seq1_start), int(seq2_start), line)) + half_dist = int((int(seq1_end) - int(seq1_start)) // 2) + assert int(seq1_end) > int(seq1_start) + assert int(seq2_end) > int(seq2_start) + seq1_mid = int(seq1_start) + half_dist + seq2_mid = int(seq2_start) + half_dist + data.setdefault((seq1_name, seq2_name), []).append((seq1_mid, seq2_mid, line)) + + # If there are chromosome pairs with segment count <= chunk_size + # then no need to sort and split these pairs into separate files. + # It is better to keep these pairs in a single segment file. + skip_pairs = [] # pairs that have count <= chunk_size. + # these will not be sorted + if len(data.keys()) > 1: + for pair in data.keys(): + if len(data[pair]) <= chunk_size: + skip_pairs.append(pair) + + # sorting for forward segments + if direction == "r": + for pair in data.keys(): + if pair not in skip_pairs: + data[pair] = sorted( + data[pair], key=lambda coord: (coord[1] - coord[0], coord[0]) + ) + # sorting for reverse segments + elif direction == "f": + for pair in data.keys(): + if pair not in skip_pairs: + data[pair] = sorted( + data[pair], key=lambda coord: (coord[1] + coord[0], coord[0]) + ) + else: + print(f"INVALID DIRECTION VALUE: {direction}") + exit(0) + + # Writing file in chunks + ctr = 0 + for pair in data.keys() - skip_pairs: + for chunk in chunks(list(zip(*data[pair]))[2], chunk_size): + ctr += 1 + name_addition = f".split{ctr}" + fname = input_file.split(".segments", 1)[0] + name_addition + ".segments" + + with open(fname, "w") as f: + f.writelines(chunk) + # update segment file in command + params[segment_index] = segment_key + fname + # update output file in command + params[output_index] = ( + output_key + + output_alignment_file_base + + name_addition + + "." + + output_format + ) + # update error file in command + params[-1] = err_name_base + name_addition + ".err" + print(" ".join(params), flush=True) + + # writing unsorted skipped pairs + if len(skip_pairs) > 0: + skip_pairs_with_len = sorted( + [(len(data[p]), p) for p in skip_pairs] + ) # list of tuples of (pair length, pair) + aggregated_skip_pairs = [] # list of list of pair names + current_count = 0 + aggregated_skip_pairs.append([]) + for count, pair in skip_pairs_with_len: + if current_count + count <= chunk_size: + current_count += count + aggregated_skip_pairs[-1].append(pair) + else: + aggregated_skip_pairs.append([]) + current_count = count + aggregated_skip_pairs[-1].append(pair) + + for aggregate in aggregated_skip_pairs: + ctr += 1 + name_addition = f".split{ctr}" + fname = input_file.split(".segments", 1)[0] + name_addition + ".segments" + with open(fname, "w") as f: + for pair in sorted(aggregate, key=lambda p: (p[1], p[0])): + chunk = list(zip(*data[pair]))[2] + f.writelines(chunk) + # update segment file in command + params[segment_index] = segment_key + fname + # update output file in command + params[output_index] = ( + output_key + + output_alignment_file_base + + name_addition + + "." + + output_format + ) + # update error file in command + params[-1] = err_name_base + name_addition + ".err" + print(" ".join(params), flush=True) + + if DELETE_AFTER_CHUNKING: + os.remove(input_file)