Mercurial > repos > goeckslab > tabular_learner
changeset 5:3d42f82b3c7f draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 4a11e8a4c4e9daa884bddedfa47090476c517667
author | goeckslab |
---|---|
date | Thu, 31 Jul 2025 15:41:07 +0000 |
parents | 11fdac5affb3 |
children | |
files | base_model_trainer.py pycaret_train.py tabular_learner.xml |
diffstat | 3 files changed, 40 insertions(+), 15 deletions(-) [+] |
line wrap: on
line diff
--- a/base_model_trainer.py Fri Jul 25 19:02:12 2025 +0000 +++ b/base_model_trainer.py Thu Jul 31 15:41:07 2025 +0000 @@ -175,7 +175,13 @@ if self.task_type == "classification": self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) - _ = self.exp.predict_model(self.best_model) + + prob_thresh = getattr(self, "probability_threshold", None) + if self.task_type == "classification" and prob_thresh is not None: + _ = self.exp.predict_model(self.best_model, probability_threshold=prob_thresh) + else: + _ = self.exp.predict_model(self.best_model) + self.test_result_df = self.exp.pull() if self.task_type == "classification": self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True) @@ -233,7 +239,7 @@ best_model_name = type(self.best_model).__name__ LOG.info(f"Best model determined as: {best_model_name}") - # 2) Compute training sample count + # 2) Compute training sample count try: n_train = self.exp.X_train.shape[0] except Exception: @@ -241,7 +247,10 @@ total_rows = self.data.shape[0] # 3) Build setup parameters table - all_params = self.setup_params + all_params = self.setup_params.copy() + if self.task_type == "classification" and hasattr(self, "probability_threshold"): + all_params["probability_threshold"] = self.probability_threshold + display_keys = [ "Target", "Session ID", @@ -255,6 +264,7 @@ "Polynomial Features", "Fix Imbalance", "Models", + "Probability Threshold", ] setup_rows = [] for key in display_keys: @@ -281,6 +291,8 @@ dv = v if v is not None else "None" elif key == "Models": dv = ", ".join(map(str, v)) if isinstance(v, (list, tuple)) else "None" + elif key == "Probability Threshold": + dv = v if v is not None else "None" else: dv = v if v is not None else "None" setup_rows.append([key, dv])
--- a/pycaret_train.py Fri Jul 25 19:02:12 2025 +0000 +++ b/pycaret_train.py Thu Jul 31 15:41:07 2025 +0000 @@ -103,16 +103,22 @@ help="Tune the best model hyperparameters after training", ) parser.add_argument( + "--test_file", + type=str, + default=None, + help="Path to the test data file", + ) + parser.add_argument( "--random_seed", type=int, default=42, help="Random seed for PyCaret setup", ) parser.add_argument( - "--test_file", - type=str, + "--probability_threshold", + type=float, default=None, - help="Path to the test data file", + help="Probability threshold for classification decision,", ) args = parser.parse_args() @@ -120,7 +126,7 @@ # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation if args.no_cross_validation: args.cross_validation = False - # If --cross_validation was passed, args.cross_validation is True + # If --cross_validation was passed, args.cross_validation is True # If neither was passed, args.cross_validation remains None # Build the model_kwargs dict from CLI args @@ -137,6 +143,7 @@ "feature_ratio": args.feature_ratio, "fix_imbalance": args.fix_imbalance, "tune_model": args.tune_model, + "probability_threshold": args.probability_threshold, } LOG.info(f"Model kwargs: {model_kwargs}")
--- a/tabular_learner.xml Fri Jul 25 19:02:12 2025 +0000 +++ b/tabular_learner.xml Thu Jul 31 15:41:07 2025 +0000 @@ -22,10 +22,10 @@ #end if #if $customize_defaults == "true" #if $train_size - --train_size '$train_size' + --train_size '$train_size' #end if #if $normalize - --normalize + --normalize #end if #if $feature_selection --feature_selection @@ -34,27 +34,30 @@ --cross_validation #if $cross_validation_folds --cross_validation_folds '$cross_validation_folds' - #end if + #end if #end if #if $enable_cross_validation == "false" --no_cross_validation #end if #if $remove_outliers - --remove_outliers + --remove_outliers #end if #if $remove_multicollinearity - --remove_multicollinearity + --remove_multicollinearity #end if #if $polynomial_features - --polynomial_features + --polynomial_features #end if #if $fix_imbalance - --fix_imbalance + --fix_imbalance + #end if + #if $probability_threshold + --probability_threshold '$probability_threshold' #end if #end if #if $test_file --test_file '$test_file' - #end if + #end if --model_type '$model_type' ]]> </command> @@ -150,6 +153,7 @@ <param name="remove_multicollinearity" type="boolean" truevalue="True" falsevalue="False" checked="false" label="Remove Multicollinearity" help="Whether to remove multicollinear features before training." /> <param name="polynomial_features" type="boolean" truevalue="True" falsevalue="False" checked="false" label="Polynomial Features" help="Whether to create polynomial features before training." /> <param name="fix_imbalance" type="boolean" truevalue="True" falsevalue="False" checked="false" label="Fix Imbalance" help="ONLY for classfication! Whether to use SMOTE or similar methods to fix imbalance in the input dataset." /> + <param name="probability_threshold" type="float" min="0.0" max="1.0" value="0.5" label="Classification Probability Threshold" help="Only applies to classification. Probability above which a prediction is considered positive. Default is 0.5." /> </when> <when value="false"> <!-- No additional parameters to show if the user selects 'No' --> @@ -175,6 +179,7 @@ <param name="cross_validation_folds" value="5"/> <param name="remove_outliers" value="true"/> <param name="remove_multicollinearity" value="true"/> + <param name="probability_threshold" value="0.4" /> <output name="model" file="expected_model_classification_customized.h5" compare="sim_size"/> <output name="comparison_result"> <assert_contents> @@ -197,6 +202,7 @@ <param name="enable_cross_validation" value="false"/> <param name="remove_outliers" value="true"/> <param name="remove_multicollinearity" value="true"/> + <param name="probability_threshold" value="0.6" /> <output name="model" file="expected_model_classification_customized_cross_off.h5" compare="sim_size"/> <output name="comparison_result"> <assert_contents>