Mercurial > repos > recetox > matchms_similarity
view matchms_similarity_wrapper.py @ 1:872d8040f713 draft default tip
planemo upload for repository https://github.com/RECETOX/galaxytools/tree/master/tools/matchms commit b1cc1aebf796f170d93e3dd46ffcdefdc7b8018a
author | recetox |
---|---|
date | Thu, 12 Oct 2023 13:25:30 +0000 |
parents | e5010b19d64d |
children |
line wrap: on
line source
import argparse import json import sys from matchms import calculate_scores from matchms.importing import load_from_mgf, load_from_msp from matchms.similarity import (CosineGreedy, CosineHungarian, MetadataMatch, ModifiedCosine, NeutralLossesCosine) from spec2vec import Spec2Vec from spec2vec.serialization.model_importing import load_weights, Word2VecLight def convert_precursor_mz(spectrum): """ Check the presence of precursor m/z since it is needed for ModifiedCosine similarity metric. Convert to float if needed, raise error if missing. """ if "precursor_mz" in spectrum.metadata: metadata = spectrum.metadata metadata["precursor_mz"] = float(metadata["precursor_mz"]) spectrum.metadata = metadata return spectrum else: raise ValueError("Precursor_mz missing. Apply 'add_precursor_mz' filter first.") def load_model(model_file, weights_file) -> Word2VecLight: """ Read a lightweight version of a :class:`~gensim.models.Word2Vec` model from disk. Parameters ---------- model_file: A path of json file to load the model. weights_file: A path of `.npy` file to load the model's weights. Returns ------- :class:`~spec2vec.serialization.model_importing.Word2VecLight` – a lightweight version of a :class:`~gensim.models.Word2Vec` """ with open(model_file, "r", encoding="utf-8") as f: model: dict = json.load(f) del (model["mapfile_path"]) weights = load_weights(weights_file, model["__weights_format"]) return Word2VecLight(model, weights) def main(argv): parser = argparse.ArgumentParser(description="Compute MSP similarity scores") parser.add_argument("-r", dest="ri_tolerance", type=float, help="Use RI filtering with given tolerance.") parser.add_argument("-s", dest="symmetric", action='store_true', help="Computation is symmetric.") parser.add_argument("--array_type", type=str, help="Type of array to use for storing scores (numpy or sparse).") parser.add_argument("--ref", dest="references_filename", type=str, help="Path to reference spectra library.") parser.add_argument("--ref_format", dest="references_format", type=str, help="Reference spectra library file format.") parser.add_argument("--spec2vec_model", dest="spec2vec_model", type=str, help="Path to spec2vec model.") parser.add_argument("--spec2vec_weights", dest="spec2vec_weights", type=str, help="Path to spec2vec weights.") parser.add_argument("--allow_missing_percentage", dest="allowed_missing_percentage", type=lambda x: float(x) * 100.0, help="Maximum percentage of missing peaks in model corpus.") parser.add_argument("queries_filename", type=str, help="Path to query spectra.") parser.add_argument("queries_format", type=str, help="Query spectra file format.") parser.add_argument("similarity_metric", type=str, help='Metric to use for matching.') parser.add_argument("tolerance", type=float, help="Tolerance to use for peak matching.") parser.add_argument("mz_power", type=float, help="The power to raise mz to in the cosine function.") parser.add_argument("intensity_power", type=float, help="The power to raise intensity to in the cosine function.") parser.add_argument("output_filename_scores", type=str, help="Path where to store the output .json scores.") args = parser.parse_args() if args.queries_format == 'msp': queries_spectra = list(load_from_msp(args.queries_filename)) elif args.queries_format == 'mgf': queries_spectra = list(load_from_mgf(args.queries_filename)) else: raise ValueError(f'File format {args.queries_format} not supported for query spectra.') if args.symmetric: reference_spectra = queries_spectra.copy() else: if args.references_format == 'msp': reference_spectra = list(load_from_msp(args.references_filename)) elif args.references_format == 'mgf': reference_spectra = list(load_from_mgf(args.references_filename)) else: raise ValueError(f'File format {args.references_format} not supported for reference spectra library.') if args.similarity_metric == 'CosineGreedy': similarity_metric = CosineGreedy(args.tolerance, args.mz_power, args.intensity_power) elif args.similarity_metric == 'CosineHungarian': similarity_metric = CosineHungarian(args.tolerance, args.mz_power, args.intensity_power) elif args.similarity_metric == 'ModifiedCosine': similarity_metric = ModifiedCosine(args.tolerance, args.mz_power, args.intensity_power) reference_spectra = list(map(convert_precursor_mz, reference_spectra)) queries_spectra = list(map(convert_precursor_mz, queries_spectra)) elif args.similarity_metric == 'NeutralLossesCosine': similarity_metric = NeutralLossesCosine(args.tolerance, args.mz_power, args.intensity_power) reference_spectra = list(map(convert_precursor_mz, reference_spectra)) queries_spectra = list(map(convert_precursor_mz, queries_spectra)) elif args.similarity_metric == 'Spec2Vec': model = load_model(args.spec2vec_model, args.spec2vec_weights) similarity_metric = Spec2Vec(model, intensity_weighting_power=args.intensity_power, allowed_missing_percentage=args.allowed_missing_percentage) else: return -1 print("Calculating scores...") scores = calculate_scores( references=reference_spectra, queries=queries_spectra, array_type=args.array_type, similarity_function=similarity_metric, is_symmetric=args.symmetric ) if args.ri_tolerance is not None: print("RI filtering with tolerance ", args.ri_tolerance) ri_matches = calculate_scores(references=reference_spectra, queries=queries_spectra, similarity_function=MetadataMatch("retention_index", "difference", args.ri_tolerance), array_type="numpy", is_symmetric=args.symmetric).scores scores.scores.add_coo_matrix(ri_matches, "MetadataMatch", join_type="inner") write_outputs(args, scores) return 0 def write_outputs(args, scores): """Write Scores to json file.""" print("Storing outputs...") scores.to_json(args.output_filename_scores) if __name__ == "__main__": main(argv=sys.argv[1:]) pass