Mercurial > repos > recetox > matchms_fingerprint_similarity
comparison matchms_similarity_wrapper.py @ 0:84af792d3a78 draft
planemo upload for repository https://github.com/RECETOX/galaxytools/tree/master/tools/matchms commit f79a5b51599254817727bc9028b9797ea994cb4e
author | recetox |
---|---|
date | Tue, 27 Jun 2023 14:27:04 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:84af792d3a78 |
---|---|
1 import argparse | |
2 import json | |
3 import sys | |
4 | |
5 from matchms import calculate_scores | |
6 from matchms.importing import load_from_mgf, load_from_msp | |
7 from matchms.similarity import (CosineGreedy, CosineHungarian, MetadataMatch, | |
8 ModifiedCosine, NeutralLossesCosine) | |
9 from spec2vec import Spec2Vec | |
10 from spec2vec.serialization.model_importing import load_weights, Word2VecLight | |
11 | |
12 | |
13 def convert_precursor_mz(spectrum): | |
14 """ | |
15 Check the presence of precursor m/z since it is needed for ModifiedCosine similarity metric. Convert to float if | |
16 needed, raise error if missing. | |
17 """ | |
18 | |
19 if "precursor_mz" in spectrum.metadata: | |
20 metadata = spectrum.metadata | |
21 metadata["precursor_mz"] = float(metadata["precursor_mz"]) | |
22 spectrum.metadata = metadata | |
23 return spectrum | |
24 else: | |
25 raise ValueError("Precursor_mz missing. Apply 'add_precursor_mz' filter first.") | |
26 | |
27 | |
28 def load_model(model_file, weights_file) -> Word2VecLight: | |
29 """ | |
30 Read a lightweight version of a :class:`~gensim.models.Word2Vec` model from disk. | |
31 | |
32 Parameters | |
33 ---------- | |
34 model_file: | |
35 A path of json file to load the model. | |
36 weights_file: | |
37 A path of `.npy` file to load the model's weights. | |
38 | |
39 Returns | |
40 ------- | |
41 :class:`~spec2vec.serialization.model_importing.Word2VecLight` – a lightweight version of a | |
42 :class:`~gensim.models.Word2Vec` | |
43 """ | |
44 with open(model_file, "r", encoding="utf-8") as f: | |
45 model: dict = json.load(f) | |
46 del (model["mapfile_path"]) | |
47 | |
48 weights = load_weights(weights_file, model["__weights_format"]) | |
49 return Word2VecLight(model, weights) | |
50 | |
51 | |
52 def main(argv): | |
53 parser = argparse.ArgumentParser(description="Compute MSP similarity scores") | |
54 parser.add_argument("-r", dest="ri_tolerance", type=float, help="Use RI filtering with given tolerance.") | |
55 parser.add_argument("-s", dest="symmetric", action='store_true', help="Computation is symmetric.") | |
56 parser.add_argument("--array_type", type=str, help="Type of array to use for storing scores (numpy or sparse).") | |
57 parser.add_argument("--ref", dest="references_filename", type=str, help="Path to reference spectra library.") | |
58 parser.add_argument("--ref_format", dest="references_format", type=str, help="Reference spectra library file format.") | |
59 parser.add_argument("--spec2vec_model", dest="spec2vec_model", type=str, help="Path to spec2vec model.") | |
60 parser.add_argument("--spec2vec_weights", dest="spec2vec_weights", type=str, help="Path to spec2vec weights.") | |
61 parser.add_argument("--allow_missing_percentage", dest="allowed_missing_percentage", type=lambda x: float(x) * 100.0, help="Maximum percentage of missing peaks in model corpus.") | |
62 parser.add_argument("queries_filename", type=str, help="Path to query spectra.") | |
63 parser.add_argument("queries_format", type=str, help="Query spectra file format.") | |
64 parser.add_argument("similarity_metric", type=str, help='Metric to use for matching.') | |
65 parser.add_argument("tolerance", type=float, help="Tolerance to use for peak matching.") | |
66 parser.add_argument("mz_power", type=float, help="The power to raise mz to in the cosine function.") | |
67 parser.add_argument("intensity_power", type=float, help="The power to raise intensity to in the cosine function.") | |
68 parser.add_argument("output_filename_scores", type=str, help="Path where to store the output .json scores.") | |
69 args = parser.parse_args() | |
70 | |
71 if args.queries_format == 'msp': | |
72 queries_spectra = list(load_from_msp(args.queries_filename)) | |
73 elif args.queries_format == 'mgf': | |
74 queries_spectra = list(load_from_mgf(args.queries_filename)) | |
75 else: | |
76 raise ValueError(f'File format {args.queries_format} not supported for query spectra.') | |
77 | |
78 if args.symmetric: | |
79 reference_spectra = queries_spectra.copy() | |
80 else: | |
81 if args.references_format == 'msp': | |
82 reference_spectra = list(load_from_msp(args.references_filename)) | |
83 elif args.references_format == 'mgf': | |
84 reference_spectra = list(load_from_mgf(args.references_filename)) | |
85 else: | |
86 raise ValueError(f'File format {args.references_format} not supported for reference spectra library.') | |
87 | |
88 if args.similarity_metric == 'CosineGreedy': | |
89 similarity_metric = CosineGreedy(args.tolerance, args.mz_power, args.intensity_power) | |
90 elif args.similarity_metric == 'CosineHungarian': | |
91 similarity_metric = CosineHungarian(args.tolerance, args.mz_power, args.intensity_power) | |
92 elif args.similarity_metric == 'ModifiedCosine': | |
93 similarity_metric = ModifiedCosine(args.tolerance, args.mz_power, args.intensity_power) | |
94 reference_spectra = list(map(convert_precursor_mz, reference_spectra)) | |
95 queries_spectra = list(map(convert_precursor_mz, queries_spectra)) | |
96 elif args.similarity_metric == 'NeutralLossesCosine': | |
97 similarity_metric = NeutralLossesCosine(args.tolerance, args.mz_power, args.intensity_power) | |
98 reference_spectra = list(map(convert_precursor_mz, reference_spectra)) | |
99 queries_spectra = list(map(convert_precursor_mz, queries_spectra)) | |
100 elif args.similarity_metric == 'Spec2Vec': | |
101 model = load_model(args.spec2vec_model, args.spec2vec_weights) | |
102 similarity_metric = Spec2Vec(model, intensity_weighting_power=args.intensity_power, allowed_missing_percentage=args.allowed_missing_percentage) | |
103 else: | |
104 return -1 | |
105 | |
106 print("Calculating scores...") | |
107 scores = calculate_scores( | |
108 references=reference_spectra, | |
109 queries=queries_spectra, | |
110 array_type=args.array_type, | |
111 similarity_function=similarity_metric, | |
112 is_symmetric=args.symmetric | |
113 ) | |
114 | |
115 if args.ri_tolerance is not None: | |
116 print("RI filtering with tolerance ", args.ri_tolerance) | |
117 ri_matches = calculate_scores(references=reference_spectra, | |
118 queries=queries_spectra, | |
119 similarity_function=MetadataMatch("retention_index", "difference", args.ri_tolerance), | |
120 array_type="numpy", | |
121 is_symmetric=args.symmetric).scores | |
122 scores.scores.add_coo_matrix(ri_matches, "MetadataMatch", join_type="inner") | |
123 | |
124 write_outputs(args, scores) | |
125 return 0 | |
126 | |
127 | |
128 def write_outputs(args, scores): | |
129 """Write Scores to json file.""" | |
130 print("Storing outputs...") | |
131 scores.to_json(args.output_filename_scores) | |
132 | |
133 | |
134 if __name__ == "__main__": | |
135 main(argv=sys.argv[1:]) | |
136 pass |