comparison spec2vec_training_wrapper.py @ 0:e1e22ada831e draft

planemo upload for repository https://github.com/RECETOX/galaxytools/tree/master/tools/spec2vec commit 2e4bdc2fd94445aa5a8d1882a3d092cca727e4b6
author recetox
date Thu, 05 Jan 2023 10:08:12 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e1e22ada831e
1 #!/usr/bin/env python
2
3 import argparse
4 import sys
5
6 from matchms.importing import load_from_mgf, load_from_msp
7 from spec2vec import SpectrumDocument
8 from spec2vec.model_building import train_new_word2vec_model
9 from spec2vec.serialization import export_model
10
11
12 def read_spectra(spectra_file, file_format):
13 if file_format == "mgf":
14 return load_from_mgf(spectra_file)
15 elif file_format == "msp":
16 return load_from_msp(spectra_file)
17 else:
18 raise NotImplementedError(f"Unsupported file format: {file_format}.")
19
20
21 def parse_checkpoints_input(checkpoints_input):
22 checkpoints_str = checkpoints_input.replace(" ", "").split(",")
23 try:
24 checkpoints_int = map(int, checkpoints_str)
25 except ValueError:
26 raise ValueError("Checkpoint values must be integers.")
27 return list(set(checkpoints_int))
28
29
30 def main(argv):
31 parser = argparse.ArgumentParser(description="Train a spec2vec model.")
32
33 # Input data
34 parser.add_argument("--spectra_filename", type=str, help="Path to a file containing spectra.")
35 parser.add_argument("--spectra_fileformat", type=str, help="Spectra file format.")
36
37 # Training parameters
38 parser.add_argument("--epochs", type=int, default=0, help="Number of epochs to train the model.")
39 parser.add_argument("--checkpoints", type=str, default=None, help="Epochs after which to save the model.")
40
41 # Hyperparameters
42 parser.add_argument("--vector_size", type=int, default=100, help="Dimensionality of the feature vectors.")
43 parser.add_argument("--alpha", type=float, default=0.025, help="The initial learning rate.")
44 parser.add_argument("--window", type=int, default=5, help="The maximum distance between the current and predicted peak within a spectrum.")
45 parser.add_argument("--min_count", type=int, default=5, help="Ignores all peaks with total frequency lower than this.")
46 parser.add_argument("--sample", type=float, default=0.001, help="The threshold for configuring which higher-frequency peaks are randomly downsampled.")
47 parser.add_argument("--seed", type=int, default=1, help="A seed for model reproducibility.")
48 parser.add_argument("--min_alpha", type=float, default=0.0001, help="Learning rate will linearly drop to min_alpha as training progresses.")
49 parser.add_argument("--sg", type=int, default=0, help="Training algorithm: 1 for skip-gram; otherwise CBOW.")
50 parser.add_argument("--hs", type=int, default=0, help="If 1, hierarchical softmax will be used for model training. If set to 0, and negative is non-zero, negative sampling will be used.")
51 parser.add_argument("--negative", type=int, default=5, help="If > 0, negative sampling will be used, the int for negative specifies how many “noise words” should be drawn (usually between 5-20). If set to 0, no negative sampling is used.")
52 parser.add_argument("--ns_exponent", type=float, default=0.75, help="The exponent used to shape the negative sampling distribution.")
53 parser.add_argument("--cbow_mean", type=int, default=1, help="If 0, use the sum of the context word vectors. If 1, use the mean. Only applies when cbow is used.")
54 parser.add_argument("--sorted_vocab", type=bool, default=True, help="If 1, sort the vocabulary by descending frequency before assigning word indexes.")
55 parser.add_argument("--batch_words", type=int, default=10000, help="Target size (in words) for batches of examples passed to worker threads (and thus cython routines). Larger batches will be passed if individual texts are longer than 10000 words, but the standard cython code truncates to that maximum.")
56 parser.add_argument("--shrink_windows", type=bool, default=True, help="If 1, the input sentence will be truncated to the window size.")
57 parser.add_argument("--max_vocab_size", type=int, default=None, help="Limits the RAM during vocabulary building; if there are more unique words than this, then prune the infrequent ones. Every 10 million word types need about 1GB of RAM. Set to None for no limit (default).")
58 parser.add_argument("--n_decimals", type=int, default=2, help="Rounds peak position to this number of decimals.")
59 parser.add_argument("--n_workers", type=int, default=1, help="Number of worker nodes to train the model.")
60
61 # Output files
62 parser.add_argument("--model_filename_pickle", type=str, help="If specified, the model will also be saved as a pickle file.")
63 parser.add_argument("--model_filename", type=str, help="Path to the output model json-file.")
64 parser.add_argument("--weights_filename", type=str, help="Path to the output weights json-file.")
65
66 args = parser.parse_args()
67
68 # Load the spectra
69 spectra = list(read_spectra(args.spectra_filename, args.spectra_fileformat))
70 reference_documents = [SpectrumDocument(spectrum, n_decimals=args.n_decimals) for spectrum in spectra]
71
72 # Process epoch arguments
73 if args.checkpoints:
74 iterations = parse_checkpoints_input(args.checkpoints)
75 else:
76 iterations = args.epochs
77
78 # Train a model
79 model = train_new_word2vec_model(
80 documents=reference_documents,
81 iterations=iterations,
82 filename="spec2vec",
83 progress_logger=True,
84 workers=args.n_workers,
85 vector_size=args.vector_size,
86 learning_rate_initial=args.alpha,
87 learning_rate_decay=args.min_alpha,
88 window=args.window,
89 min_count=args.min_count,
90 sample=args.sample,
91 seed=args.seed,
92 sg=args.sg,
93 hs=args.hs,
94 negative=args.negative,
95 ns_exponent=args.ns_exponent,
96 cbow_mean=args.cbow_mean,
97 sorted_vocab=args.sorted_vocab,
98 batch_words=args.batch_words,
99 shrink_windows=args.shrink_windows,
100 max_vocab_size=args.max_vocab_size)
101
102 # Save the model
103 if args.model_filename_pickle:
104 model.save(args.model_filename_pickle)
105
106 export_model(model, args.model_filename, args.weights_filename)
107
108
109 if __name__ == "__main__":
110 main(argv=sys.argv[1:])