diff gplib.py @ 5:ddcf35a868b8 draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/rna_tools/graphprot commit ad60258f5759eaa205fec4af6143c728ea131419
author bgruening
date Wed, 05 Jun 2024 16:40:51 +0000
parents ace92c9a4653
children
line wrap: on
line diff
--- a/gplib.py	Thu Jan 28 15:06:14 2021 +0000
+++ b/gplib.py	Wed Jun 05 16:40:51 2024 +0000
@@ -1,4 +1,3 @@
-
 import gzip
 import random
 import re
@@ -6,7 +5,6 @@
 import subprocess
 from distutils.spawn import find_executable
 
-
 """
 
 Run doctests:
@@ -17,7 +15,8 @@
 """
 
 
-###############################################################################
+#######################################################################
+
 
 def graphprot_predictions_get_median(predictions_file):
     """
@@ -41,11 +40,12 @@
     return statistics.median(sc_list)
 
 
-###############################################################################
+#######################################################################
+
 
-def graphprot_profile_get_tsm(profile_file,
-                              profile_type="profile",
-                              avg_profile_extlr=5):
+def graphprot_profile_get_tsm(
+    profile_file, profile_type="profile", avg_profile_extlr=5
+):
 
     """
     Given a GraphProt .profile file, extract for each site (identified by
@@ -88,23 +88,21 @@
             max_list.append(max_sc)
         elif profile_type == "avg_profile":
             # Convert profile score list to average profile scores list.
-            aps_list = \
-                list_moving_window_average_values(lists_dic[seq_id],
-                                                  win_extlr=avg_profile_extlr)
+            aps_list = list_moving_window_average_values(
+                lists_dic[seq_id], win_extlr=avg_profile_extlr
+            )
             max_sc = max(aps_list)
             max_list.append(max_sc)
         else:
-            assert 0, "invalid profile_type argument given: \"%s\"" \
-                % (profile_type)
+            assert 0, 'invalid profile_type argument given: "%s"' % (profile_type)
     # Return the median.
     return statistics.median(max_list)
 
 
-###############################################################################
+#######################################################################
 
-def list_moving_window_average_values(in_list,
-                                      win_extlr=5,
-                                      method=1):
+
+def list_moving_window_average_values(in_list, win_extlr=5, method=1):
     """
     Take a list of numeric values, and calculate for each position a new value,
     by taking the mean value of the window of positions -win_extlr and
@@ -152,7 +150,8 @@
     return new_list
 
 
-###############################################################################
+#######################################################################
+
 
 def echo_add_to_file(echo_string, out_file):
     """
@@ -167,14 +166,16 @@
     assert not error, "echo is complaining:\n%s\n%s" % (check_cmd, output)
 
 
-###############################################################################
+#######################################################################
+
 
 def is_tool(name):
     """Check whether tool "name" is in PATH."""
     return find_executable(name) is not None
 
 
-###############################################################################
+#######################################################################
+
 
 def count_fasta_headers(fasta_file):
     """
@@ -194,7 +195,8 @@
     return row_count
 
 
-###############################################################################
+#######################################################################
+
 
 def make_file_copy(in_file, out_file):
     """
@@ -202,21 +204,26 @@
 
     """
     check_cmd = "cat " + in_file + " > " + out_file
-    assert in_file != out_file, \
-        "cat does not like to cat file into same file (%s)" % (check_cmd)
+    assert in_file != out_file, "cat does not like to cat file into same file (%s)" % (
+        check_cmd
+    )
     output = subprocess.getoutput(check_cmd)
     error = False
     if output:
         error = True
-    assert not error, \
-        "cat did not like your input (in_file: %s, out_file: %s):\n%s" \
-        % (in_file, out_file, output)
+    assert not error, "cat did not like your input (in_file: %s, out_file: %s):\n%s" % (
+        in_file,
+        out_file,
+        output,
+    )
 
 
-###############################################################################
+#######################################################################
+
 
-def split_fasta_into_test_train_files(in_fasta, test_out_fa, train_out_fa,
-                                      test_size=500):
+def split_fasta_into_test_train_files(
+    in_fasta, test_out_fa, train_out_fa, test_size=500
+):
     """
     Split in_fasta .fa file into two files (e.g. test, train).
 
@@ -230,7 +237,7 @@
     TRAINOUT = open(train_out_fa, "w")
     for seq_id in rand_ids_list:
         seq = seqs_dic[seq_id]
-        if (c_out >= test_size):
+        if c_out >= test_size:
             TRAINOUT.write(">%s\n%s\n" % (seq_id, seq))
         else:
             TESTOUT.write(">%s\n%s\n" % (seq_id, seq))
@@ -239,7 +246,8 @@
     TRAINOUT.close()
 
 
-###############################################################################
+#######################################################################
+
 
 def check_seqs_dic_format(seqs_dic):
     """
@@ -267,16 +275,19 @@
     return bad_seq_ids
 
 
-###############################################################################
+#######################################################################
+
 
-def read_fasta_into_dic(fasta_file,
-                        seqs_dic=False,
-                        ids_dic=False,
-                        read_dna=False,
-                        short_ensembl=False,
-                        reject_lc=False,
-                        convert_to_uc=False,
-                        skip_n_seqs=True):
+def read_fasta_into_dic(
+    fasta_file,
+    seqs_dic=False,
+    ids_dic=False,
+    read_dna=False,
+    short_ensembl=False,
+    reject_lc=False,
+    convert_to_uc=False,
+    skip_n_seqs=True,
+):
     """
     Read in FASTA sequences, convert to RNA, store in dictionary
     and return dictionary.
@@ -302,7 +313,7 @@
 
     # Go through FASTA file, extract sequences.
     if re.search(r".+\.gz$", fasta_file):
-        f = gzip.open(fasta_file, 'rt')
+        f = gzip.open(fasta_file, "rt")
     else:
         f = open(fasta_file, "r")
     for line in f:
@@ -315,9 +326,10 @@
                 if re.search(r".+\..+", seq_id):
                     m = re.search(r"(.+?)\..+", seq_id)
                     seq_id = m.group(1)
-            assert seq_id not in seqs_dic, \
-                "non-unique FASTA header \"%s\" in \"%s\"" \
-                % (seq_id, fasta_file)
+            assert seq_id not in seqs_dic, 'non-unique FASTA header "%s" in "%s"' % (
+                seq_id,
+                fasta_file,
+            )
             if ids_dic:
                 if seq_id in ids_dic:
                     seqs_dic[seq_id] = ""
@@ -328,31 +340,31 @@
                 m = re.search("([ACGTUN]+)", line, re.I)
                 seq = m.group(1)
                 if reject_lc:
-                    assert \
-                        not re.search("[a-z]", seq), \
-                        "lc char detected in seq \"%i\" (reject_lc=True)" \
-                        % (seq_id)
+                    assert not re.search(
+                        "[a-z]", seq
+                    ), 'lc char detected in seq "%i" (reject_lc=True)' % (seq_id)
                 if convert_to_uc:
                     seq = seq.upper()
                 # If sequences with N nucleotides should be skipped.
                 if skip_n_seqs:
                     if "n" in m.group(1) or "N" in m.group(1):
-                        print("WARNING: \"%s\" contains N. Discarding "
-                              "sequence ... " % (seq_id))
+                        print(
+                            'WARNING: "%s" contains N. Discarding '
+                            "sequence ... " % (seq_id)
+                        )
                         del seqs_dic[seq_id]
                         continue
                 # Convert to RNA, concatenate sequence.
                 if read_dna:
-                    seqs_dic[seq_id] += \
-                        m.group(1).replace("U", "T").replace("u", "t")
+                    seqs_dic[seq_id] += m.group(1).replace("U", "T").replace("u", "t")
                 else:
-                    seqs_dic[seq_id] += \
-                        m.group(1).replace("T", "U").replace("t", "u")
+                    seqs_dic[seq_id] += m.group(1).replace("T", "U").replace("t", "u")
     f.close()
     return seqs_dic
 
 
-###############################################################################
+#######################################################################
+
 
 def random_order_dic_keys_into_list(in_dic):
     """
@@ -366,7 +378,8 @@
     return id_list
 
 
-###############################################################################
+#######################################################################
+
 
 def graphprot_get_param_string(params_file):
     """
@@ -394,11 +407,12 @@
                 else:
                     param_string += "-%s %s " % (par, setting)
             else:
-                assert 0, "pattern matching failed for string \"%s\"" % (param)
+                assert 0, 'pattern matching failed for string "%s"' % (param)
     return param_string
 
 
-###############################################################################
+#######################################################################
+
 
 def seqs_dic_count_uc_nts(seqs_dic):
     """
@@ -416,11 +430,12 @@
     assert seqs_dic, "Given sequence dictionary empty"
     c_uc = 0
     for seq_id in seqs_dic:
-        c_uc += len(re.findall(r'[A-Z]', seqs_dic[seq_id]))
+        c_uc += len(re.findall(r"[A-Z]", seqs_dic[seq_id]))
     return c_uc
 
 
-###############################################################################
+#######################################################################
+
 
 def seqs_dic_count_lc_nts(seqs_dic):
     """
@@ -438,11 +453,12 @@
     assert seqs_dic, "Given sequence dictionary empty"
     c_uc = 0
     for seq_id in seqs_dic:
-        c_uc += len(re.findall(r'[a-z]', seqs_dic[seq_id]))
+        c_uc += len(re.findall(r"[a-z]", seqs_dic[seq_id]))
     return c_uc
 
 
-###############################################################################
+#######################################################################
+
 
 def count_file_rows(in_file):
     """
@@ -462,7 +478,8 @@
     return row_count
 
 
-###############################################################################
+#######################################################################
+
 
 def bed_check_six_col_format(bed_file):
     """
@@ -488,7 +505,8 @@
     return six_col_format
 
 
-###############################################################################
+#######################################################################
+
 
 def bed_check_unique_ids(bed_file):
     """
@@ -512,7 +530,8 @@
         return True
 
 
-###############################################################################
+#######################################################################
+
 
 def get_seq_lengths_from_seqs_dic(seqs_dic):
     """
@@ -527,7 +546,8 @@
     return seq_len_dic
 
 
-###############################################################################
+#######################################################################
+
 
 def bed_get_region_lengths(bed_file):
     """
@@ -548,19 +568,19 @@
             site_e = int(cols[2])
             site_id = cols[3]
             site_l = site_e - site_s
-            assert site_id \
-                not in id2len_dic, \
-                "column 4 IDs not unique in given .bed file \"%s\"" \
-                % (bed_file)
+            assert (
+                site_id not in id2len_dic
+            ), 'column 4 IDs not unique in given .bed file "%s"' % (bed_file)
             id2len_dic[site_id] = site_l
     f.closed
-    assert id2len_dic, \
-        "No IDs read into dic (input file \"%s\" empty or malformatted?)" \
-        % (bed_file)
+    assert (
+        id2len_dic
+    ), 'No IDs read into dic (input file "%s" empty or malformatted?)' % (bed_file)
     return id2len_dic
 
 
-###############################################################################
+#######################################################################
+
 
 def graphprot_get_param_dic(params_file):
     """
@@ -599,10 +619,10 @@
     return param_dic
 
 
-###############################################################################
+#######################################################################
 
-def graphprot_filter_predictions_file(in_file, out_file,
-                                      sc_thr=0):
+
+def graphprot_filter_predictions_file(in_file, out_file, sc_thr=0):
     """
     Filter GraphProt .predictions file by given score thr_sc.
     """
@@ -619,7 +639,8 @@
     OUTPRED.close()
 
 
-###############################################################################
+#######################################################################
+
 
 def fasta_read_in_ids(fasta_file):
     """
@@ -642,12 +663,12 @@
     return ids_list
 
 
-###############################################################################
+#######################################################################
+
 
-def graphprot_profile_calc_avg_profile(in_file, out_file,
-                                       ap_extlr=5,
-                                       seq_ids_list=False,
-                                       method=1):
+def graphprot_profile_calc_avg_profile(
+    in_file, out_file, ap_extlr=5, seq_ids_list=False, method=1
+):
     """
     Given a GraphProt .profile file, calculate average profiles and output
     average profile file.
@@ -702,15 +723,16 @@
         if seq_ids_list:
             c_seq_ids = len(seq_ids_list)
             c_site_ids = len(site_starts_dic)
-            assert c_seq_ids == c_site_ids, \
-                "# sequence IDs != # site IDs (%i != %i)" \
-                % (c_seq_ids, c_site_ids)
+            assert (
+                c_seq_ids == c_site_ids
+            ), "# sequence IDs != # site IDs (%i != %i)" % (c_seq_ids, c_site_ids)
         OUTPROF = open(out_file, "w")
         # For each site, calculate average profile scores list.
         for site_id in lists_dic:
             # Convert profile score list to average profile scores list.
-            aps_list = list_moving_window_average_values(lists_dic[site_id],
-                                                         win_extlr=ap_extlr)
+            aps_list = list_moving_window_average_values(
+                lists_dic[site_id], win_extlr=ap_extlr
+            )
             start_pos = site_starts_dic[site_id]
             # Get original FASTA sequence ID.
             if seq_ids_list:
@@ -741,10 +763,9 @@
                 if cur_id != old_id:
                     # Process old id scores.
                     if scores_list:
-                        aps_list = \
-                            list_moving_window_average_values(
-                                scores_list,
-                                win_extlr=ap_extlr)
+                        aps_list = list_moving_window_average_values(
+                            scores_list, win_extlr=ap_extlr
+                        )
                         start_pos = site_starts_dic[old_id]
                         seq_id = old_id
                         # Get original FASTA sequence ID.
@@ -763,8 +784,9 @@
         f.close()
         # Process last block.
         if scores_list:
-            aps_list = list_moving_window_average_values(scores_list,
-                                                         win_extlr=ap_extlr)
+            aps_list = list_moving_window_average_values(
+                scores_list, win_extlr=ap_extlr
+            )
             start_pos = site_starts_dic[old_id]
             seq_id = old_id
             # Get original FASTA sequence ID.
@@ -776,11 +798,12 @@
         OUTPROF.close()
 
 
-###############################################################################
+#######################################################################
+
 
-def graphprot_profile_extract_peak_regions(in_file, out_file,
-                                           max_merge_dist=0,
-                                           sc_thr=0):
+def graphprot_profile_extract_peak_regions(
+    in_file, out_file, max_merge_dist=0, sc_thr=0
+):
     """
     Extract peak regions from GraphProt .profile file.
     Store the peak regions (defined as regions with scores >= sc_thr)
@@ -835,21 +858,22 @@
                 # Process old id scores.
                 if scores_list:
                     # Extract peaks from region.
-                    peak_list = \
-                        list_extract_peaks(scores_list,
-                                           max_merge_dist=max_merge_dist,
-                                           coords="bed",
-                                           sc_thr=sc_thr)
+                    peak_list = list_extract_peaks(
+                        scores_list,
+                        max_merge_dist=max_merge_dist,
+                        coords="bed",
+                        sc_thr=sc_thr,
+                    )
                     start_pos = site_starts_dic[old_id]
                     # Print out peaks in .bed format.
                     for ln in peak_list:
                         peak_s = start_pos + ln[0]
                         peak_e = start_pos + ln[1]
                         site_id = "%s,%i" % (old_id, ln[2])
-                        OUTPEAKS.write("%s\t%i\t%i"
-                                       "\t%s\t%f\t+\n"
-                                       % (old_id, peak_s,
-                                          peak_e, site_id, ln[3]))
+                        OUTPEAKS.write(
+                            "%s\t%i\t%i"
+                            "\t%s\t%f\t+\n" % (old_id, peak_s, peak_e, site_id, ln[3])
+                        )
                     # Reset list.
                     scores_list = []
                 old_id = cur_id
@@ -861,27 +885,25 @@
     # Process last block.
     if scores_list:
         # Extract peaks from region.
-        peak_list = list_extract_peaks(scores_list,
-                                       max_merge_dist=max_merge_dist,
-                                       coords="bed",
-                                       sc_thr=sc_thr)
+        peak_list = list_extract_peaks(
+            scores_list, max_merge_dist=max_merge_dist, coords="bed", sc_thr=sc_thr
+        )
         start_pos = site_starts_dic[old_id]
         # Print out peaks in .bed format.
         for ln in peak_list:
             peak_s = start_pos + ln[0]
             peak_e = start_pos + ln[1]
             site_id = "%s,%i" % (old_id, ln[2])  # best score also 1-based.
-            OUTPEAKS.write("%s\t%i\t%i\t%s\t%f\t+\n"
-                           % (old_id, peak_s, peak_e, site_id, ln[3]))
+            OUTPEAKS.write(
+                "%s\t%i\t%i\t%s\t%f\t+\n" % (old_id, peak_s, peak_e, site_id, ln[3])
+            )
     OUTPEAKS.close()
 
 
-###############################################################################
+#######################################################################
 
-def list_extract_peaks(in_list,
-                       max_merge_dist=0,
-                       coords="list",
-                       sc_thr=0):
+
+def list_extract_peaks(in_list, max_merge_dist=0, coords="list", sc_thr=0):
     """
     Extract peak regions from list.
     Peak region is defined as region >= score threshold.
@@ -969,8 +991,12 @@
                     if peak_list[i][3] < peak_list[j][3]:
                         new_top_pos = peak_list[j][2]
                         new_top_sc = peak_list[j][3]
-                    new_peak = [peak_list[i][0], peak_list[j][1],
-                                new_top_pos, new_top_sc]
+                    new_peak = [
+                        peak_list[i][0],
+                        peak_list[j][1],
+                        new_top_pos,
+                        new_top_sc,
+                    ]
                 # If two peaks were merged.
                 if new_peak:
                     merged_peak_list.append(new_peak)
@@ -991,10 +1017,12 @@
     return peak_list
 
 
-###############################################################################
+#######################################################################
+
 
-def bed_peaks_to_genomic_peaks(peak_file, genomic_peak_file, genomic_sites_bed,
-                               print_rows=False):
+def bed_peaks_to_genomic_peaks(
+    peak_file, genomic_peak_file, genomic_sites_bed, print_rows=False
+):
     """
     Given a .bed file of sequence peak regions (possible coordinates from
     0 to length of s), convert peak coordinates to genomic coordinates.
@@ -1017,10 +1045,9 @@
             row = line.strip()
             cols = line.strip().split("\t")
             site_id = cols[3]
-            assert site_id \
-                not in id2row_dic, \
-                "column 4 IDs not unique in given .bed file \"%s\"" \
-                % (genomic_sites_bed)
+            assert (
+                site_id not in id2row_dic
+            ), 'column 4 IDs not unique in given .bed file "%s"' % (genomic_sites_bed)
             id2row_dic[site_id] = row
     f.close()
 
@@ -1034,13 +1061,14 @@
             site_e = int(cols[2])
             site_id2 = cols[3]
             site_sc = float(cols[4])
-            assert re.search(".+,.+", site_id2), \
-                "regular expression failed for ID \"%s\"" % (site_id2)
+            assert re.search(
+                ".+,.+", site_id2
+            ), 'regular expression failed for ID "%s"' % (site_id2)
             m = re.search(r".+,(\d+)", site_id2)
             sc_pos = int(m.group(1))  # 1-based.
-            assert site_id in id2row_dic, \
-                "site ID \"%s\" not found in genomic sites dictionary" \
-                % (site_id)
+            assert (
+                site_id in id2row_dic
+            ), 'site ID "%s" not found in genomic sites dictionary' % (site_id)
             row = id2row_dic[site_id]
             rowl = row.split("\t")
             gen_chr = rowl[0]
@@ -1054,16 +1082,23 @@
                 new_s = gen_e - site_e
                 new_e = gen_e - site_s
                 new_sc_pos = gen_e - sc_pos + 1  # keep 1-based.
-            new_row = "%s\t%i\t%i\t%s,%i\t%f\t%s" \
-                      % (gen_chr, new_s, new_e,
-                         site_id, new_sc_pos, site_sc, gen_pol)
+            new_row = "%s\t%i\t%i\t%s,%i\t%f\t%s" % (
+                gen_chr,
+                new_s,
+                new_e,
+                site_id,
+                new_sc_pos,
+                site_sc,
+                gen_pol,
+            )
             OUTPEAKS.write("%s\n" % (new_row))
             if print_rows:
                 print(new_row)
     OUTPEAKS.close()
 
 
-###############################################################################
+#######################################################################
+
 
 def diff_two_files_identical(file1, file2):
     """
@@ -1087,4 +1122,4 @@
     return same
 
 
-###############################################################################
+#######################################################################