comparison data_manager/model_fetcher.py @ 0:11e42265a9b0 draft default tip

planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/main/data_managers/data_manager_clair3_models commit 2672414472cc968c736dc7d42f5a119ff8c16c62
author iuc
date Thu, 20 Feb 2025 17:57:11 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:11e42265a9b0
1 #!/usr/bin/env python3
2
3 import argparse
4 import json
5 import sys
6 import tarfile
7 from hashlib import sha256
8 from io import BytesIO, StringIO
9 from pathlib import Path
10 from urllib.error import HTTPError
11 from urllib.request import Request, urlopen
12
13 DATA_TABLE_NAME = 'clair3_models'
14
15
16 def find_latest_models():
17 # based on the README.rst of the rerio repository as of 7 January 2025
18 url = 'https://raw.githubusercontent.com/nanoporetech/rerio/refs/heads/master/README.rst'
19 httprequest = Request(url)
20 with urlopen(httprequest) as response:
21 if response.status != 200:
22 raise IOError(f'Failed to fetch the latest models: {response.status}')
23 data = response.read().decode('utf-8')
24 init_line_seen = False
25 latest_seen = False
26 config_line_seen = False
27 read_lines = False
28 models = []
29 # the file that we are parsing has a section that looks like this:
30 # Clair3 Models
31 # -------------
32
33 # Clair3 models for the following configurations are available:
34
35 # Latest:
36
37 # ========================== =================== =======================
38 # Config Chemistry Dorado basecaller model
39 # ========================== =================== =======================
40 # r1041_e82_400bps_sup_v500 R10.4.1 E8.2 (5kHz) v5.0.0 SUP
41 # r1041_e82_400bps_hac_v500 R10.4.1 E8.2 (5kHz) v5.0.0 HAC
42 # r1041_e82_400bps_sup_v410 R10.4.1 E8.2 (4kHz) v4.1.0 SUP
43 # r1041_e82_400bps_hac_v410 R10.4.1 E8.2 (4kHz) v4.1.0 HAC
44 # ========================== =================== =======================
45 #
46 # and the aim is to extract the list of model names from the table by successfully looking for
47 # "Clair3 Models", then "Latest:", then "Config" and then "=====" and then reading the lines until
48 # the next "=====" is encountered
49 for line in StringIO(data):
50 if read_lines:
51 if line.startswith('====='):
52 read_lines = False
53 break
54 model = line.split()[0]
55 models.append(model)
56 if config_line_seen and line.startswith('====='):
57 read_lines = True
58 continue
59 if init_line_seen and line.startswith('Latest:'):
60 latest_seen = True
61 continue
62 if latest_seen and line.startswith('Config'):
63 config_line_seen = True
64 continue
65 if line.startswith('Clair3 Models'):
66 init_line_seen = True
67 continue
68 return models
69
70
71 def fetch_model(model_name):
72 # the model files are tar gzipped, with a structure like:
73 # model_name/pileup.index
74 # model_name/full_alignment.index
75 # and other files, with the key point being that the model_name becoomes the model_directory
76
77 url = f'https://raw.githubusercontent.com/nanoporetech/rerio/refs/heads/master/clair3_models/{model_name}_model'
78 httprequest = Request(url)
79 try:
80 # urlopen throws a HTTPError if it gets a 404 status (and perhaps other non-200 status?)
81 with urlopen(httprequest) as response:
82 if response.status != 200:
83 raise IOError(f'Failed to fetch the model {model_name}: {response.status}')
84 final_url = response.read().decode('utf-8').strip()
85 httprequest = Request(final_url)
86 except HTTPError as e:
87 raise IOError(f'Failed to fetch the model {model_name}: {e}')
88
89 with urlopen(httprequest) as response:
90 if response.status != 200:
91 raise IOError(f'Failed to fetch the model {model_name} from CDN URL {final_url}: {response.status}')
92 data = response.read()
93 return data
94
95
96 def unpack_model(data, outdir):
97 with tarfile.open(fileobj=BytesIO(data), mode='r:*') as tar:
98 tar.extractall(outdir)
99
100
101 if __name__ == '__main__':
102 parser = argparse.ArgumentParser()
103 parser.add_argument('dm_filename', type=str, help='The filename of the data manager file to read parameters from and write outputs to')
104 parser.add_argument('--known_models', type=str, help='List of models already known in the Galaxy data table')
105 parser.add_argument('--sha256_sums', type=str, help='List of sha256sums of the models already known in the Galaxy data table')
106 parser.add_argument('--download_latest', action='store_true', default=False, help='Download the latest models as per the Rerio repository')
107 parser.add_argument('--download_models', type=str, help='Comma separated list of models to download')
108 args = parser.parse_args()
109
110 # parameters to a data manager are passed in a JSON file (see https://docs.galaxyproject.org/en/latest/dev/data_managers.html) and
111 # similarily a JSON file is created to pass the output back to Galaxy
112 models = []
113 if args.download_latest:
114 models.extend(find_latest_models())
115 if args.download_models:
116 models.extend(args.download_models.split(','))
117
118 if not models:
119 sys.exit('No models to download, please specify either --download_latest or --download_models')
120
121 with open(args.dm_filename) as fh:
122 config = json.load(fh)
123 if 'extra_files_path' not in config.get('output_data', [{}])[0]:
124 sys.exit('Please specify the output directory in the data manager configuration (the extra_files_path)')
125 output_directory = config["output_data"][0]["extra_files_path"]
126 if not Path(output_directory).exists():
127 Path(output_directory).mkdir(parents=True)
128
129 data_manager_dict = {}
130 data_manager_dict["data_tables"] = config.get("data_tables", {})
131 data_manager_dict["data_tables"][DATA_TABLE_NAME] = []
132
133 known_models = set(args.known_models.split(',')) if args.known_models else set()
134 model_to_sha256 = {}
135 if args.known_models:
136 sha256_sums = args.sha256_sums.split(',')
137 for (i, model) in enumerate(known_models):
138 model_to_sha256[model] = sha256_sums[i]
139
140 for model in models:
141 model_dir = Path(output_directory) / model
142 # The data table cannot handle duplicate entries, so we skip models that are already in the data table
143 if model in known_models:
144 print(f'Model {model} already exists, skipping', file=sys.stderr)
145 continue
146 data = fetch_model(model)
147 sha256sum = sha256(data).hexdigest()
148
149 # Since we skip models that are already known we cannot test the sha256sum here. This code is retained to illustrate that an
150 # alternative logic would be to download the model each time and check if the sha256sum matches what is already known. Hopefully
151 # ONT does not update the models while keeping the same name, so this is not needed. The sha256sum is stored in the data table
152 # in case it is needed in the future.
153 # if model in model_to_sha256 and sha256sum != model_to_sha256[model]:
154 # sys.exit(f'Model {model} already exists with a different sha256sum {model_to_sha256[model]}. This is a serious error, inform the Galaxy admin')
155
156 unpack_model(data, output_directory)
157
158 data_manager_dict["data_tables"][DATA_TABLE_NAME].append(
159 dict(
160 value=model,
161 platform="ont",
162 sha256=sha256sum,
163 path=str(model_dir),
164 source="rerio"
165 )
166 )
167
168 with open(args.dm_filename, 'w') as fh:
169 json.dump(data_manager_dict, fh, sort_keys=True, indent=4)