Mercurial > repos > iuc > b2btools_single_sequence
comparison script.py @ 0:b694a77ca1e8 draft default tip
planemo upload commit 599e1135baba020195b3f7576449d595bca9af75
| author | iuc |
|---|---|
| date | Tue, 09 Aug 2022 12:30:52 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:b694a77ca1e8 |
|---|---|
| 1 import json | |
| 2 import optparse | |
| 3 import os.path | |
| 4 import re | |
| 5 import unicodedata | |
| 6 | |
| 7 import matplotlib.pyplot as plt | |
| 8 import numpy as np | |
| 9 import pandas as pd | |
| 10 from b2bTools import SingleSeq | |
| 11 | |
| 12 | |
| 13 def slugify(value): | |
| 14 """ | |
| 15 From https://github.com/django/django/blob/master/django/utils/text.py | |
| 16 Convert to ASCII if 'allow_unicode'. Convert spaces or repeated | |
| 17 dashes to single dashes. Remove characters that aren't alphanumerics, | |
| 18 underscores, or hyphens. Convert to lowercase. Also strip leading and | |
| 19 trailing whitespace, dashes, and underscores. | |
| 20 """ | |
| 21 value = str(value) | |
| 22 value = ( | |
| 23 unicodedata.normalize("NFKD", value) | |
| 24 .encode("ascii", "ignore") | |
| 25 .decode("ascii") | |
| 26 ) | |
| 27 value = re.sub(r"[^\w\s-]", "", value.lower()) | |
| 28 return re.sub(r"[-\s]+", "-", value).strip("-_") | |
| 29 | |
| 30 | |
| 31 def check_min_max(predicted_values, former_min, former_max): | |
| 32 seq_max = max(predicted_values) | |
| 33 seq_min = min(predicted_values) | |
| 34 if ( | |
| 35 seq_max + 0.1 > former_max | |
| 36 and not np.isnan(seq_max) | |
| 37 and not np.isinf(seq_max) | |
| 38 ): | |
| 39 former_max = seq_max + 0.1 | |
| 40 if ( | |
| 41 seq_min - 0.1 < former_min | |
| 42 and not np.isnan(seq_min) | |
| 43 and not np.isinf(seq_min) | |
| 44 ): | |
| 45 former_min = seq_min - 0.1 | |
| 46 return former_min, former_max | |
| 47 | |
| 48 | |
| 49 def plot_prediction(pred_name, hlighting_regions, predicted_values, seq_name): | |
| 50 thresholds_dict = { | |
| 51 "backbone": { | |
| 52 "membrane spanning": [1.0, 1.5], | |
| 53 "rigid": [0.8, 1.0], | |
| 54 "context-dependent": [0.69, 0.8], | |
| 55 "flexible": [-1.0, 0.69], | |
| 56 }, | |
| 57 "earlyFolding": { | |
| 58 "early folds": [0.169, 2.0], | |
| 59 "late folds": [-1.0, 0.169], | |
| 60 }, | |
| 61 "disoMine": {"ordered": [-1.0, 0.5], "disordered": [0.5, 2.0]}, | |
| 62 } | |
| 63 ordered_regions_dict = { | |
| 64 "backbone": [ | |
| 65 "flexible", | |
| 66 "context-dependent", | |
| 67 "rigid", | |
| 68 "membrane spanning", | |
| 69 ], | |
| 70 "earlyFolding": ["late folds", "early folds"], | |
| 71 "disoMine": ["ordered", "disordered"], | |
| 72 } | |
| 73 colors = ["yellow", "orange", "pink", "red"] | |
| 74 ranges_dict = { | |
| 75 "backbone": [-0.2, 1.2], | |
| 76 "sidechain": [-0.2, 1.2], | |
| 77 "ppII": [-0.2, 1.2], | |
| 78 "earlyFolding": [-0.2, 1.2], | |
| 79 "disoMine": [-0.2, 1.2], | |
| 80 "agmata": [-0.2, 1.2], | |
| 81 "helix": [-1.0, 1.0], | |
| 82 "sheet": [-1.0, 1.0], | |
| 83 "coil": [-1.0, 1.0], | |
| 84 } | |
| 85 fig, ax = plt.subplots(1, 1) | |
| 86 fig.set_figwidth(10) | |
| 87 fig.set_figheight(5) | |
| 88 ax.set_title(pred_name + " " + "prediction") | |
| 89 min_value, max_value = ranges_dict[pred_name] | |
| 90 if seq_name == "all": | |
| 91 max_len = 0 | |
| 92 for seq in predicted_values.keys(): | |
| 93 predictions = predicted_values[seq] | |
| 94 min_value, max_value = check_min_max( | |
| 95 predictions, min_value, max_value | |
| 96 ) | |
| 97 ax.plot(range(len(predictions)), predictions, label=seq) | |
| 98 if len(predictions) > max_len: | |
| 99 max_len = len(predictions) | |
| 100 ax.set_xlim([0, max_len - 1]) | |
| 101 else: | |
| 102 predictions = predicted_values | |
| 103 min_value, max_value = check_min_max(predictions, min_value, max_value) | |
| 104 ax.plot(range(len(predictions)), predictions, label=seq_name) | |
| 105 ax.set_xlim([0, len(predictions) - 1]) | |
| 106 legend_lines = plt.legend( | |
| 107 bbox_to_anchor=(1.04, 1), loc="upper left", fancybox=True, shadow=True | |
| 108 ) | |
| 109 ax.add_artist(legend_lines) | |
| 110 # Define regions | |
| 111 if hlighting_regions: | |
| 112 if pred_name in ordered_regions_dict.keys(): | |
| 113 for i, prediction in enumerate(ordered_regions_dict[pred_name]): | |
| 114 lower = thresholds_dict[pred_name][prediction][0] | |
| 115 upper = thresholds_dict[pred_name][prediction][1] | |
| 116 color = colors[i] | |
| 117 ax.axhspan( | |
| 118 lower, upper, alpha=0.3, color=color, label=prediction | |
| 119 ) | |
| 120 included_in_regions_legend = list( | |
| 121 reversed( | |
| 122 [ | |
| 123 prediction | |
| 124 for prediction in ordered_regions_dict[pred_name] | |
| 125 ] | |
| 126 ) | |
| 127 ) # to sort it "from up to low" | |
| 128 # Get handles and labels | |
| 129 handles, labels = plt.gca().get_legend_handles_labels() | |
| 130 handles_dict = { | |
| 131 label: handles[idx] for idx, label in enumerate(labels) | |
| 132 } | |
| 133 # Add legend for regions, if available | |
| 134 region_legend = ax.legend( | |
| 135 [ | |
| 136 handles_dict[region] | |
| 137 for region in included_in_regions_legend | |
| 138 ], | |
| 139 [region for region in included_in_regions_legend], | |
| 140 fancybox=True, | |
| 141 shadow=True, | |
| 142 loc="lower left", | |
| 143 bbox_to_anchor=(1.04, 0), | |
| 144 ) | |
| 145 ax.add_artist(region_legend) | |
| 146 ax.set_ylim([min_value, max_value]) | |
| 147 ax.set_xlabel("residue index") | |
| 148 ax.set_ylabel("prediction values") | |
| 149 ax.grid(axis="y") | |
| 150 plt.savefig( | |
| 151 os.path.join( | |
| 152 options.plot_output, | |
| 153 "{0}_{1}.png".format(slugify(seq_name), pred_name), | |
| 154 ), | |
| 155 bbox_inches="tight", | |
| 156 ) | |
| 157 plt.close() | |
| 158 | |
| 159 | |
| 160 def df_dict_to_dict_of_values(df_dict, predictor): | |
| 161 results_dict = {} | |
| 162 for seq in df_dict.keys(): | |
| 163 df = pd.read_csv(df_dict[seq], sep="\t") | |
| 164 results_dict[seq] = df[predictor] | |
| 165 return results_dict | |
| 166 | |
| 167 | |
| 168 def main(options): | |
| 169 single_seq = SingleSeq(options.input_fasta) | |
| 170 b2b_tools = [] | |
| 171 if options.dynamine: | |
| 172 b2b_tools.append("dynamine") | |
| 173 if options.disomine: | |
| 174 b2b_tools.append("disomine") | |
| 175 if options.efoldmine: | |
| 176 b2b_tools.append("efoldmine") | |
| 177 if options.agmata: | |
| 178 b2b_tools.append("agmata") | |
| 179 single_seq.predict(b2b_tools) | |
| 180 predictions = single_seq.get_all_predictions() | |
| 181 | |
| 182 def rounder_function(value): | |
| 183 return round(float(value), 3) | |
| 184 | |
| 185 rounded_predictions = json.loads( | |
| 186 json.dumps(predictions), parse_float=rounder_function | |
| 187 ) | |
| 188 results_json = json.dumps(rounded_predictions, indent=2, sort_keys=True) | |
| 189 with open(options.json_output, "w") as f: | |
| 190 f.write(results_json) | |
| 191 first_sequence_key = next(iter(predictions)) | |
| 192 prediction_keys = predictions[first_sequence_key].keys() | |
| 193 # Sort column names | |
| 194 tsv_column_names = list(prediction_keys) | |
| 195 tsv_column_names.remove("seq") | |
| 196 tsv_column_names = ['residue', *sorted(tsv_column_names)] | |
| 197 | |
| 198 df_dictionary = {} | |
| 199 for sequence_key, seq_preds in predictions.items(): | |
| 200 residues = seq_preds["seq"] | |
| 201 residues_count = len(residues) | |
| 202 sequence_df = pd.DataFrame( | |
| 203 columns=prediction_keys, index=range(residues_count) | |
| 204 ) | |
| 205 sequence_df.index.name = "residue_index" | |
| 206 for predictor in prediction_keys: | |
| 207 sequence_df[predictor] = seq_preds[predictor] | |
| 208 sequence_df = sequence_df.rename(columns={"seq": "residue"}) | |
| 209 sequence_df = sequence_df.round(decimals=3) | |
| 210 filename = f"{options.output}/{slugify(sequence_key)}.tsv" | |
| 211 df_dictionary[sequence_key] = filename | |
| 212 sequence_df.to_csv( | |
| 213 filename, | |
| 214 header=True, | |
| 215 columns=tsv_column_names, | |
| 216 sep="\t" | |
| 217 ) | |
| 218 # Plot each individual plot (compatible with plot all) | |
| 219 if options.plot: | |
| 220 for predictor in prediction_keys: | |
| 221 if predictor != "seq": | |
| 222 plot_prediction( | |
| 223 pred_name=predictor, | |
| 224 hlighting_regions=True, | |
| 225 predicted_values=seq_preds[predictor], | |
| 226 seq_name=sequence_key, | |
| 227 ) | |
| 228 # Plot all together (compatible with plot individual) | |
| 229 if options.plot_all: | |
| 230 for predictor in prediction_keys: | |
| 231 if predictor != "seq": | |
| 232 results_dictionary = df_dict_to_dict_of_values( | |
| 233 df_dictionary, predictor | |
| 234 ) | |
| 235 plot_prediction( | |
| 236 pred_name=predictor, | |
| 237 hlighting_regions=True, | |
| 238 predicted_values=results_dictionary, | |
| 239 seq_name="all", | |
| 240 ) | |
| 241 | |
| 242 | |
| 243 if __name__ == "__main__": | |
| 244 parser = optparse.OptionParser() | |
| 245 parser.add_option( | |
| 246 "--dynamine", | |
| 247 action="store_true" | |
| 248 ) | |
| 249 parser.add_option( | |
| 250 "--disomine", | |
| 251 action="store_true" | |
| 252 ) | |
| 253 parser.add_option( | |
| 254 "--efoldmine", | |
| 255 action="store_true" | |
| 256 ) | |
| 257 parser.add_option( | |
| 258 "--agmata", | |
| 259 action="store_true" | |
| 260 ) | |
| 261 parser.add_option( | |
| 262 "--file", | |
| 263 dest="input_fasta", | |
| 264 type="string" | |
| 265 ) | |
| 266 parser.add_option( | |
| 267 "--output", | |
| 268 dest="output", | |
| 269 type="string" | |
| 270 ) | |
| 271 parser.add_option( | |
| 272 "--plot-output", | |
| 273 type="string", | |
| 274 dest="plot_output" | |
| 275 ) | |
| 276 parser.add_option( | |
| 277 "--json", | |
| 278 dest="json_output", | |
| 279 type="string" | |
| 280 ) | |
| 281 parser.add_option( | |
| 282 "--plot", | |
| 283 action="store_true" | |
| 284 ) | |
| 285 parser.add_option( | |
| 286 "--plot_all", | |
| 287 action="store_true" | |
| 288 ) | |
| 289 parser.add_option( | |
| 290 "--highlight", | |
| 291 action="store_true" | |
| 292 ) | |
| 293 try: | |
| 294 options, args = parser.parse_args() | |
| 295 if not (options.dynamine or options.disomine or options.efoldmine or options.agmata): | |
| 296 parser.error('At least one predictor is required') | |
| 297 if not options.input_fasta: | |
| 298 parser.error('Input file not given (--file)') | |
| 299 if not options.output: | |
| 300 parser.error('Output directory not given (--output)') | |
| 301 if (options.plot or options.plot_all) and not options.plot_output: | |
| 302 parser.error('Plot output directory not given (--plot-output)') | |
| 303 if not options.json_output: | |
| 304 parser.error('Json output file not given (--json)') | |
| 305 main(options) | |
| 306 except optparse.OptionError as exc: | |
| 307 raise RuntimeError(f"Invalid arguments: {args}") from exc |
