comparison simple_model_fit.py @ 17:a01fa4e8fe4f draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 12:54:40 +0000
parents c9ddd20d25d0
children
comparison
equal deleted inserted replaced
16:d0352e8b4c10 17:a01fa4e8fe4f
1 import argparse 1 import argparse
2 import json 2 import json
3 import pickle
4 3
5 import pandas as pd 4 import pandas as pd
6 from galaxy_ml.utils import load_model, read_columns 5 from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5
6 from galaxy_ml.utils import read_columns
7 from scipy.io import mmread 7 from scipy.io import mmread
8 from sklearn.pipeline import Pipeline 8 from sklearn.pipeline import Pipeline
9 9
10 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) 10 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1))
11 11
146 """ 146 """
147 with open(inputs, "r") as param_handler: 147 with open(inputs, "r") as param_handler:
148 params = json.load(param_handler) 148 params = json.load(param_handler)
149 149
150 # load model 150 # load model
151 with open(infile_estimator, "rb") as est_handler: 151 estimator = load_model_from_h5(infile_estimator)
152 estimator = load_model(est_handler) 152
153 estimator = clean_params(estimator, n_jobs=N_JOBS) 153 estimator = clean_params(estimator)
154 154
155 X_train, y_train = _get_X_y(params, infile1, infile2) 155 X_train, y_train = _get_X_y(params, infile1, infile2)
156 156
157 estimator.fit(X_train, y_train) 157 estimator.fit(X_train, y_train)
158 158
168 if getattr(main_est, "validation_data", None): 168 if getattr(main_est, "validation_data", None):
169 del main_est.validation_data 169 del main_est.validation_data
170 if getattr(main_est, "data_generator_", None): 170 if getattr(main_est, "data_generator_", None):
171 del main_est.data_generator_ 171 del main_est.data_generator_
172 172
173 with open(out_object, "wb") as output_handler: 173 dump_model_to_h5(estimator, out_object)
174 pickle.dump(estimator, output_handler, pickle.HIGHEST_PROTOCOL)
175 174
176 175
177 if __name__ == "__main__": 176 if __name__ == "__main__":
178 aparser = argparse.ArgumentParser() 177 aparser = argparse.ArgumentParser()
179 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 178 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)