Mercurial > repos > bgruening > create_tool_recommendation_model
comparison optimise_hyperparameters.py @ 4:afec8c595124 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 65d36f271296a38deeceb0d0e8d471b2898ee8f4"
author | bgruening |
---|---|
date | Tue, 07 Jul 2020 03:25:49 -0400 |
parents | 5b3c08710e47 |
children | 4f7e6612906b |
comparison
equal
deleted
inserted
replaced
3:5b3c08710e47 | 4:afec8c595124 |
---|---|
18 class HyperparameterOptimisation: | 18 class HyperparameterOptimisation: |
19 | 19 |
20 def __init__(self): | 20 def __init__(self): |
21 """ Init method. """ | 21 """ Init method. """ |
22 | 22 |
23 def train_model(self, config, reverse_dictionary, train_data, train_labels, test_data, test_labels, l_tool_tr_samples, class_weights): | 23 def train_model(self, config, reverse_dictionary, train_data, train_labels, test_data, test_labels, tool_tr_samples, class_weights): |
24 """ | 24 """ |
25 Train a model and report accuracy | 25 Train a model and report accuracy |
26 """ | 26 """ |
27 # convert items to integer | 27 # convert items to integer |
28 l_batch_size = list(map(int, config["batch_size"].split(","))) | 28 l_batch_size = list(map(int, config["batch_size"].split(","))) |
69 model_fit = model.fit_generator( | 69 model_fit = model.fit_generator( |
70 utils.balanced_sample_generator( | 70 utils.balanced_sample_generator( |
71 train_data, | 71 train_data, |
72 train_labels, | 72 train_labels, |
73 batch_size, | 73 batch_size, |
74 l_tool_tr_samples | 74 tool_tr_samples, |
75 reverse_dictionary | |
75 ), | 76 ), |
76 steps_per_epoch=len(train_data) // batch_size, | 77 steps_per_epoch=len(train_data) // batch_size, |
77 epochs=optimize_n_epochs, | 78 epochs=optimize_n_epochs, |
78 callbacks=[early_stopping], | 79 callbacks=[early_stopping], |
79 validation_data=(test_data, test_labels), | 80 validation_data=(test_data, test_labels), |