Mercurial > repos > iuc > data_manager_clair3_models
view data_manager/model_fetcher.py @ 0:11e42265a9b0 draft default tip
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/main/data_managers/data_manager_clair3_models commit 2672414472cc968c736dc7d42f5a119ff8c16c62
author | iuc |
---|---|
date | Thu, 20 Feb 2025 17:57:11 +0000 |
parents | |
children |
line wrap: on
line source
#!/usr/bin/env python3 import argparse import json import sys import tarfile from hashlib import sha256 from io import BytesIO, StringIO from pathlib import Path from urllib.error import HTTPError from urllib.request import Request, urlopen DATA_TABLE_NAME = 'clair3_models' def find_latest_models(): # based on the README.rst of the rerio repository as of 7 January 2025 url = 'https://raw.githubusercontent.com/nanoporetech/rerio/refs/heads/master/README.rst' httprequest = Request(url) with urlopen(httprequest) as response: if response.status != 200: raise IOError(f'Failed to fetch the latest models: {response.status}') data = response.read().decode('utf-8') init_line_seen = False latest_seen = False config_line_seen = False read_lines = False models = [] # the file that we are parsing has a section that looks like this: # Clair3 Models # ------------- # Clair3 models for the following configurations are available: # Latest: # ========================== =================== ======================= # Config Chemistry Dorado basecaller model # ========================== =================== ======================= # r1041_e82_400bps_sup_v500 R10.4.1 E8.2 (5kHz) v5.0.0 SUP # r1041_e82_400bps_hac_v500 R10.4.1 E8.2 (5kHz) v5.0.0 HAC # r1041_e82_400bps_sup_v410 R10.4.1 E8.2 (4kHz) v4.1.0 SUP # r1041_e82_400bps_hac_v410 R10.4.1 E8.2 (4kHz) v4.1.0 HAC # ========================== =================== ======================= # # and the aim is to extract the list of model names from the table by successfully looking for # "Clair3 Models", then "Latest:", then "Config" and then "=====" and then reading the lines until # the next "=====" is encountered for line in StringIO(data): if read_lines: if line.startswith('====='): read_lines = False break model = line.split()[0] models.append(model) if config_line_seen and line.startswith('====='): read_lines = True continue if init_line_seen and line.startswith('Latest:'): latest_seen = True continue if latest_seen and line.startswith('Config'): config_line_seen = True continue if line.startswith('Clair3 Models'): init_line_seen = True continue return models def fetch_model(model_name): # the model files are tar gzipped, with a structure like: # model_name/pileup.index # model_name/full_alignment.index # and other files, with the key point being that the model_name becoomes the model_directory url = f'https://raw.githubusercontent.com/nanoporetech/rerio/refs/heads/master/clair3_models/{model_name}_model' httprequest = Request(url) try: # urlopen throws a HTTPError if it gets a 404 status (and perhaps other non-200 status?) with urlopen(httprequest) as response: if response.status != 200: raise IOError(f'Failed to fetch the model {model_name}: {response.status}') final_url = response.read().decode('utf-8').strip() httprequest = Request(final_url) except HTTPError as e: raise IOError(f'Failed to fetch the model {model_name}: {e}') with urlopen(httprequest) as response: if response.status != 200: raise IOError(f'Failed to fetch the model {model_name} from CDN URL {final_url}: {response.status}') data = response.read() return data def unpack_model(data, outdir): with tarfile.open(fileobj=BytesIO(data), mode='r:*') as tar: tar.extractall(outdir) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('dm_filename', type=str, help='The filename of the data manager file to read parameters from and write outputs to') parser.add_argument('--known_models', type=str, help='List of models already known in the Galaxy data table') parser.add_argument('--sha256_sums', type=str, help='List of sha256sums of the models already known in the Galaxy data table') parser.add_argument('--download_latest', action='store_true', default=False, help='Download the latest models as per the Rerio repository') parser.add_argument('--download_models', type=str, help='Comma separated list of models to download') args = parser.parse_args() # parameters to a data manager are passed in a JSON file (see https://docs.galaxyproject.org/en/latest/dev/data_managers.html) and # similarily a JSON file is created to pass the output back to Galaxy models = [] if args.download_latest: models.extend(find_latest_models()) if args.download_models: models.extend(args.download_models.split(',')) if not models: sys.exit('No models to download, please specify either --download_latest or --download_models') with open(args.dm_filename) as fh: config = json.load(fh) if 'extra_files_path' not in config.get('output_data', [{}])[0]: sys.exit('Please specify the output directory in the data manager configuration (the extra_files_path)') output_directory = config["output_data"][0]["extra_files_path"] if not Path(output_directory).exists(): Path(output_directory).mkdir(parents=True) data_manager_dict = {} data_manager_dict["data_tables"] = config.get("data_tables", {}) data_manager_dict["data_tables"][DATA_TABLE_NAME] = [] known_models = set(args.known_models.split(',')) if args.known_models else set() model_to_sha256 = {} if args.known_models: sha256_sums = args.sha256_sums.split(',') for (i, model) in enumerate(known_models): model_to_sha256[model] = sha256_sums[i] for model in models: model_dir = Path(output_directory) / model # The data table cannot handle duplicate entries, so we skip models that are already in the data table if model in known_models: print(f'Model {model} already exists, skipping', file=sys.stderr) continue data = fetch_model(model) sha256sum = sha256(data).hexdigest() # Since we skip models that are already known we cannot test the sha256sum here. This code is retained to illustrate that an # alternative logic would be to download the model each time and check if the sha256sum matches what is already known. Hopefully # ONT does not update the models while keeping the same name, so this is not needed. The sha256sum is stored in the data table # in case it is needed in the future. # if model in model_to_sha256 and sha256sum != model_to_sha256[model]: # sys.exit(f'Model {model} already exists with a different sha256sum {model_to_sha256[model]}. This is a serious error, inform the Galaxy admin') unpack_model(data, output_directory) data_manager_dict["data_tables"][DATA_TABLE_NAME].append( dict( value=model, platform="ont", sha256=sha256sum, path=str(model_dir), source="rerio" ) ) with open(args.dm_filename, 'w') as fh: json.dump(data_manager_dict, fh, sort_keys=True, indent=4)