# HG changeset patch
# User goeckslab
# Date 1753976467 0
# Node ID 3d42f82b3c7f192fc8efc667ebdb2e8cea3cd6d2
# Parent 11fdac5affb3f7b3b43ab69a91dfa208a83fc6b6
planemo upload for repository https://github.com/goeckslab/gleam commit 4a11e8a4c4e9daa884bddedfa47090476c517667
diff -r 11fdac5affb3 -r 3d42f82b3c7f base_model_trainer.py
--- a/base_model_trainer.py Fri Jul 25 19:02:12 2025 +0000
+++ b/base_model_trainer.py Thu Jul 31 15:41:07 2025 +0000
@@ -175,7 +175,13 @@
if self.task_type == "classification":
self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
- _ = self.exp.predict_model(self.best_model)
+
+ prob_thresh = getattr(self, "probability_threshold", None)
+ if self.task_type == "classification" and prob_thresh is not None:
+ _ = self.exp.predict_model(self.best_model, probability_threshold=prob_thresh)
+ else:
+ _ = self.exp.predict_model(self.best_model)
+
self.test_result_df = self.exp.pull()
if self.task_type == "classification":
self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
@@ -233,7 +239,7 @@
best_model_name = type(self.best_model).__name__
LOG.info(f"Best model determined as: {best_model_name}")
- # 2) Compute training sample count
+ # 2) Compute training sample count
try:
n_train = self.exp.X_train.shape[0]
except Exception:
@@ -241,7 +247,10 @@
total_rows = self.data.shape[0]
# 3) Build setup parameters table
- all_params = self.setup_params
+ all_params = self.setup_params.copy()
+ if self.task_type == "classification" and hasattr(self, "probability_threshold"):
+ all_params["probability_threshold"] = self.probability_threshold
+
display_keys = [
"Target",
"Session ID",
@@ -255,6 +264,7 @@
"Polynomial Features",
"Fix Imbalance",
"Models",
+ "Probability Threshold",
]
setup_rows = []
for key in display_keys:
@@ -281,6 +291,8 @@
dv = v if v is not None else "None"
elif key == "Models":
dv = ", ".join(map(str, v)) if isinstance(v, (list, tuple)) else "None"
+ elif key == "Probability Threshold":
+ dv = v if v is not None else "None"
else:
dv = v if v is not None else "None"
setup_rows.append([key, dv])
diff -r 11fdac5affb3 -r 3d42f82b3c7f pycaret_train.py
--- a/pycaret_train.py Fri Jul 25 19:02:12 2025 +0000
+++ b/pycaret_train.py Thu Jul 31 15:41:07 2025 +0000
@@ -103,16 +103,22 @@
help="Tune the best model hyperparameters after training",
)
parser.add_argument(
+ "--test_file",
+ type=str,
+ default=None,
+ help="Path to the test data file",
+ )
+ parser.add_argument(
"--random_seed",
type=int,
default=42,
help="Random seed for PyCaret setup",
)
parser.add_argument(
- "--test_file",
- type=str,
+ "--probability_threshold",
+ type=float,
default=None,
- help="Path to the test data file",
+ help="Probability threshold for classification decision,",
)
args = parser.parse_args()
@@ -120,7 +126,7 @@
# Normalize cross-validation flags: --no_cross_validation overrides --cross_validation
if args.no_cross_validation:
args.cross_validation = False
- # If --cross_validation was passed, args.cross_validation is True
+ # If --cross_validation was passed, args.cross_validation is True
# If neither was passed, args.cross_validation remains None
# Build the model_kwargs dict from CLI args
@@ -137,6 +143,7 @@
"feature_ratio": args.feature_ratio,
"fix_imbalance": args.fix_imbalance,
"tune_model": args.tune_model,
+ "probability_threshold": args.probability_threshold,
}
LOG.info(f"Model kwargs: {model_kwargs}")
diff -r 11fdac5affb3 -r 3d42f82b3c7f tabular_learner.xml
--- a/tabular_learner.xml Fri Jul 25 19:02:12 2025 +0000
+++ b/tabular_learner.xml Thu Jul 31 15:41:07 2025 +0000
@@ -22,10 +22,10 @@
#end if
#if $customize_defaults == "true"
#if $train_size
- --train_size '$train_size'
+ --train_size '$train_size'
#end if
#if $normalize
- --normalize
+ --normalize
#end if
#if $feature_selection
--feature_selection
@@ -34,27 +34,30 @@
--cross_validation
#if $cross_validation_folds
--cross_validation_folds '$cross_validation_folds'
- #end if
+ #end if
#end if
#if $enable_cross_validation == "false"
--no_cross_validation
#end if
#if $remove_outliers
- --remove_outliers
+ --remove_outliers
#end if
#if $remove_multicollinearity
- --remove_multicollinearity
+ --remove_multicollinearity
#end if
#if $polynomial_features
- --polynomial_features
+ --polynomial_features
#end if
#if $fix_imbalance
- --fix_imbalance
+ --fix_imbalance
+ #end if
+ #if $probability_threshold
+ --probability_threshold '$probability_threshold'
#end if
#end if
#if $test_file
--test_file '$test_file'
- #end if
+ #end if
--model_type '$model_type'
]]>
@@ -150,6 +153,7 @@
+
@@ -175,6 +179,7 @@
+