comparison optimise_hyperparameters.py @ 4:afec8c595124 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 65d36f271296a38deeceb0d0e8d471b2898ee8f4"
author bgruening
date Tue, 07 Jul 2020 03:25:49 -0400
parents 5b3c08710e47
children 4f7e6612906b
comparison
equal deleted inserted replaced
3:5b3c08710e47 4:afec8c595124
18 class HyperparameterOptimisation: 18 class HyperparameterOptimisation:
19 19
20 def __init__(self): 20 def __init__(self):
21 """ Init method. """ 21 """ Init method. """
22 22
23 def train_model(self, config, reverse_dictionary, train_data, train_labels, test_data, test_labels, l_tool_tr_samples, class_weights): 23 def train_model(self, config, reverse_dictionary, train_data, train_labels, test_data, test_labels, tool_tr_samples, class_weights):
24 """ 24 """
25 Train a model and report accuracy 25 Train a model and report accuracy
26 """ 26 """
27 # convert items to integer 27 # convert items to integer
28 l_batch_size = list(map(int, config["batch_size"].split(","))) 28 l_batch_size = list(map(int, config["batch_size"].split(",")))
69 model_fit = model.fit_generator( 69 model_fit = model.fit_generator(
70 utils.balanced_sample_generator( 70 utils.balanced_sample_generator(
71 train_data, 71 train_data,
72 train_labels, 72 train_labels,
73 batch_size, 73 batch_size,
74 l_tool_tr_samples 74 tool_tr_samples,
75 reverse_dictionary
75 ), 76 ),
76 steps_per_epoch=len(train_data) // batch_size, 77 steps_per_epoch=len(train_data) // batch_size,
77 epochs=optimize_n_epochs, 78 epochs=optimize_n_epochs,
78 callbacks=[early_stopping], 79 callbacks=[early_stopping],
79 validation_data=(test_data, test_labels), 80 validation_data=(test_data, test_labels),