Mercurial > repos > goeckslab > image_learner
comparison image_learner_cli.py @ 1:39202fe5cf97 draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit 06c0da44ac93256dfb616a6b40276b5485a71e8e
| author | goeckslab |
|---|---|
| date | Wed, 02 Jul 2025 18:59:10 +0000 |
| parents | 54b871dfc51e |
| children | 186424a7eca7 |
comparison
equal
deleted
inserted
replaced
| 0:54b871dfc51e | 1:39202fe5cf97 |
|---|---|
| 22 from ludwig.visualize import get_visualizations_registry | 22 from ludwig.visualize import get_visualizations_registry |
| 23 from sklearn.model_selection import train_test_split | 23 from sklearn.model_selection import train_test_split |
| 24 from utils import encode_image_to_base64, get_html_closing, get_html_template | 24 from utils import encode_image_to_base64, get_html_closing, get_html_template |
| 25 | 25 |
| 26 # --- Constants --- | 26 # --- Constants --- |
| 27 SPLIT_COLUMN_NAME = 'split' | 27 SPLIT_COLUMN_NAME = "split" |
| 28 LABEL_COLUMN_NAME = 'label' | 28 LABEL_COLUMN_NAME = "label" |
| 29 IMAGE_PATH_COLUMN_NAME = 'image_path' | 29 IMAGE_PATH_COLUMN_NAME = "image_path" |
| 30 DEFAULT_SPLIT_PROBABILITIES = [0.7, 0.1, 0.2] | 30 DEFAULT_SPLIT_PROBABILITIES = [0.7, 0.1, 0.2] |
| 31 TEMP_CSV_FILENAME = "processed_data_for_ludwig.csv" | 31 TEMP_CSV_FILENAME = "processed_data_for_ludwig.csv" |
| 32 TEMP_CONFIG_FILENAME = "ludwig_config.yaml" | 32 TEMP_CONFIG_FILENAME = "ludwig_config.yaml" |
| 33 TEMP_DIR_PREFIX = "ludwig_api_work_" | 33 TEMP_DIR_PREFIX = "ludwig_api_work_" |
| 34 MODEL_ENCODER_TEMPLATES: Dict[str, Any] = { | 34 MODEL_ENCODER_TEMPLATES: Dict[str, Any] = { |
| 35 'stacked_cnn': 'stacked_cnn', | 35 "stacked_cnn": "stacked_cnn", |
| 36 'resnet18': {'type': 'resnet', 'model_variant': 18}, | 36 "resnet18": {"type": "resnet", "model_variant": 18}, |
| 37 'resnet34': {'type': 'resnet', 'model_variant': 34}, | 37 "resnet34": {"type": "resnet", "model_variant": 34}, |
| 38 'resnet50': {'type': 'resnet', 'model_variant': 50}, | 38 "resnet50": {"type": "resnet", "model_variant": 50}, |
| 39 'resnet101': {'type': 'resnet', 'model_variant': 101}, | 39 "resnet101": {"type": "resnet", "model_variant": 101}, |
| 40 'resnet152': {'type': 'resnet', 'model_variant': 152}, | 40 "resnet152": {"type": "resnet", "model_variant": 152}, |
| 41 'resnext50_32x4d': {'type': 'resnext', 'model_variant': '50_32x4d'}, | 41 "resnext50_32x4d": {"type": "resnext", "model_variant": "50_32x4d"}, |
| 42 'resnext101_32x8d': {'type': 'resnext', 'model_variant': '101_32x8d'}, | 42 "resnext101_32x8d": {"type": "resnext", "model_variant": "101_32x8d"}, |
| 43 'resnext101_64x4d': {'type': 'resnext', 'model_variant': '101_64x4d'}, | 43 "resnext101_64x4d": {"type": "resnext", "model_variant": "101_64x4d"}, |
| 44 'resnext152_32x8d': {'type': 'resnext', 'model_variant': '152_32x8d'}, | 44 "resnext152_32x8d": {"type": "resnext", "model_variant": "152_32x8d"}, |
| 45 'wide_resnet50_2': {'type': 'wide_resnet', 'model_variant': '50_2'}, | 45 "wide_resnet50_2": {"type": "wide_resnet", "model_variant": "50_2"}, |
| 46 'wide_resnet101_2': {'type': 'wide_resnet', 'model_variant': '101_2'}, | 46 "wide_resnet101_2": {"type": "wide_resnet", "model_variant": "101_2"}, |
| 47 'wide_resnet103_2': {'type': 'wide_resnet', 'model_variant': '103_2'}, | 47 "wide_resnet103_2": {"type": "wide_resnet", "model_variant": "103_2"}, |
| 48 'efficientnet_b0': {'type': 'efficientnet', 'model_variant': 'b0'}, | 48 "efficientnet_b0": {"type": "efficientnet", "model_variant": "b0"}, |
| 49 'efficientnet_b1': {'type': 'efficientnet', 'model_variant': 'b1'}, | 49 "efficientnet_b1": {"type": "efficientnet", "model_variant": "b1"}, |
| 50 'efficientnet_b2': {'type': 'efficientnet', 'model_variant': 'b2'}, | 50 "efficientnet_b2": {"type": "efficientnet", "model_variant": "b2"}, |
| 51 'efficientnet_b3': {'type': 'efficientnet', 'model_variant': 'b3'}, | 51 "efficientnet_b3": {"type": "efficientnet", "model_variant": "b3"}, |
| 52 'efficientnet_b4': {'type': 'efficientnet', 'model_variant': 'b4'}, | 52 "efficientnet_b4": {"type": "efficientnet", "model_variant": "b4"}, |
| 53 'efficientnet_b5': {'type': 'efficientnet', 'model_variant': 'b5'}, | 53 "efficientnet_b5": {"type": "efficientnet", "model_variant": "b5"}, |
| 54 'efficientnet_b6': {'type': 'efficientnet', 'model_variant': 'b6'}, | 54 "efficientnet_b6": {"type": "efficientnet", "model_variant": "b6"}, |
| 55 'efficientnet_b7': {'type': 'efficientnet', 'model_variant': 'b7'}, | 55 "efficientnet_b7": {"type": "efficientnet", "model_variant": "b7"}, |
| 56 'efficientnet_v2_s': {'type': 'efficientnet', 'model_variant': 'v2_s'}, | 56 "efficientnet_v2_s": {"type": "efficientnet", "model_variant": "v2_s"}, |
| 57 'efficientnet_v2_m': {'type': 'efficientnet', 'model_variant': 'v2_m'}, | 57 "efficientnet_v2_m": {"type": "efficientnet", "model_variant": "v2_m"}, |
| 58 'efficientnet_v2_l': {'type': 'efficientnet', 'model_variant': 'v2_l'}, | 58 "efficientnet_v2_l": {"type": "efficientnet", "model_variant": "v2_l"}, |
| 59 'regnet_y_400mf': {'type': 'regnet', 'model_variant': 'y_400mf'}, | 59 "regnet_y_400mf": {"type": "regnet", "model_variant": "y_400mf"}, |
| 60 'regnet_y_800mf': {'type': 'regnet', 'model_variant': 'y_800mf'}, | 60 "regnet_y_800mf": {"type": "regnet", "model_variant": "y_800mf"}, |
| 61 'regnet_y_1_6gf': {'type': 'regnet', 'model_variant': 'y_1_6gf'}, | 61 "regnet_y_1_6gf": {"type": "regnet", "model_variant": "y_1_6gf"}, |
| 62 'regnet_y_3_2gf': {'type': 'regnet', 'model_variant': 'y_3_2gf'}, | 62 "regnet_y_3_2gf": {"type": "regnet", "model_variant": "y_3_2gf"}, |
| 63 'regnet_y_8gf': {'type': 'regnet', 'model_variant': 'y_8gf'}, | 63 "regnet_y_8gf": {"type": "regnet", "model_variant": "y_8gf"}, |
| 64 'regnet_y_16gf': {'type': 'regnet', 'model_variant': 'y_16gf'}, | 64 "regnet_y_16gf": {"type": "regnet", "model_variant": "y_16gf"}, |
| 65 'regnet_y_32gf': {'type': 'regnet', 'model_variant': 'y_32gf'}, | 65 "regnet_y_32gf": {"type": "regnet", "model_variant": "y_32gf"}, |
| 66 'regnet_y_128gf': {'type': 'regnet', 'model_variant': 'y_128gf'}, | 66 "regnet_y_128gf": {"type": "regnet", "model_variant": "y_128gf"}, |
| 67 'regnet_x_400mf': {'type': 'regnet', 'model_variant': 'x_400mf'}, | 67 "regnet_x_400mf": {"type": "regnet", "model_variant": "x_400mf"}, |
| 68 'regnet_x_800mf': {'type': 'regnet', 'model_variant': 'x_800mf'}, | 68 "regnet_x_800mf": {"type": "regnet", "model_variant": "x_800mf"}, |
| 69 'regnet_x_1_6gf': {'type': 'regnet', 'model_variant': 'x_1_6gf'}, | 69 "regnet_x_1_6gf": {"type": "regnet", "model_variant": "x_1_6gf"}, |
| 70 'regnet_x_3_2gf': {'type': 'regnet', 'model_variant': 'x_3_2gf'}, | 70 "regnet_x_3_2gf": {"type": "regnet", "model_variant": "x_3_2gf"}, |
| 71 'regnet_x_8gf': {'type': 'regnet', 'model_variant': 'x_8gf'}, | 71 "regnet_x_8gf": {"type": "regnet", "model_variant": "x_8gf"}, |
| 72 'regnet_x_16gf': {'type': 'regnet', 'model_variant': 'x_16gf'}, | 72 "regnet_x_16gf": {"type": "regnet", "model_variant": "x_16gf"}, |
| 73 'regnet_x_32gf': {'type': 'regnet', 'model_variant': 'x_32gf'}, | 73 "regnet_x_32gf": {"type": "regnet", "model_variant": "x_32gf"}, |
| 74 'vgg11': {'type': 'vgg', 'model_variant': 11}, | 74 "vgg11": {"type": "vgg", "model_variant": 11}, |
| 75 'vgg11_bn': {'type': 'vgg', 'model_variant': '11_bn'}, | 75 "vgg11_bn": {"type": "vgg", "model_variant": "11_bn"}, |
| 76 'vgg13': {'type': 'vgg', 'model_variant': 13}, | 76 "vgg13": {"type": "vgg", "model_variant": 13}, |
| 77 'vgg13_bn': {'type': 'vgg', 'model_variant': '13_bn'}, | 77 "vgg13_bn": {"type": "vgg", "model_variant": "13_bn"}, |
| 78 'vgg16': {'type': 'vgg', 'model_variant': 16}, | 78 "vgg16": {"type": "vgg", "model_variant": 16}, |
| 79 'vgg16_bn': {'type': 'vgg', 'model_variant': '16_bn'}, | 79 "vgg16_bn": {"type": "vgg", "model_variant": "16_bn"}, |
| 80 'vgg19': {'type': 'vgg', 'model_variant': 19}, | 80 "vgg19": {"type": "vgg", "model_variant": 19}, |
| 81 'vgg19_bn': {'type': 'vgg', 'model_variant': '19_bn'}, | 81 "vgg19_bn": {"type": "vgg", "model_variant": "19_bn"}, |
| 82 'shufflenet_v2_x0_5': {'type': 'shufflenet_v2', 'model_variant': 'x0_5'}, | 82 "shufflenet_v2_x0_5": {"type": "shufflenet_v2", "model_variant": "x0_5"}, |
| 83 'shufflenet_v2_x1_0': {'type': 'shufflenet_v2', 'model_variant': 'x1_0'}, | 83 "shufflenet_v2_x1_0": {"type": "shufflenet_v2", "model_variant": "x1_0"}, |
| 84 'shufflenet_v2_x1_5': {'type': 'shufflenet_v2', 'model_variant': 'x1_5'}, | 84 "shufflenet_v2_x1_5": {"type": "shufflenet_v2", "model_variant": "x1_5"}, |
| 85 'shufflenet_v2_x2_0': {'type': 'shufflenet_v2', 'model_variant': 'x2_0'}, | 85 "shufflenet_v2_x2_0": {"type": "shufflenet_v2", "model_variant": "x2_0"}, |
| 86 'squeezenet1_0': {'type': 'squeezenet', 'model_variant': '1_0'}, | 86 "squeezenet1_0": {"type": "squeezenet", "model_variant": "1_0"}, |
| 87 'squeezenet1_1': {'type': 'squeezenet', 'model_variant': '1_1'}, | 87 "squeezenet1_1": {"type": "squeezenet", "model_variant": "1_1"}, |
| 88 'swin_t': {'type': 'swin_transformer', 'model_variant': 't'}, | 88 "swin_t": {"type": "swin_transformer", "model_variant": "t"}, |
| 89 'swin_s': {'type': 'swin_transformer', 'model_variant': 's'}, | 89 "swin_s": {"type": "swin_transformer", "model_variant": "s"}, |
| 90 'swin_b': {'type': 'swin_transformer', 'model_variant': 'b'}, | 90 "swin_b": {"type": "swin_transformer", "model_variant": "b"}, |
| 91 'swin_v2_t': {'type': 'swin_transformer', 'model_variant': 'v2_t'}, | 91 "swin_v2_t": {"type": "swin_transformer", "model_variant": "v2_t"}, |
| 92 'swin_v2_s': {'type': 'swin_transformer', 'model_variant': 'v2_s'}, | 92 "swin_v2_s": {"type": "swin_transformer", "model_variant": "v2_s"}, |
| 93 'swin_v2_b': {'type': 'swin_transformer', 'model_variant': 'v2_b'}, | 93 "swin_v2_b": {"type": "swin_transformer", "model_variant": "v2_b"}, |
| 94 'vit_b_16': {'type': 'vision_transformer', 'model_variant': 'b_16'}, | 94 "vit_b_16": {"type": "vision_transformer", "model_variant": "b_16"}, |
| 95 'vit_b_32': {'type': 'vision_transformer', 'model_variant': 'b_32'}, | 95 "vit_b_32": {"type": "vision_transformer", "model_variant": "b_32"}, |
| 96 'vit_l_16': {'type': 'vision_transformer', 'model_variant': 'l_16'}, | 96 "vit_l_16": {"type": "vision_transformer", "model_variant": "l_16"}, |
| 97 'vit_l_32': {'type': 'vision_transformer', 'model_variant': 'l_32'}, | 97 "vit_l_32": {"type": "vision_transformer", "model_variant": "l_32"}, |
| 98 'vit_h_14': {'type': 'vision_transformer', 'model_variant': 'h_14'}, | 98 "vit_h_14": {"type": "vision_transformer", "model_variant": "h_14"}, |
| 99 'convnext_tiny': {'type': 'convnext', 'model_variant': 'tiny'}, | 99 "convnext_tiny": {"type": "convnext", "model_variant": "tiny"}, |
| 100 'convnext_small': {'type': 'convnext', 'model_variant': 'small'}, | 100 "convnext_small": {"type": "convnext", "model_variant": "small"}, |
| 101 'convnext_base': {'type': 'convnext', 'model_variant': 'base'}, | 101 "convnext_base": {"type": "convnext", "model_variant": "base"}, |
| 102 'convnext_large': {'type': 'convnext', 'model_variant': 'large'}, | 102 "convnext_large": {"type": "convnext", "model_variant": "large"}, |
| 103 'maxvit_t': {'type': 'maxvit', 'model_variant': 't'}, | 103 "maxvit_t": {"type": "maxvit", "model_variant": "t"}, |
| 104 'alexnet': {'type': 'alexnet'}, | 104 "alexnet": {"type": "alexnet"}, |
| 105 'googlenet': {'type': 'googlenet'}, | 105 "googlenet": {"type": "googlenet"}, |
| 106 'inception_v3': {'type': 'inception_v3'}, | 106 "inception_v3": {"type": "inception_v3"}, |
| 107 'mobilenet_v2': {'type': 'mobilenet_v2'}, | 107 "mobilenet_v2": {"type": "mobilenet_v2"}, |
| 108 'mobilenet_v3_large': {'type': 'mobilenet_v3_large'}, | 108 "mobilenet_v3_large": {"type": "mobilenet_v3_large"}, |
| 109 'mobilenet_v3_small': {'type': 'mobilenet_v3_small'}, | 109 "mobilenet_v3_small": {"type": "mobilenet_v3_small"}, |
| 110 } | |
| 111 METRIC_DISPLAY_NAMES = { | |
| 112 "accuracy": "Accuracy", | |
| 113 "accuracy_micro": "Accuracy-Micro", | |
| 114 "loss": "Loss", | |
| 115 "roc_auc": "ROC-AUC", | |
| 116 "roc_auc_macro": "ROC-AUC-Macro", | |
| 117 "roc_auc_micro": "ROC-AUC-Micro", | |
| 118 "hits_at_k": "Hits at K", | |
| 119 "precision": "Precision", | |
| 120 "recall": "Recall", | |
| 121 "specificity": "Specificity", | |
| 122 "kappa_score": "Cohen's Kappa", | |
| 123 "token_accuracy": "Token Accuracy", | |
| 124 "avg_precision_macro": "Precision-Macro", | |
| 125 "avg_recall_macro": "Recall-Macro", | |
| 126 "avg_f1_score_macro": "F1-score-Macro", | |
| 127 "avg_precision_micro": "Precision-Micro", | |
| 128 "avg_recall_micro": "Recall-Micro", | |
| 129 "avg_f1_score_micro": "F1-score-Micro", | |
| 130 "avg_precision_weighted": "Precision-Weighted", | |
| 131 "avg_recall_weighted": "Recall-Weighted", | |
| 132 "avg_f1_score_weighted": "F1-score-Weighted", | |
| 133 "average_precision_macro": " Precision-Average-Macro", | |
| 134 "average_precision_micro": "Precision-Average-Micro", | |
| 135 "average_precision_samples": "Precision-Average-Samples", | |
| 110 } | 136 } |
| 111 | 137 |
| 112 # --- Logging Setup --- | 138 # --- Logging Setup --- |
| 113 logging.basicConfig( | 139 logging.basicConfig( |
| 114 level=logging.INFO, | 140 level=logging.INFO, |
| 115 format='%(asctime)s %(levelname)s %(name)s: %(message)s' | 141 format="%(asctime)s %(levelname)s %(name)s: %(message)s", |
| 116 ) | 142 ) |
| 117 logger = logging.getLogger("ImageLearner") | 143 logger = logging.getLogger("ImageLearner") |
| 118 | 144 |
| 119 | 145 |
| 146 def get_metrics_help_modal() -> str: | |
| 147 modal_html = """ | |
| 148 <div id="metricsHelpModal" class="modal"> | |
| 149 <div class="modal-content"> | |
| 150 <span class="close">×</span> | |
| 151 <h2>Model Evaluation Metrics — Help Guide</h2> | |
| 152 <div class="metrics-guide"> | |
| 153 <h3>1) General Metrics</h3> | |
| 154 <p><strong>Loss:</strong> Measures the difference between predicted and actual values. Lower is better. Often used for optimization during training.</p> | |
| 155 <p><strong>Accuracy:</strong> Proportion of correct predictions among all predictions. Simple but can be misleading for imbalanced datasets.</p> | |
| 156 <p><strong>Micro Accuracy:</strong> Calculates accuracy by summing up all individual true positives and true negatives across all classes, making it suitable for multiclass or multilabel problems.</p> | |
| 157 <p><strong>Token Accuracy:</strong> Measures how often the predicted tokens (e.g., in sequences) match the true tokens. Useful in sequence prediction tasks like NLP.</p> | |
| 158 <h3>2) Precision, Recall & Specificity</h3> | |
| 159 <p><strong>Precision:</strong> Out of all positive predictions, how many were correct. Precision = TP / (TP + FP). Helps when false positives are costly.</p> | |
| 160 <p><strong>Recall (Sensitivity):</strong> Out of all actual positives, how many were predicted correctly. Recall = TP / (TP + FN). Important when missing positives is risky.</p> | |
| 161 <p><strong>Specificity:</strong> True negative rate. Measures how well the model identifies negatives. Specificity = TN / (TN + FP). Useful in medical testing to avoid false alarms.</p> | |
| 162 <h3>3) Macro, Micro, and Weighted Averages</h3> | |
| 163 <p><strong>Macro Precision / Recall / F1:</strong> Averages the metric across all classes, treating each class equally, regardless of class frequency. Best when class sizes are balanced.</p> | |
| 164 <p><strong>Micro Precision / Recall / F1:</strong> Aggregates TP, FP, FN across all classes before computing the metric. Gives a global view and is ideal for class-imbalanced problems.</p> | |
| 165 <p><strong>Weighted Precision / Recall / F1:</strong> Averages each metric across classes, weighted by the number of true instances per class. Balances importance of classes based on frequency.</p> | |
| 166 <h3>4) Average Precision (PR-AUC Variants)</h3> | |
| 167 <p><strong>Average Precision Macro:</strong> Precision-Recall AUC averaged across all classes equally. Useful for balanced multi-class problems.</p> | |
| 168 <p><strong>Average Precision Micro:</strong> Global Precision-Recall AUC using all instances. Best for imbalanced data or multi-label classification.</p> | |
| 169 <p><strong>Average Precision Samples:</strong> Precision-Recall AUC averaged across individual samples (not classes). Ideal for multi-label problems where each sample can belong to multiple classes.</p> | |
| 170 <h3>5) ROC-AUC Variants</h3> | |
| 171 <p><strong>ROC-AUC:</strong> Measures model's ability to distinguish between classes. AUC = 1 is perfect; 0.5 is random guessing. Use for binary classification.</p> | |
| 172 <p><strong>Macro ROC-AUC:</strong> Averages the AUC across all classes equally. Suitable when classes are balanced and of equal importance.</p> | |
| 173 <p><strong>Micro ROC-AUC:</strong> Computes AUC from aggregated predictions across all classes. Useful in multiclass or multilabel settings with imbalance.</p> | |
| 174 <h3>6) Ranking Metrics</h3> | |
| 175 <p><strong>Hits at K:</strong> Measures whether the true label is among the top-K predictions. Common in recommendation systems and retrieval tasks.</p> | |
| 176 <h3>7) Confusion Matrix Stats (Per Class)</h3> | |
| 177 <p><strong>True Positives / Negatives (TP / TN):</strong> Correct predictions for positives and negatives respectively.</p> | |
| 178 <p><strong>False Positives / Negatives (FP / FN):</strong> Incorrect predictions — false alarms and missed detections.</p> | |
| 179 <h3>8) Other Useful Metrics</h3> | |
| 180 <p><strong>Cohen's Kappa:</strong> Measures agreement between predicted and actual values adjusted for chance. Useful for multiclass classification with imbalanced labels.</p> | |
| 181 <p><strong>Matthews Correlation Coefficient (MCC):</strong> Balanced measure of prediction quality that takes into account TP, TN, FP, and FN. Particularly effective for imbalanced datasets.</p> | |
| 182 <h3>9) Metric Recommendations</h3> | |
| 183 <ul> | |
| 184 <li>Use <strong>Accuracy + F1</strong> for balanced data.</li> | |
| 185 <li>Use <strong>Precision, Recall, ROC-AUC</strong> for imbalanced datasets.</li> | |
| 186 <li>Use <strong>Average Precision Micro</strong> for multilabel or class-imbalanced problems.</li> | |
| 187 <li>Use <strong>Macro scores</strong> when all classes should be treated equally.</li> | |
| 188 <li>Use <strong>Weighted scores</strong> when class imbalance should be accounted for without ignoring small classes.</li> | |
| 189 <li>Use <strong>Confusion Matrix stats</strong> to analyze class-wise performance.</li> | |
| 190 <li>Use <strong>Hits at K</strong> for recommendation or ranking-based tasks.</li> | |
| 191 </ul> | |
| 192 </div> | |
| 193 </div> | |
| 194 </div> | |
| 195 """ | |
| 196 modal_css = """ | |
| 197 <style> | |
| 198 .modal { | |
| 199 display: none; | |
| 200 position: fixed; | |
| 201 z-index: 1; | |
| 202 left: 0; | |
| 203 top: 0; | |
| 204 width: 100%; | |
| 205 height: 100%; | |
| 206 overflow: auto; | |
| 207 background-color: rgba(0,0,0,0.4); | |
| 208 } | |
| 209 .modal-content { | |
| 210 background-color: #fefefe; | |
| 211 margin: 15% auto; | |
| 212 padding: 20px; | |
| 213 border: 1px solid #888; | |
| 214 width: 80%; | |
| 215 max-width: 800px; | |
| 216 } | |
| 217 .close { | |
| 218 color: #aaa; | |
| 219 float: right; | |
| 220 font-size: 28px; | |
| 221 font-weight: bold; | |
| 222 } | |
| 223 .close:hover, | |
| 224 .close:focus { | |
| 225 color: black; | |
| 226 text-decoration: none; | |
| 227 cursor: pointer; | |
| 228 } | |
| 229 .metrics-guide h3 { | |
| 230 margin-top: 20px; | |
| 231 } | |
| 232 .metrics-guide p { | |
| 233 margin: 5px 0; | |
| 234 } | |
| 235 .metrics-guide ul { | |
| 236 margin: 10px 0; | |
| 237 padding-left: 20px; | |
| 238 } | |
| 239 </style> | |
| 240 """ | |
| 241 modal_js = """ | |
| 242 <script> | |
| 243 document.addEventListener("DOMContentLoaded", function() { | |
| 244 var modal = document.getElementById("metricsHelpModal"); | |
| 245 var closeBtn = document.getElementsByClassName("close")[0]; | |
| 246 | |
| 247 document.querySelectorAll(".openMetricsHelp").forEach(btn => { | |
| 248 btn.onclick = function() { | |
| 249 modal.style.display = "block"; | |
| 250 }; | |
| 251 }); | |
| 252 | |
| 253 if (closeBtn) { | |
| 254 closeBtn.onclick = function() { | |
| 255 modal.style.display = "none"; | |
| 256 }; | |
| 257 } | |
| 258 | |
| 259 window.onclick = function(event) { | |
| 260 if (event.target == modal) { | |
| 261 modal.style.display = "none"; | |
| 262 } | |
| 263 } | |
| 264 }); | |
| 265 </script> | |
| 266 """ | |
| 267 return modal_css + modal_html + modal_js | |
| 268 | |
| 269 | |
| 120 def format_config_table_html( | 270 def format_config_table_html( |
| 121 config: dict, | 271 config: dict, |
| 122 split_info: Optional[str] = None, | 272 split_info: Optional[str] = None, |
| 123 training_progress: dict = None) -> str: | 273 training_progress: dict = None, |
| 274 ) -> str: | |
| 124 display_keys = [ | 275 display_keys = [ |
| 125 "model_name", | 276 "model_name", |
| 126 "epochs", | 277 "epochs", |
| 127 "batch_size", | 278 "batch_size", |
| 128 "fine_tune", | 279 "fine_tune", |
| 141 val = int(val) | 292 val = int(val) |
| 142 else: | 293 else: |
| 143 if training_progress: | 294 if training_progress: |
| 144 val = "Auto-selected batch size by Ludwig:<br>" | 295 val = "Auto-selected batch size by Ludwig:<br>" |
| 145 resolved_val = training_progress.get("batch_size") | 296 resolved_val = training_progress.get("batch_size") |
| 146 val += ( | 297 val += f"<span style='font-size: 0.85em;'>{resolved_val}</span><br>" |
| 147 f"<span style='font-size: 0.85em;'>{resolved_val}</span><br>" | |
| 148 ) | |
| 149 else: | 298 else: |
| 150 val = "auto" | 299 val = "auto" |
| 151 if key == "learning_rate": | 300 if key == "learning_rate": |
| 152 resolved_val = None | 301 resolved_val = None |
| 153 if val is None or val == "auto": | 302 if val is None or val == "auto": |
| 154 if training_progress: | 303 if training_progress: |
| 155 resolved_val = training_progress.get("learning_rate") | 304 resolved_val = training_progress.get("learning_rate") |
| 156 val = ( | 305 val = ( |
| 157 "Auto-selected learning rate by Ludwig:<br>" | 306 "Auto-selected learning rate by Ludwig:<br>" |
| 158 f"<span style='font-size: 0.85em;'>{resolved_val if resolved_val else val}</span><br>" | 307 f"<span style='font-size: 0.85em;'>" |
| 308 f"{resolved_val if resolved_val else val}</span><br>" | |
| 159 "<span style='font-size: 0.85em;'>" | 309 "<span style='font-size: 0.85em;'>" |
| 160 "Based on model architecture and training setup (e.g., fine-tuning).<br>" | 310 "Based on model architecture and training setup " |
| 161 "See <a href='https://ludwig.ai/latest/configuration/trainer/#trainer-parameters' " | 311 "(e.g., fine-tuning).<br>" |
| 162 "target='_blank'>Ludwig Trainer Parameters</a> for details." | 312 "See <a href='https://ludwig.ai/latest/configuration/trainer/" |
| 313 "#trainer-parameters' target='_blank'>" | |
| 314 "Ludwig Trainer Parameters</a> for details." | |
| 163 "</span>" | 315 "</span>" |
| 164 ) | 316 ) |
| 165 else: | 317 else: |
| 166 val = ( | 318 val = ( |
| 167 "Auto-selected by Ludwig<br>" | 319 "Auto-selected by Ludwig<br>" |
| 168 "<span style='font-size: 0.85em;'>" | 320 "<span style='font-size: 0.85em;'>" |
| 169 "Automatically tuned based on architecture and dataset.<br>" | 321 "Automatically tuned based on architecture and dataset.<br>" |
| 170 "See <a href='https://ludwig.ai/latest/configuration/trainer/#trainer-parameters' " | 322 "See <a href='https://ludwig.ai/latest/configuration/trainer/" |
| 171 "target='_blank'>Ludwig Trainer Parameters</a> for details." | 323 "#trainer-parameters' target='_blank'>" |
| 324 "Ludwig Trainer Parameters</a> for details." | |
| 172 "</span>" | 325 "</span>" |
| 173 ) | 326 ) |
| 174 else: | 327 else: |
| 175 val = f"{val:.6f}" | 328 val = f"{val:.6f}" |
| 176 if key == "epochs": | 329 if key == "epochs": |
| 177 if training_progress and "epoch" in training_progress and val > training_progress["epoch"]: | 330 if ( |
| 331 training_progress | |
| 332 and "epoch" in training_progress | |
| 333 and val > training_progress["epoch"] | |
| 334 ): | |
| 178 val = ( | 335 val = ( |
| 179 f"Because of early stopping: the training" | 336 f"Because of early stopping: the training " |
| 180 f"stopped at epoch {training_progress['epoch']}" | 337 f"stopped at epoch {training_progress['epoch']}" |
| 181 ) | 338 ) |
| 182 | 339 |
| 183 if val is None: | 340 if val is None: |
| 184 continue | 341 continue |
| 185 rows.append( | 342 rows.append( |
| 186 f"<tr>" | 343 f"<tr>" |
| 187 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" | 344 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" |
| 188 f"{key.replace('_', ' ').title()}</td>" | 345 f"{key.replace('_', ' ').title()}</td>" |
| 189 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{val}</td>" | 346 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" |
| 347 f"{val}</td>" | |
| 190 f"</tr>" | 348 f"</tr>" |
| 191 ) | 349 ) |
| 192 | 350 |
| 193 if split_info: | 351 if split_info: |
| 194 rows.append( | 352 rows.append( |
| 195 f"<tr>" | 353 f"<tr>" |
| 196 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Data Split</td>" | 354 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" |
| 197 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{split_info}</td>" | 355 f"Data Split</td>" |
| 356 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" | |
| 357 f"{split_info}</td>" | |
| 198 f"</tr>" | 358 f"</tr>" |
| 199 ) | 359 ) |
| 200 | 360 |
| 201 return ( | 361 return ( |
| 202 "<h2 style='text-align: center;'>Training Setup</h2>" | 362 "<h2 style='text-align: center;'>Training Setup</h2>" |
| 203 "<div style='display: flex; justify-content: center;'>" | 363 "<div style='display: flex; justify-content: center;'>" |
| 204 "<table style='border-collapse: collapse; width: 60%; table-layout: auto;'>" | 364 "<table style='border-collapse: collapse; width: 60%; table-layout: auto;'>" |
| 205 "<thead><tr>" | 365 "<thead><tr>" |
| 206 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>Parameter</th>" | 366 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>" |
| 207 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Value</th>" | 367 "Parameter</th>" |
| 368 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>" | |
| 369 "Value</th>" | |
| 208 "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>" | 370 "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>" |
| 209 "<p style='text-align: center; font-size: 0.9em;'>" | 371 "<p style='text-align: center; font-size: 0.9em;'>" |
| 210 "Model trained using Ludwig.<br>" | 372 "Model trained using Ludwig.<br>" |
| 211 "If want to learn more about Ludwig default settings," | 373 "If want to learn more about Ludwig default settings," |
| 212 "please check the their <a href='https://ludwig.ai' target='_blank'>website(ludwig.ai)</a>." | 374 "please check the their <a href='https://ludwig.ai' target='_blank'>" |
| 375 "website(ludwig.ai)</a>." | |
| 213 "</p><hr>" | 376 "</p><hr>" |
| 214 ) | 377 ) |
| 215 | 378 |
| 216 | 379 |
| 217 def format_stats_table_html(training_stats: dict, test_stats: dict) -> str: | 380 def detect_output_type(test_stats): |
| 218 train_metrics = training_stats.get("training", {}).get("label", {}) | 381 """Detects if the output type is 'binary' or 'category' based on test statistics.""" |
| 219 val_metrics = training_stats.get("validation", {}).get("label", {}) | 382 label_stats = test_stats.get("label", {}) |
| 220 test_metrics = test_stats.get("label", {}) | 383 per_class = label_stats.get("per_class_stats", {}) |
| 221 | 384 if len(per_class) == 2: |
| 222 all_metrics = set(train_metrics) | set(val_metrics) | set(test_metrics) | 385 return "binary" |
| 386 return "category" | |
| 387 | |
| 388 | |
| 389 def extract_metrics_from_json( | |
| 390 train_stats: dict, | |
| 391 test_stats: dict, | |
| 392 output_type: str, | |
| 393 ) -> dict: | |
| 394 """Extracts relevant metrics from training and test statistics based on the output type.""" | |
| 395 metrics = {"training": {}, "validation": {}, "test": {}} | |
| 223 | 396 |
| 224 def get_last_value(stats, key): | 397 def get_last_value(stats, key): |
| 225 val = stats.get(key) | 398 val = stats.get(key) |
| 226 if isinstance(val, list) and val: | 399 if isinstance(val, list) and val: |
| 227 return val[-1] | 400 return val[-1] |
| 228 elif isinstance(val, (int, float)): | 401 elif isinstance(val, (int, float)): |
| 229 return val | 402 return val |
| 230 return None | 403 return None |
| 231 | 404 |
| 405 for split in ["training", "validation"]: | |
| 406 split_stats = train_stats.get(split, {}) | |
| 407 if not split_stats: | |
| 408 logging.warning(f"No statistics found for {split} split") | |
| 409 continue | |
| 410 label_stats = split_stats.get("label", {}) | |
| 411 if not label_stats: | |
| 412 logging.warning(f"No label statistics found for {split} split") | |
| 413 continue | |
| 414 if output_type == "binary": | |
| 415 metrics[split] = { | |
| 416 "accuracy": get_last_value(label_stats, "accuracy"), | |
| 417 "loss": get_last_value(label_stats, "loss"), | |
| 418 "precision": get_last_value(label_stats, "precision"), | |
| 419 "recall": get_last_value(label_stats, "recall"), | |
| 420 "specificity": get_last_value(label_stats, "specificity"), | |
| 421 "roc_auc": get_last_value(label_stats, "roc_auc"), | |
| 422 } | |
| 423 else: | |
| 424 metrics[split] = { | |
| 425 "accuracy": get_last_value(label_stats, "accuracy"), | |
| 426 "accuracy_micro": get_last_value(label_stats, "accuracy_micro"), | |
| 427 "loss": get_last_value(label_stats, "loss"), | |
| 428 "roc_auc": get_last_value(label_stats, "roc_auc"), | |
| 429 "hits_at_k": get_last_value(label_stats, "hits_at_k"), | |
| 430 } | |
| 431 | |
| 432 # Test metrics: dynamic extraction according to exclusions | |
| 433 test_label_stats = test_stats.get("label", {}) | |
| 434 if not test_label_stats: | |
| 435 logging.warning("No label statistics found for test split") | |
| 436 else: | |
| 437 combined_stats = test_stats.get("combined", {}) | |
| 438 overall_stats = test_label_stats.get("overall_stats", {}) | |
| 439 | |
| 440 # Define exclusions | |
| 441 if output_type == "binary": | |
| 442 exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"} | |
| 443 else: | |
| 444 exclude = {"per_class_stats", "confusion_matrix"} | |
| 445 | |
| 446 # 1. Get all scalar test_label_stats not excluded | |
| 447 test_metrics = {} | |
| 448 for k, v in test_label_stats.items(): | |
| 449 if k in exclude: | |
| 450 continue | |
| 451 if k == "overall_stats": | |
| 452 continue | |
| 453 if isinstance(v, (int, float, str, bool)): | |
| 454 test_metrics[k] = v | |
| 455 | |
| 456 # 2. Add overall_stats (flattened) | |
| 457 for k, v in overall_stats.items(): | |
| 458 test_metrics[k] = v | |
| 459 | |
| 460 # 3. Optionally include combined/loss if present and not already | |
| 461 if "loss" in combined_stats and "loss" not in test_metrics: | |
| 462 test_metrics["loss"] = combined_stats["loss"] | |
| 463 | |
| 464 metrics["test"] = test_metrics | |
| 465 | |
| 466 return metrics | |
| 467 | |
| 468 | |
| 469 def generate_table_row(cells, styles): | |
| 470 """Helper function to generate an HTML table row.""" | |
| 471 return ( | |
| 472 "<tr>" | |
| 473 + "".join(f"<td style='{styles}'>{cell}</td>" for cell in cells) | |
| 474 + "</tr>" | |
| 475 ) | |
| 476 | |
| 477 | |
| 478 def format_stats_table_html(train_stats: dict, test_stats: dict) -> str: | |
| 479 """Formats a combined HTML table for training, validation, and test metrics.""" | |
| 480 output_type = detect_output_type(test_stats) | |
| 481 all_metrics = extract_metrics_from_json(train_stats, test_stats, output_type) | |
| 232 rows = [] | 482 rows = [] |
| 233 for metric in sorted(all_metrics): | 483 for metric_key in sorted(all_metrics["training"].keys()): |
| 234 t = get_last_value(train_metrics, metric) | 484 if ( |
| 235 v = get_last_value(val_metrics, metric) | 485 metric_key in all_metrics["validation"] |
| 236 te = get_last_value(test_metrics, metric) | 486 and metric_key in all_metrics["test"] |
| 237 if all(x is not None for x in [t, v, te]): | 487 ): |
| 238 row = ( | 488 display_name = METRIC_DISPLAY_NAMES.get( |
| 239 f"<tr>" | 489 metric_key, |
| 240 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>{metric}</td>" | 490 metric_key.replace("_", " ").title(), |
| 241 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{t:.4f}</td>" | 491 ) |
| 242 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{v:.4f}</td>" | 492 t = all_metrics["training"].get(metric_key) |
| 243 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{te:.4f}</td>" | 493 v = all_metrics["validation"].get(metric_key) |
| 244 f"</tr>" | 494 te = all_metrics["test"].get(metric_key) |
| 245 ) | 495 if all(x is not None for x in [t, v, te]): |
| 246 rows.append(row) | 496 rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"]) |
| 247 | 497 |
| 248 if not rows: | 498 if not rows: |
| 249 return "<p><em>No metric values found.</em></p>" | 499 return "<table><tr><td>No metric values found.</td></tr></table>" |
| 250 | 500 |
| 251 return ( | 501 html = ( |
| 252 "<h2 style='text-align: center;'>Model Performance Summary</h2>" | 502 "<h2 style='text-align: center;'>Model Performance Summary</h2>" |
| 253 "<div style='display: flex; justify-content: center;'>" | 503 "<div style='display: flex; justify-content: center;'>" |
| 254 "<table style='border-collapse: collapse; width: 80%; table-layout: fixed;'>" | 504 "<table style='border-collapse: collapse; table-layout: auto;'>" |
| 255 "<colgroup>" | |
| 256 "<col style='width: 40%;'>" | |
| 257 "<col style='width: 20%;'>" | |
| 258 "<col style='width: 20%;'>" | |
| 259 "<col style='width: 20%;'>" | |
| 260 "</colgroup>" | |
| 261 "<thead><tr>" | 505 "<thead><tr>" |
| 262 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>Metric</th>" | 506 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; " |
| 263 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Train</th>" | 507 "white-space: nowrap;'>Metric</th>" |
| 264 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Validation</th>" | 508 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " |
| 265 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>Test</th>" | 509 "white-space: nowrap;'>Train</th>" |
| 266 "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>" | 510 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " |
| 267 ) | 511 "white-space: nowrap;'>Validation</th>" |
| 268 | 512 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " |
| 269 | 513 "white-space: nowrap;'>Test</th>" |
| 270 def build_tabbed_html( | 514 "</tr></thead><tbody>" |
| 271 metrics_html: str, | 515 ) |
| 272 train_viz_html: str, | 516 for row in rows: |
| 273 test_viz_html: str) -> str: | 517 html += generate_table_row( |
| 518 row, | |
| 519 "padding: 10px; border: 1px solid #ccc; text-align: center; " | |
| 520 "white-space: nowrap;", | |
| 521 ) | |
| 522 html += "</tbody></table></div><br>" | |
| 523 return html | |
| 524 | |
| 525 | |
| 526 def format_train_val_stats_table_html(train_stats: dict, test_stats: dict) -> str: | |
| 527 """Formats an HTML table for training and validation metrics.""" | |
| 528 output_type = detect_output_type(test_stats) | |
| 529 all_metrics = extract_metrics_from_json(train_stats, test_stats, output_type) | |
| 530 rows = [] | |
| 531 for metric_key in sorted(all_metrics["training"].keys()): | |
| 532 if metric_key in all_metrics["validation"]: | |
| 533 display_name = METRIC_DISPLAY_NAMES.get( | |
| 534 metric_key, | |
| 535 metric_key.replace("_", " ").title(), | |
| 536 ) | |
| 537 t = all_metrics["training"].get(metric_key) | |
| 538 v = all_metrics["validation"].get(metric_key) | |
| 539 if t is not None and v is not None: | |
| 540 rows.append([display_name, f"{t:.4f}", f"{v:.4f}"]) | |
| 541 | |
| 542 if not rows: | |
| 543 return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>" | |
| 544 | |
| 545 html = ( | |
| 546 "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>" | |
| 547 "<div style='display: flex; justify-content: center;'>" | |
| 548 "<table style='border-collapse: collapse; table-layout: auto;'>" | |
| 549 "<thead><tr>" | |
| 550 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; " | |
| 551 "white-space: nowrap;'>Metric</th>" | |
| 552 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " | |
| 553 "white-space: nowrap;'>Train</th>" | |
| 554 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " | |
| 555 "white-space: nowrap;'>Validation</th>" | |
| 556 "</tr></thead><tbody>" | |
| 557 ) | |
| 558 for row in rows: | |
| 559 html += generate_table_row( | |
| 560 row, | |
| 561 "padding: 10px; border: 1px solid #ccc; text-align: center; " | |
| 562 "white-space: nowrap;", | |
| 563 ) | |
| 564 html += "</tbody></table></div><br>" | |
| 565 return html | |
| 566 | |
| 567 | |
| 568 def format_test_merged_stats_table_html(test_metrics: Dict[str, Optional[float]]) -> str: | |
| 569 """Formats an HTML table for test metrics.""" | |
| 570 rows = [] | |
| 571 for key in sorted(test_metrics.keys()): | |
| 572 display_name = METRIC_DISPLAY_NAMES.get(key, key.replace("_", " ").title()) | |
| 573 value = test_metrics[key] | |
| 574 if value is not None: | |
| 575 rows.append([display_name, f"{value:.4f}"]) | |
| 576 | |
| 577 if not rows: | |
| 578 return "<table><tr><td>No test metric values found.</td></tr></table>" | |
| 579 | |
| 580 html = ( | |
| 581 "<h2 style='text-align: center;'>Test Performance Summary</h2>" | |
| 582 "<div style='display: flex; justify-content: center;'>" | |
| 583 "<table style='border-collapse: collapse; table-layout: auto;'>" | |
| 584 "<thead><tr>" | |
| 585 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; " | |
| 586 "white-space: nowrap;'>Metric</th>" | |
| 587 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " | |
| 588 "white-space: nowrap;'>Test</th>" | |
| 589 "</tr></thead><tbody>" | |
| 590 ) | |
| 591 for row in rows: | |
| 592 html += generate_table_row( | |
| 593 row, | |
| 594 "padding: 10px; border: 1px solid #ccc; text-align: center; " | |
| 595 "white-space: nowrap;", | |
| 596 ) | |
| 597 html += "</tbody></table></div><br>" | |
| 598 return html | |
| 599 | |
| 600 | |
| 601 def build_tabbed_html(metrics_html: str, train_val_html: str, test_html: str) -> str: | |
| 274 return f""" | 602 return f""" |
| 275 <style> | 603 <style> |
| 276 .tabs {{ | 604 .tabs {{ |
| 277 display: flex; | 605 display: flex; |
| 278 border-bottom: 2px solid #ccc; | 606 border-bottom: 2px solid #ccc; |
| 300 }} | 628 }} |
| 301 .tab-content.active {{ | 629 .tab-content.active {{ |
| 302 display: block; | 630 display: block; |
| 303 }} | 631 }} |
| 304 </style> | 632 </style> |
| 305 | |
| 306 <div class="tabs"> | 633 <div class="tabs"> |
| 307 <div class="tab active" onclick="showTab('metrics')"> Config & Metrics</div> | 634 <div class="tab active" onclick="showTab('metrics')"> Config & Results Summary</div> |
| 308 <div class="tab" onclick="showTab('trainval')"> Train/Validation Plots</div> | 635 <div class="tab" onclick="showTab('trainval')"> Train/Validation Results</div> |
| 309 <div class="tab" onclick="showTab('test')"> Test Plots</div> | 636 <div class="tab" onclick="showTab('test')"> Test Results</div> |
| 310 </div> | 637 </div> |
| 311 | |
| 312 <div id="metrics" class="tab-content active"> | 638 <div id="metrics" class="tab-content active"> |
| 313 {metrics_html} | 639 {metrics_html} |
| 314 </div> | 640 </div> |
| 315 <div id="trainval" class="tab-content"> | 641 <div id="trainval" class="tab-content"> |
| 316 {train_viz_html} | 642 {train_val_html} |
| 317 </div> | 643 </div> |
| 318 <div id="test" class="tab-content"> | 644 <div id="test" class="tab-content"> |
| 319 {test_viz_html} | 645 {test_html} |
| 320 </div> | 646 </div> |
| 321 | |
| 322 <script> | 647 <script> |
| 323 function showTab(id) {{ | 648 function showTab(id) {{ |
| 324 document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active')); | 649 document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active')); |
| 325 document.querySelectorAll('.tab').forEach(el => el.classList.remove('active')); | 650 document.querySelectorAll('.tab').forEach(el => el.classList.remove('active')); |
| 326 document.getElementById(id).classList.add('active'); | 651 document.getElementById(id).classList.add('active'); |
| 335 split_column: str, | 660 split_column: str, |
| 336 validation_size: float = 0.15, | 661 validation_size: float = 0.15, |
| 337 random_state: int = 42, | 662 random_state: int = 42, |
| 338 label_column: Optional[str] = None, | 663 label_column: Optional[str] = None, |
| 339 ) -> pd.DataFrame: | 664 ) -> pd.DataFrame: |
| 340 """ | 665 """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation).""" |
| 341 Given a DataFrame whose split_column only contains {0,2}, re-assign | |
| 342 a portion of the 0s to become 1s (validation). Returns a fresh DataFrame. | |
| 343 """ | |
| 344 # Work on a copy | |
| 345 out = df.copy() | 666 out = df.copy() |
| 346 # Ensure split col is integer dtype | |
| 347 out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int) | 667 out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int) |
| 348 | 668 |
| 349 idx_train = out.index[out[split_column] == 0].tolist() | 669 idx_train = out.index[out[split_column] == 0].tolist() |
| 350 | 670 |
| 351 if not idx_train: | 671 if not idx_train: |
| 352 logger.info("No rows with split=0; nothing to do.") | 672 logger.info("No rows with split=0; nothing to do.") |
| 353 return out | 673 return out |
| 354 | |
| 355 # Determine stratify array if possible | |
| 356 stratify_arr = None | 674 stratify_arr = None |
| 357 if label_column and label_column in out.columns: | 675 if label_column and label_column in out.columns: |
| 358 # Only stratify if at least two classes and enough samples | |
| 359 label_counts = out.loc[idx_train, label_column].value_counts() | 676 label_counts = out.loc[idx_train, label_column].value_counts() |
| 360 if label_counts.size > 1 and (label_counts.min() * validation_size) >= 1: | 677 if label_counts.size > 1 and (label_counts.min() * validation_size) >= 1: |
| 361 stratify_arr = out.loc[idx_train, label_column] | 678 stratify_arr = out.loc[idx_train, label_column] |
| 362 else: | 679 else: |
| 363 logger.warning("Cannot stratify (too few labels); splitting without stratify.") | 680 logger.warning( |
| 364 | 681 "Cannot stratify (too few labels); splitting without stratify." |
| 365 # Edge cases | 682 ) |
| 366 if validation_size <= 0: | 683 if validation_size <= 0: |
| 367 logger.info("validation_size <= 0; keeping all as train.") | 684 logger.info("validation_size <= 0; keeping all as train.") |
| 368 return out | 685 return out |
| 369 if validation_size >= 1: | 686 if validation_size >= 1: |
| 370 logger.info("validation_size >= 1; moving all train → validation.") | 687 logger.info("validation_size >= 1; moving all train → validation.") |
| 371 out.loc[idx_train, split_column] = 1 | 688 out.loc[idx_train, split_column] = 1 |
| 372 return out | 689 return out |
| 373 | |
| 374 # Do the split | |
| 375 try: | 690 try: |
| 376 train_idx, val_idx = train_test_split( | 691 train_idx, val_idx = train_test_split( |
| 377 idx_train, | 692 idx_train, |
| 378 test_size=validation_size, | 693 test_size=validation_size, |
| 379 random_state=random_state, | 694 random_state=random_state, |
| 380 stratify=stratify_arr | 695 stratify=stratify_arr, |
| 381 ) | 696 ) |
| 382 except ValueError as e: | 697 except ValueError as e: |
| 383 logger.warning(f"Stratified split failed ({e}); retrying without stratify.") | 698 logger.warning(f"Stratified split failed ({e}); retrying without stratify.") |
| 384 train_idx, val_idx = train_test_split( | 699 train_idx, val_idx = train_test_split( |
| 385 idx_train, | 700 idx_train, |
| 386 test_size=validation_size, | 701 test_size=validation_size, |
| 387 random_state=random_state, | 702 random_state=random_state, |
| 388 stratify=None | 703 stratify=None, |
| 389 ) | 704 ) |
| 390 | |
| 391 # Assign new splits | |
| 392 out.loc[train_idx, split_column] = 0 | 705 out.loc[train_idx, split_column] = 0 |
| 393 out.loc[val_idx, split_column] = 1 | 706 out.loc[val_idx, split_column] = 1 |
| 394 # idx_test stays at 2 | |
| 395 | |
| 396 # Cast back to a clean integer type | |
| 397 out[split_column] = out[split_column].astype(int) | 707 out[split_column] = out[split_column].astype(int) |
| 398 # print(out) | |
| 399 return out | 708 return out |
| 400 | 709 |
| 401 | 710 |
| 402 class Backend(Protocol): | 711 class Backend(Protocol): |
| 403 """Interface for a machine learning backend.""" | 712 """Interface for a machine learning backend.""" |
| 713 | |
| 404 def prepare_config( | 714 def prepare_config( |
| 405 self, | 715 self, |
| 406 config_params: Dict[str, Any], | 716 config_params: Dict[str, Any], |
| 407 split_config: Dict[str, Any] | 717 split_config: Dict[str, Any], |
| 408 ) -> str: | 718 ) -> str: |
| 409 ... | 719 ... |
| 410 | 720 |
| 411 def run_experiment( | 721 def run_experiment( |
| 412 self, | 722 self, |
| 430 ) -> Path: | 740 ) -> Path: |
| 431 ... | 741 ... |
| 432 | 742 |
| 433 | 743 |
| 434 class LudwigDirectBackend: | 744 class LudwigDirectBackend: |
| 435 """ | 745 """Backend for running Ludwig experiments directly via the internal experiment_cli function.""" |
| 436 Backend for running Ludwig experiments directly via the internal experiment_cli function. | |
| 437 """ | |
| 438 | 746 |
| 439 def prepare_config( | 747 def prepare_config( |
| 440 self, | 748 self, |
| 441 config_params: Dict[str, Any], | 749 config_params: Dict[str, Any], |
| 442 split_config: Dict[str, Any], | 750 split_config: Dict[str, Any], |
| 443 ) -> str: | 751 ) -> str: |
| 444 """ | 752 """Build and serialize the Ludwig YAML configuration.""" |
| 445 Build and serialize the Ludwig YAML configuration. | |
| 446 """ | |
| 447 logger.info("LudwigDirectBackend: Preparing YAML configuration.") | 753 logger.info("LudwigDirectBackend: Preparing YAML configuration.") |
| 448 | 754 |
| 449 model_name = config_params.get("model_name", "resnet18") | 755 model_name = config_params.get("model_name", "resnet18") |
| 450 use_pretrained = config_params.get("use_pretrained", False) | 756 use_pretrained = config_params.get("use_pretrained", False) |
| 451 fine_tune = config_params.get("fine_tune", False) | 757 fine_tune = config_params.get("fine_tune", False) |
| 458 trainable = fine_tune or (not use_pretrained) | 764 trainable = fine_tune or (not use_pretrained) |
| 459 if not use_pretrained and not trainable: | 765 if not use_pretrained and not trainable: |
| 460 logger.warning("trainable=False; use_pretrained=False is ignored.") | 766 logger.warning("trainable=False; use_pretrained=False is ignored.") |
| 461 logger.warning("Setting trainable=True to train the model from scratch.") | 767 logger.warning("Setting trainable=True to train the model from scratch.") |
| 462 trainable = True | 768 trainable = True |
| 463 | |
| 464 # Encoder setup | |
| 465 raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name) | 769 raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name) |
| 466 if isinstance(raw_encoder, dict): | 770 if isinstance(raw_encoder, dict): |
| 467 encoder_config = { | 771 encoder_config = { |
| 468 **raw_encoder, | 772 **raw_encoder, |
| 469 "use_pretrained": use_pretrained, | 773 "use_pretrained": use_pretrained, |
| 470 "trainable": trainable, | 774 "trainable": trainable, |
| 471 } | 775 } |
| 472 else: | 776 else: |
| 473 encoder_config = {"type": raw_encoder} | 777 encoder_config = {"type": raw_encoder} |
| 474 | 778 |
| 475 # Trainer & optimizer | |
| 476 # optimizer = {"type": "adam", "learning_rate": 5e-5} if fine_tune else {"type": "adam"} | |
| 477 batch_size_cfg = batch_size or "auto" | 779 batch_size_cfg = batch_size or "auto" |
| 780 | |
| 781 label_column_path = config_params.get("label_column_data_path") | |
| 782 if label_column_path is not None and Path(label_column_path).exists(): | |
| 783 try: | |
| 784 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] | |
| 785 num_unique_labels = label_series.nunique() | |
| 786 except Exception as e: | |
| 787 logger.warning( | |
| 788 f"Could not determine label cardinality, defaulting to 'binary': {e}" | |
| 789 ) | |
| 790 num_unique_labels = 2 | |
| 791 else: | |
| 792 logger.warning( | |
| 793 "label_column_data_path not provided, defaulting to 'binary'" | |
| 794 ) | |
| 795 num_unique_labels = 2 | |
| 796 | |
| 797 output_type = "binary" if num_unique_labels == 2 else "category" | |
| 478 | 798 |
| 479 conf: Dict[str, Any] = { | 799 conf: Dict[str, Any] = { |
| 480 "model_type": "ecd", | 800 "model_type": "ecd", |
| 481 "input_features": [ | 801 "input_features": [ |
| 482 { | 802 { |
| 483 "name": IMAGE_PATH_COLUMN_NAME, | 803 "name": IMAGE_PATH_COLUMN_NAME, |
| 484 "type": "image", | 804 "type": "image", |
| 485 "encoder": encoder_config, | 805 "encoder": encoder_config, |
| 486 } | 806 } |
| 487 ], | 807 ], |
| 488 "output_features": [ | 808 "output_features": [{"name": LABEL_COLUMN_NAME, "type": output_type}], |
| 489 {"name": LABEL_COLUMN_NAME, "type": "category"} | |
| 490 ], | |
| 491 "combiner": {"type": "concat"}, | 809 "combiner": {"type": "concat"}, |
| 492 "trainer": { | 810 "trainer": { |
| 493 "epochs": epochs, | 811 "epochs": epochs, |
| 494 "early_stop": early_stop, | 812 "early_stop": early_stop, |
| 495 "batch_size": batch_size_cfg, | 813 "batch_size": batch_size_cfg, |
| 506 try: | 824 try: |
| 507 yaml_str = yaml.dump(conf, sort_keys=False, indent=2) | 825 yaml_str = yaml.dump(conf, sort_keys=False, indent=2) |
| 508 logger.info("LudwigDirectBackend: YAML config generated.") | 826 logger.info("LudwigDirectBackend: YAML config generated.") |
| 509 return yaml_str | 827 return yaml_str |
| 510 except Exception: | 828 except Exception: |
| 511 logger.error("LudwigDirectBackend: Failed to serialize YAML.", exc_info=True) | 829 logger.error( |
| 830 "LudwigDirectBackend: Failed to serialize YAML.", | |
| 831 exc_info=True, | |
| 832 ) | |
| 512 raise | 833 raise |
| 513 | 834 |
| 514 def run_experiment( | 835 def run_experiment( |
| 515 self, | 836 self, |
| 516 dataset_path: Path, | 837 dataset_path: Path, |
| 517 config_path: Path, | 838 config_path: Path, |
| 518 output_dir: Path, | 839 output_dir: Path, |
| 519 random_seed: int = 42, | 840 random_seed: int = 42, |
| 520 ) -> None: | 841 ) -> None: |
| 521 """ | 842 """Invoke Ludwig's internal experiment_cli function to run the experiment.""" |
| 522 Invoke Ludwig's internal experiment_cli function to run the experiment. | |
| 523 """ | |
| 524 logger.info("LudwigDirectBackend: Starting experiment execution.") | 843 logger.info("LudwigDirectBackend: Starting experiment execution.") |
| 525 | 844 |
| 526 try: | 845 try: |
| 527 from ludwig.experiment import experiment_cli | 846 from ludwig.experiment import experiment_cli |
| 528 except ImportError as e: | 847 except ImportError as e: |
| 529 logger.error( | 848 logger.error( |
| 530 "LudwigDirectBackend: Could not import experiment_cli.", | 849 "LudwigDirectBackend: Could not import experiment_cli.", |
| 531 exc_info=True | 850 exc_info=True, |
| 532 ) | 851 ) |
| 533 raise RuntimeError("Ludwig import failed.") from e | 852 raise RuntimeError("Ludwig import failed.") from e |
| 534 | 853 |
| 535 output_dir.mkdir(parents=True, exist_ok=True) | 854 output_dir.mkdir(parents=True, exist_ok=True) |
| 536 | 855 |
| 539 dataset=str(dataset_path), | 858 dataset=str(dataset_path), |
| 540 config=str(config_path), | 859 config=str(config_path), |
| 541 output_directory=str(output_dir), | 860 output_directory=str(output_dir), |
| 542 random_seed=random_seed, | 861 random_seed=random_seed, |
| 543 ) | 862 ) |
| 544 logger.info(f"LudwigDirectBackend: Experiment completed. Results in {output_dir}") | 863 logger.info( |
| 864 f"LudwigDirectBackend: Experiment completed. Results in {output_dir}" | |
| 865 ) | |
| 545 except TypeError as e: | 866 except TypeError as e: |
| 546 logger.error( | 867 logger.error( |
| 547 "LudwigDirectBackend: Argument mismatch in experiment_cli call.", | 868 "LudwigDirectBackend: Argument mismatch in experiment_cli call.", |
| 548 exc_info=True | 869 exc_info=True, |
| 549 ) | 870 ) |
| 550 raise RuntimeError("Ludwig argument error.") from e | 871 raise RuntimeError("Ludwig argument error.") from e |
| 551 except Exception: | 872 except Exception: |
| 552 logger.error( | 873 logger.error( |
| 553 "LudwigDirectBackend: Experiment execution error.", | 874 "LudwigDirectBackend: Experiment execution error.", |
| 554 exc_info=True | 875 exc_info=True, |
| 555 ) | 876 ) |
| 556 raise | 877 raise |
| 557 | 878 |
| 558 def get_training_process(self, output_dir) -> float: | 879 def get_training_process(self, output_dir) -> float: |
| 559 """ | 880 """Retrieve the learning rate used in the most recent Ludwig run.""" |
| 560 Retrieve the learning rate used in the most recent Ludwig run. | |
| 561 Returns: | |
| 562 float: learning rate (or None if not found) | |
| 563 """ | |
| 564 output_dir = Path(output_dir) | 881 output_dir = Path(output_dir) |
| 565 exp_dirs = sorted( | 882 exp_dirs = sorted( |
| 566 output_dir.glob("experiment_run*"), | 883 output_dir.glob("experiment_run*"), |
| 567 key=lambda p: p.stat().st_mtime | 884 key=lambda p: p.stat().st_mtime, |
| 568 ) | 885 ) |
| 569 | 886 |
| 570 if not exp_dirs: | 887 if not exp_dirs: |
| 571 logger.warning(f"No experiment run directories found in {output_dir}") | 888 logger.warning(f"No experiment run directories found in {output_dir}") |
| 572 return None | 889 return None |
| 583 "learning_rate": data.get("learning_rate"), | 900 "learning_rate": data.get("learning_rate"), |
| 584 "batch_size": data.get("batch_size"), | 901 "batch_size": data.get("batch_size"), |
| 585 "epoch": data.get("epoch"), | 902 "epoch": data.get("epoch"), |
| 586 } | 903 } |
| 587 except Exception as e: | 904 except Exception as e: |
| 588 self.logger.warning(f"Failed to read training progress info: {e}") | 905 logger.warning(f"Failed to read training progress info: {e}") |
| 589 return {} | 906 return {} |
| 590 | 907 |
| 591 def convert_parquet_to_csv(self, output_dir: Path): | 908 def convert_parquet_to_csv(self, output_dir: Path): |
| 592 """Convert the predictions Parquet file to CSV.""" | 909 """Convert the predictions Parquet file to CSV.""" |
| 593 output_dir = Path(output_dir) | 910 output_dir = Path(output_dir) |
| 594 exp_dirs = sorted( | 911 exp_dirs = sorted( |
| 595 output_dir.glob("experiment_run*"), | 912 output_dir.glob("experiment_run*"), |
| 596 key=lambda p: p.stat().st_mtime | 913 key=lambda p: p.stat().st_mtime, |
| 597 ) | 914 ) |
| 598 if not exp_dirs: | 915 if not exp_dirs: |
| 599 logger.warning(f"No experiment run dirs found in {output_dir}") | 916 logger.warning(f"No experiment run dirs found in {output_dir}") |
| 600 return | 917 return |
| 601 exp_dir = exp_dirs[-1] | 918 exp_dir = exp_dirs[-1] |
| 607 logger.info(f"Converted Parquet to CSV: {csv_path}") | 924 logger.info(f"Converted Parquet to CSV: {csv_path}") |
| 608 except Exception as e: | 925 except Exception as e: |
| 609 logger.error(f"Error converting Parquet to CSV: {e}") | 926 logger.error(f"Error converting Parquet to CSV: {e}") |
| 610 | 927 |
| 611 def generate_plots(self, output_dir: Path) -> None: | 928 def generate_plots(self, output_dir: Path) -> None: |
| 612 """ | 929 """Generate all registered Ludwig visualizations for the latest experiment run.""" |
| 613 Generate _all_ registered Ludwig visualizations for the latest experiment run. | |
| 614 """ | |
| 615 logger.info("Generating all Ludwig visualizations…") | 930 logger.info("Generating all Ludwig visualizations…") |
| 616 | 931 |
| 617 test_plots = { | 932 test_plots = { |
| 618 'compare_performance', | 933 "compare_performance", |
| 619 'compare_classifiers_performance_from_prob', | 934 "compare_classifiers_performance_from_prob", |
| 620 'compare_classifiers_performance_from_pred', | 935 "compare_classifiers_performance_from_pred", |
| 621 'compare_classifiers_performance_changing_k', | 936 "compare_classifiers_performance_changing_k", |
| 622 'compare_classifiers_multiclass_multimetric', | 937 "compare_classifiers_multiclass_multimetric", |
| 623 'compare_classifiers_predictions', | 938 "compare_classifiers_predictions", |
| 624 'confidence_thresholding_2thresholds_2d', | 939 "confidence_thresholding_2thresholds_2d", |
| 625 'confidence_thresholding_2thresholds_3d', | 940 "confidence_thresholding_2thresholds_3d", |
| 626 'confidence_thresholding', | 941 "confidence_thresholding", |
| 627 'confidence_thresholding_data_vs_acc', | 942 "confidence_thresholding_data_vs_acc", |
| 628 'binary_threshold_vs_metric', | 943 "binary_threshold_vs_metric", |
| 629 'roc_curves', | 944 "roc_curves", |
| 630 'roc_curves_from_test_statistics', | 945 "roc_curves_from_test_statistics", |
| 631 'calibration_1_vs_all', | 946 "calibration_1_vs_all", |
| 632 'calibration_multiclass', | 947 "calibration_multiclass", |
| 633 'confusion_matrix', | 948 "confusion_matrix", |
| 634 'frequency_vs_f1', | 949 "frequency_vs_f1", |
| 635 } | 950 } |
| 636 train_plots = { | 951 train_plots = { |
| 637 'learning_curves', | 952 "learning_curves", |
| 638 'compare_classifiers_performance_subset', | 953 "compare_classifiers_performance_subset", |
| 639 } | 954 } |
| 640 | 955 |
| 641 # 1) find the most recent experiment directory | |
| 642 output_dir = Path(output_dir) | 956 output_dir = Path(output_dir) |
| 643 exp_dirs = sorted( | 957 exp_dirs = sorted( |
| 644 output_dir.glob("experiment_run*"), | 958 output_dir.glob("experiment_run*"), |
| 645 key=lambda p: p.stat().st_mtime | 959 key=lambda p: p.stat().st_mtime, |
| 646 ) | 960 ) |
| 647 if not exp_dirs: | 961 if not exp_dirs: |
| 648 logger.warning(f"No experiment run dirs found in {output_dir}") | 962 logger.warning(f"No experiment run dirs found in {output_dir}") |
| 649 return | 963 return |
| 650 exp_dir = exp_dirs[-1] | 964 exp_dir = exp_dirs[-1] |
| 651 | 965 |
| 652 # 2) ensure viz output subfolder exists | |
| 653 viz_dir = exp_dir / "visualizations" | 966 viz_dir = exp_dir / "visualizations" |
| 654 viz_dir.mkdir(exist_ok=True) | 967 viz_dir.mkdir(exist_ok=True) |
| 655 train_viz = viz_dir / "train" | 968 train_viz = viz_dir / "train" |
| 656 test_viz = viz_dir / "test" | 969 test_viz = viz_dir / "test" |
| 657 train_viz.mkdir(parents=True, exist_ok=True) | 970 train_viz.mkdir(parents=True, exist_ok=True) |
| 658 test_viz.mkdir(parents=True, exist_ok=True) | 971 test_viz.mkdir(parents=True, exist_ok=True) |
| 659 | 972 |
| 660 # 3) helper to check file existence | |
| 661 def _check(p: Path) -> Optional[str]: | 973 def _check(p: Path) -> Optional[str]: |
| 662 return str(p) if p.exists() else None | 974 return str(p) if p.exists() else None |
| 663 | 975 |
| 664 # 4) gather standard Ludwig output files | |
| 665 training_stats = _check(exp_dir / "training_statistics.json") | 976 training_stats = _check(exp_dir / "training_statistics.json") |
| 666 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) | 977 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) |
| 667 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) | 978 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) |
| 668 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) | 979 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) |
| 669 | 980 |
| 670 # 5) try to read original dataset & split file from description.json | |
| 671 dataset_path = None | 981 dataset_path = None |
| 672 split_file = None | 982 split_file = None |
| 673 desc = exp_dir / DESCRIPTION_FILE_NAME | 983 desc = exp_dir / DESCRIPTION_FILE_NAME |
| 674 if desc.exists(): | 984 if desc.exists(): |
| 675 with open(desc, "r") as f: | 985 with open(desc, "r") as f: |
| 676 cfg = json.load(f) | 986 cfg = json.load(f) |
| 677 dataset_path = _check(Path(cfg.get("dataset", ""))) | 987 dataset_path = _check(Path(cfg.get("dataset", ""))) |
| 678 split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) | 988 split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) |
| 679 | 989 |
| 680 # 6) infer output feature name | |
| 681 output_feature = "" | 990 output_feature = "" |
| 682 if desc.exists(): | 991 if desc.exists(): |
| 683 try: | 992 try: |
| 684 output_feature = cfg["config"]["output_features"][0]["name"] | 993 output_feature = cfg["config"]["output_features"][0]["name"] |
| 685 except Exception: | 994 except Exception: |
| 687 if not output_feature and test_stats: | 996 if not output_feature and test_stats: |
| 688 with open(test_stats, "r") as f: | 997 with open(test_stats, "r") as f: |
| 689 stats = json.load(f) | 998 stats = json.load(f) |
| 690 output_feature = next(iter(stats.keys()), "") | 999 output_feature = next(iter(stats.keys()), "") |
| 691 | 1000 |
| 692 # 7) loop through every registered viz | |
| 693 viz_registry = get_visualizations_registry() | 1001 viz_registry = get_visualizations_registry() |
| 694 for viz_name, viz_func in viz_registry.items(): | 1002 for viz_name, viz_func in viz_registry.items(): |
| 695 viz_dir_plot = None | 1003 viz_dir_plot = None |
| 696 if viz_name in train_plots: | 1004 if viz_name in train_plots: |
| 697 viz_dir_plot = train_viz | 1005 viz_dir_plot = train_viz |
| 719 logger.warning(f"✘ Skipped {viz_name}: {e}") | 1027 logger.warning(f"✘ Skipped {viz_name}: {e}") |
| 720 | 1028 |
| 721 logger.info(f"All visualizations written to {viz_dir}") | 1029 logger.info(f"All visualizations written to {viz_dir}") |
| 722 | 1030 |
| 723 def generate_html_report( | 1031 def generate_html_report( |
| 724 self, | 1032 self, |
| 725 title: str, | 1033 title: str, |
| 726 output_dir: str, | 1034 output_dir: str, |
| 727 config: dict, | 1035 config: dict, |
| 728 split_info: str) -> Path: | 1036 split_info: str, |
| 729 """ | 1037 ) -> Path: |
| 730 Assemble an HTML report from visualizations under train_val/ and test/ folders. | 1038 """Assemble an HTML report from visualizations under train_val/ and test/ folders.""" |
| 731 """ | |
| 732 cwd = Path.cwd() | 1039 cwd = Path.cwd() |
| 733 report_name = title.lower().replace(" ", "_") + "_report.html" | 1040 report_name = title.lower().replace(" ", "_") + "_report.html" |
| 734 report_path = cwd / report_name | 1041 report_path = cwd / report_name |
| 735 output_dir = Path(output_dir) | 1042 output_dir = Path(output_dir) |
| 736 | 1043 |
| 737 # Find latest experiment dir | 1044 exp_dirs = sorted( |
| 738 exp_dirs = sorted(output_dir.glob("experiment_run*"), key=lambda p: p.stat().st_mtime) | 1045 output_dir.glob("experiment_run*"), |
| 1046 key=lambda p: p.stat().st_mtime, | |
| 1047 ) | |
| 739 if not exp_dirs: | 1048 if not exp_dirs: |
| 740 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") | 1049 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") |
| 741 exp_dir = exp_dirs[-1] | 1050 exp_dir = exp_dirs[-1] |
| 742 | 1051 |
| 743 base_viz_dir = exp_dir / "visualizations" | 1052 base_viz_dir = exp_dir / "visualizations" |
| 746 | 1055 |
| 747 html = get_html_template() | 1056 html = get_html_template() |
| 748 html += f"<h1>{title}</h1>" | 1057 html += f"<h1>{title}</h1>" |
| 749 | 1058 |
| 750 metrics_html = "" | 1059 metrics_html = "" |
| 751 | 1060 train_val_metrics_html = "" |
| 752 # Load and embed metrics table (training/val/test stats) | 1061 test_metrics_html = "" |
| 1062 | |
| 753 try: | 1063 try: |
| 754 train_stats_path = exp_dir / "training_statistics.json" | 1064 train_stats_path = exp_dir / "training_statistics.json" |
| 755 test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME | 1065 test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME |
| 756 if train_stats_path.exists() and test_stats_path.exists(): | 1066 if train_stats_path.exists() and test_stats_path.exists(): |
| 757 with open(train_stats_path) as f: | 1067 with open(train_stats_path) as f: |
| 758 train_stats = json.load(f) | 1068 train_stats = json.load(f) |
| 759 with open(test_stats_path) as f: | 1069 with open(test_stats_path) as f: |
| 760 test_stats = json.load(f) | 1070 test_stats = json.load(f) |
| 761 output_feature = next(iter(train_stats.keys()), "") | 1071 output_type = detect_output_type(test_stats) |
| 762 if output_feature: | 1072 all_metrics = extract_metrics_from_json( |
| 763 metrics_html += format_stats_table_html(train_stats, test_stats) | 1073 train_stats, |
| 1074 test_stats, | |
| 1075 output_type, | |
| 1076 ) | |
| 1077 metrics_html = format_stats_table_html(train_stats, test_stats) | |
| 1078 train_val_metrics_html = format_train_val_stats_table_html( | |
| 1079 train_stats, | |
| 1080 test_stats, | |
| 1081 ) | |
| 1082 test_metrics_html = format_test_merged_stats_table_html( | |
| 1083 all_metrics["test"], | |
| 1084 ) | |
| 764 except Exception as e: | 1085 except Exception as e: |
| 765 logger.warning(f"Could not load stats for HTML report: {e}") | 1086 logger.warning( |
| 1087 f"Could not load stats for HTML report: {type(e).__name__}: {e}" | |
| 1088 ) | |
| 766 | 1089 |
| 767 config_html = "" | 1090 config_html = "" |
| 768 training_progress = self.get_training_process(output_dir) | 1091 training_progress = self.get_training_process(output_dir) |
| 769 try: | 1092 try: |
| 770 config_html = format_config_table_html(config, split_info, training_progress) | 1093 config_html = format_config_table_html(config, split_info, training_progress) |
| 771 except Exception as e: | 1094 except Exception as e: |
| 772 logger.warning(f"Could not load config for HTML report: {e}") | 1095 logger.warning(f"Could not load config for HTML report: {e}") |
| 773 | 1096 |
| 774 def render_img_section(title: str, dir_path: Path) -> str: | 1097 def render_img_section(title: str, dir_path: Path, output_type: str = None) -> str: |
| 775 if not dir_path.exists(): | 1098 if not dir_path.exists(): |
| 776 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" | 1099 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" |
| 777 imgs = sorted(dir_path.glob("*.png")) | 1100 |
| 1101 imgs = list(dir_path.glob("*.png")) | |
| 778 if not imgs: | 1102 if not imgs: |
| 779 return f"<h2>{title}</h2><p><em>No plots found.</em></p>" | 1103 return f"<h2>{title}</h2><p><em>No plots found.</em></p>" |
| 1104 | |
| 1105 if title == "Test Visualizations" and output_type == "binary": | |
| 1106 order = [ | |
| 1107 "confusion_matrix__label_top2.png", | |
| 1108 "roc_curves_from_prediction_statistics.png", | |
| 1109 "compare_performance_label.png", | |
| 1110 "confusion_matrix_entropy__label_top2.png", | |
| 1111 ] | |
| 1112 img_names = {img.name: img for img in imgs} | |
| 1113 ordered_imgs = [ | |
| 1114 img_names[fname] for fname in order if fname in img_names | |
| 1115 ] | |
| 1116 remaining = sorted( | |
| 1117 [ | |
| 1118 img | |
| 1119 for img in imgs | |
| 1120 if img.name not in order and img.name != "roc_curves.png" | |
| 1121 ] | |
| 1122 ) | |
| 1123 imgs = ordered_imgs + remaining | |
| 1124 | |
| 1125 elif title == "Test Visualizations" and output_type == "category": | |
| 1126 unwanted = { | |
| 1127 "compare_classifiers_multiclass_multimetric__label_best10.png", | |
| 1128 "compare_classifiers_multiclass_multimetric__label_top10.png", | |
| 1129 "compare_classifiers_multiclass_multimetric__label_worst10.png", | |
| 1130 } | |
| 1131 display_order = [ | |
| 1132 "confusion_matrix__label_top10.png", | |
| 1133 "roc_curves.png", | |
| 1134 "compare_performance_label.png", | |
| 1135 "compare_classifiers_performance_from_prob.png", | |
| 1136 "compare_classifiers_multiclass_multimetric__label_sorted.png", | |
| 1137 "confusion_matrix_entropy__label_top10.png", | |
| 1138 ] | |
| 1139 img_names = {img.name: img for img in imgs if img.name not in unwanted} | |
| 1140 ordered_imgs = [ | |
| 1141 img_names[fname] for fname in display_order if fname in img_names | |
| 1142 ] | |
| 1143 remaining = sorted( | |
| 1144 [ | |
| 1145 img | |
| 1146 for img in img_names.values() | |
| 1147 if img.name not in display_order | |
| 1148 ] | |
| 1149 ) | |
| 1150 imgs = ordered_imgs + remaining | |
| 1151 | |
| 1152 else: | |
| 1153 if output_type == "category": | |
| 1154 unwanted = { | |
| 1155 "compare_classifiers_multiclass_multimetric__label_best10.png", | |
| 1156 "compare_classifiers_multiclass_multimetric__label_top10.png", | |
| 1157 "compare_classifiers_multiclass_multimetric__label_worst10.png", | |
| 1158 } | |
| 1159 imgs = sorted([img for img in imgs if img.name not in unwanted]) | |
| 1160 else: | |
| 1161 imgs = sorted(imgs) | |
| 780 | 1162 |
| 781 section_html = f"<h2 style='text-align: center;'>{title}</h2><div>" | 1163 section_html = f"<h2 style='text-align: center;'>{title}</h2><div>" |
| 782 for img in imgs: | 1164 for img in imgs: |
| 783 b64 = encode_image_to_base64(str(img)) | 1165 b64 = encode_image_to_base64(str(img)) |
| 784 section_html += ( | 1166 section_html += ( |
| 785 f'<div class="plot" style="margin-bottom:20px;text-align:center;">' | 1167 f'<div class="plot" style="margin-bottom:20px;text-align:center;">' |
| 786 f"<h3>{img.stem.replace('_',' ').title()}</h3>" | 1168 f"<h3>{img.stem.replace('_', ' ').title()}</h3>" |
| 787 f'<img src="data:image/png;base64,{b64}" ' | 1169 f'<img src="data:image/png;base64,{b64}" ' |
| 788 'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' | 1170 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' |
| 789 "</div>" | 1171 f"</div>" |
| 790 ) | 1172 ) |
| 791 section_html += "</div>" | 1173 section_html += "</div>" |
| 792 return section_html | 1174 return section_html |
| 793 | 1175 |
| 794 train_plots_html = render_img_section("Training & Validation Visualizations", train_viz_dir) | 1176 button_html = """ |
| 795 test_plots_html = render_img_section("Test Visualizations", test_viz_dir) | 1177 <button class="help-modal-btn openMetricsHelp">Model Evaluation Metrics — Help Guide</button> |
| 796 html += build_tabbed_html(config_html + metrics_html, train_plots_html, test_plots_html) | 1178 <br><br> |
| 1179 <style> | |
| 1180 .help-modal-btn { | |
| 1181 background-color: #17623b; | |
| 1182 color: #fff; | |
| 1183 border: none; | |
| 1184 border-radius: 24px; | |
| 1185 padding: 10px 28px; | |
| 1186 font-size: 1.1rem; | |
| 1187 font-weight: bold; | |
| 1188 letter-spacing: 0.03em; | |
| 1189 cursor: pointer; | |
| 1190 transition: background 0.2s, box-shadow 0.2s; | |
| 1191 box-shadow: 0 2px 8px rgba(23,98,59,0.07); | |
| 1192 } | |
| 1193 .help-modal-btn:hover, .help-modal-btn:focus { | |
| 1194 background-color: #21895e; | |
| 1195 outline: none; | |
| 1196 box-shadow: 0 4px 16px rgba(23,98,59,0.14); | |
| 1197 } | |
| 1198 </style> | |
| 1199 """ | |
| 1200 tab1_content = button_html + config_html + metrics_html | |
| 1201 tab2_content = ( | |
| 1202 button_html | |
| 1203 + train_val_metrics_html | |
| 1204 + render_img_section("Training & Validation Visualizations", train_viz_dir) | |
| 1205 ) | |
| 1206 tab3_content = ( | |
| 1207 button_html | |
| 1208 + test_metrics_html | |
| 1209 + render_img_section("Test Visualizations", test_viz_dir, output_type) | |
| 1210 ) | |
| 1211 | |
| 1212 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) | |
| 1213 modal_html = get_metrics_help_modal() | |
| 1214 html += tabbed_html + modal_html | |
| 797 html += get_html_closing() | 1215 html += get_html_closing() |
| 798 | 1216 |
| 799 try: | 1217 try: |
| 800 with open(report_path, "w") as f: | 1218 with open(report_path, "w") as f: |
| 801 f.write(html) | 1219 f.write(html) |
| 806 | 1224 |
| 807 return report_path | 1225 return report_path |
| 808 | 1226 |
| 809 | 1227 |
| 810 class WorkflowOrchestrator: | 1228 class WorkflowOrchestrator: |
| 811 """ | 1229 """Manages the image-classification workflow.""" |
| 812 Manages the image-classification workflow: | |
| 813 1. Creates temp dirs | |
| 814 2. Extracts images | |
| 815 3. Prepares data (CSV + splits) | |
| 816 4. Renders a backend config | |
| 817 5. Runs the experiment | |
| 818 6. Cleans up | |
| 819 """ | |
| 820 | 1230 |
| 821 def __init__(self, args: argparse.Namespace, backend: Backend): | 1231 def __init__(self, args: argparse.Namespace, backend: Backend): |
| 822 self.args = args | 1232 self.args = args |
| 823 self.backend = backend | 1233 self.backend = backend |
| 824 self.temp_dir: Optional[Path] = None | 1234 self.temp_dir: Optional[Path] = None |
| 826 logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}") | 1236 logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}") |
| 827 | 1237 |
| 828 def _create_temp_dirs(self) -> None: | 1238 def _create_temp_dirs(self) -> None: |
| 829 """Create temporary output and image extraction directories.""" | 1239 """Create temporary output and image extraction directories.""" |
| 830 try: | 1240 try: |
| 831 self.temp_dir = Path(tempfile.mkdtemp( | 1241 self.temp_dir = Path( |
| 832 dir=self.args.output_dir, | 1242 tempfile.mkdtemp(dir=self.args.output_dir, prefix=TEMP_DIR_PREFIX) |
| 833 prefix=TEMP_DIR_PREFIX | 1243 ) |
| 834 )) | |
| 835 self.image_extract_dir = self.temp_dir / "images" | 1244 self.image_extract_dir = self.temp_dir / "images" |
| 836 self.image_extract_dir.mkdir() | 1245 self.image_extract_dir.mkdir() |
| 837 logger.info(f"Created temp directory: {self.temp_dir}") | 1246 logger.info(f"Created temp directory: {self.temp_dir}") |
| 838 except Exception: | 1247 except Exception: |
| 839 logger.error("Failed to create temporary directories", exc_info=True) | 1248 logger.error("Failed to create temporary directories", exc_info=True) |
| 841 | 1250 |
| 842 def _extract_images(self) -> None: | 1251 def _extract_images(self) -> None: |
| 843 """Extract images from ZIP into the temp image directory.""" | 1252 """Extract images from ZIP into the temp image directory.""" |
| 844 if self.image_extract_dir is None: | 1253 if self.image_extract_dir is None: |
| 845 raise RuntimeError("Temp image directory not initialized.") | 1254 raise RuntimeError("Temp image directory not initialized.") |
| 846 logger.info(f"Extracting images from {self.args.image_zip} → {self.image_extract_dir}") | 1255 logger.info( |
| 1256 f"Extracting images from {self.args.image_zip} → {self.image_extract_dir}" | |
| 1257 ) | |
| 847 try: | 1258 try: |
| 848 with zipfile.ZipFile(self.args.image_zip, "r") as z: | 1259 with zipfile.ZipFile(self.args.image_zip, "r") as z: |
| 849 z.extractall(self.image_extract_dir) | 1260 z.extractall(self.image_extract_dir) |
| 850 logger.info("Image extraction complete.") | 1261 logger.info("Image extraction complete.") |
| 851 except Exception: | 1262 except Exception: |
| 852 logger.error("Error extracting zip file", exc_info=True) | 1263 logger.error("Error extracting zip file", exc_info=True) |
| 853 raise | 1264 raise |
| 854 | 1265 |
| 855 def _prepare_data(self) -> Tuple[Path, Dict[str, Any]]: | 1266 def _prepare_data(self) -> Tuple[Path, Dict[str, Any]]: |
| 856 """ | 1267 """Load CSV, update image paths, handle splits, and write prepared CSV.""" |
| 857 Load CSV, update image paths, handle splits, and write prepared CSV. | |
| 858 Returns: | |
| 859 final_csv_path: Path to the prepared CSV | |
| 860 split_config: Dict for backend split settings | |
| 861 """ | |
| 862 if not self.temp_dir or not self.image_extract_dir: | 1268 if not self.temp_dir or not self.image_extract_dir: |
| 863 raise RuntimeError("Temp dirs not initialized before data prep.") | 1269 raise RuntimeError("Temp dirs not initialized before data prep.") |
| 864 | 1270 |
| 865 # 1) Load | |
| 866 try: | 1271 try: |
| 867 df = pd.read_csv(self.args.csv_file) | 1272 df = pd.read_csv(self.args.csv_file) |
| 868 logger.info(f"Loaded CSV: {self.args.csv_file}") | 1273 logger.info(f"Loaded CSV: {self.args.csv_file}") |
| 869 except Exception: | 1274 except Exception: |
| 870 logger.error("Error loading CSV file", exc_info=True) | 1275 logger.error("Error loading CSV file", exc_info=True) |
| 871 raise | 1276 raise |
| 872 | 1277 |
| 873 # 2) Validate columns | |
| 874 required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} | 1278 required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} |
| 875 missing = required - set(df.columns) | 1279 missing = required - set(df.columns) |
| 876 if missing: | 1280 if missing: |
| 877 raise ValueError(f"Missing CSV columns: {', '.join(missing)}") | 1281 raise ValueError(f"Missing CSV columns: {', '.join(missing)}") |
| 878 | 1282 |
| 879 # 3) Update image paths | |
| 880 try: | 1283 try: |
| 881 df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( | 1284 df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( |
| 882 lambda p: str((self.image_extract_dir / p).resolve()) | 1285 lambda p: str((self.image_extract_dir / p).resolve()) |
| 883 ) | 1286 ) |
| 884 except Exception: | 1287 except Exception: |
| 885 logger.error("Error updating image paths", exc_info=True) | 1288 logger.error("Error updating image paths", exc_info=True) |
| 886 raise | 1289 raise |
| 887 | 1290 |
| 888 # 4) Handle splits | |
| 889 if SPLIT_COLUMN_NAME in df.columns: | 1291 if SPLIT_COLUMN_NAME in df.columns: |
| 890 df, split_config, split_info = self._process_fixed_split(df) | 1292 df, split_config, split_info = self._process_fixed_split(df) |
| 891 else: | 1293 else: |
| 892 logger.info("No split column; using random split") | 1294 logger.info("No split column; using random split") |
| 893 split_config = { | 1295 split_config = { |
| 894 "type": "random", | 1296 "type": "random", |
| 895 "probabilities": self.args.split_probabilities | 1297 "probabilities": self.args.split_probabilities, |
| 896 } | 1298 } |
| 897 split_info = ( | 1299 split_info = ( |
| 898 f"No split column in CSV. Used random split: " | 1300 f"No split column in CSV. Used random split: " |
| 899 f"{[int(p*100) for p in self.args.split_probabilities]}% for train/val/test." | 1301 f"{[int(p * 100) for p in self.args.split_probabilities]}% " |
| 900 ) | 1302 f"for train/val/test." |
| 901 | 1303 ) |
| 902 # 5) Write out prepared CSV | 1304 |
| 903 final_csv = TEMP_CSV_FILENAME | 1305 final_csv = TEMP_CSV_FILENAME |
| 904 try: | 1306 try: |
| 905 df.to_csv(final_csv, index=False) | 1307 df.to_csv(final_csv, index=False) |
| 906 logger.info(f"Saved prepared data to {final_csv}") | 1308 logger.info(f"Saved prepared data to {final_csv}") |
| 907 except Exception: | 1309 except Exception: |
| 913 def _process_fixed_split(self, df: pd.DataFrame) -> Dict[str, Any]: | 1315 def _process_fixed_split(self, df: pd.DataFrame) -> Dict[str, Any]: |
| 914 """Process a fixed split column (0=train,1=val,2=test).""" | 1316 """Process a fixed split column (0=train,1=val,2=test).""" |
| 915 logger.info(f"Fixed split column '{SPLIT_COLUMN_NAME}' detected.") | 1317 logger.info(f"Fixed split column '{SPLIT_COLUMN_NAME}' detected.") |
| 916 try: | 1318 try: |
| 917 col = df[SPLIT_COLUMN_NAME] | 1319 col = df[SPLIT_COLUMN_NAME] |
| 918 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype(pd.Int64Dtype()) | 1320 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype( |
| 1321 pd.Int64Dtype() | |
| 1322 ) | |
| 919 if df[SPLIT_COLUMN_NAME].isna().any(): | 1323 if df[SPLIT_COLUMN_NAME].isna().any(): |
| 920 logger.warning("Split column contains non-numeric/missing values.") | 1324 logger.warning("Split column contains non-numeric/missing values.") |
| 921 | 1325 |
| 922 unique = set(df[SPLIT_COLUMN_NAME].dropna().unique()) | 1326 unique = set(df[SPLIT_COLUMN_NAME].dropna().unique()) |
| 923 logger.info(f"Unique split values: {unique}") | 1327 logger.info(f"Unique split values: {unique}") |
| 924 | 1328 |
| 925 if unique == {0, 2}: | 1329 if unique == {0, 2}: |
| 926 df = split_data_0_2( | 1330 df = split_data_0_2( |
| 927 df, SPLIT_COLUMN_NAME, | 1331 df, |
| 1332 SPLIT_COLUMN_NAME, | |
| 928 validation_size=self.args.validation_size, | 1333 validation_size=self.args.validation_size, |
| 929 label_column=LABEL_COLUMN_NAME, | 1334 label_column=LABEL_COLUMN_NAME, |
| 930 random_state=self.args.random_seed | 1335 random_state=self.args.random_seed, |
| 931 ) | 1336 ) |
| 932 split_info = ( | 1337 split_info = ( |
| 933 "Detected a split column (with values 0 and 2) in the input CSV. " | 1338 "Detected a split column (with values 0 and 2) in the input CSV. " |
| 934 f"Used this column as a base and" | 1339 f"Used this column as a base and reassigned " |
| 935 f"reassigned {self.args.validation_size * 100:.1f}% " | 1340 f"{self.args.validation_size * 100:.1f}% " |
| 936 "of the training set (originally labeled 0) to validation (labeled 1)." | 1341 "of the training set (originally labeled 0) to validation (labeled 1)." |
| 937 ) | 1342 ) |
| 938 | |
| 939 logger.info("Applied custom 0/2 split.") | 1343 logger.info("Applied custom 0/2 split.") |
| 940 elif unique.issubset({0, 1, 2}): | 1344 elif unique.issubset({0, 1, 2}): |
| 941 split_info = "Used user-defined split column from CSV." | 1345 split_info = "Used user-defined split column from CSV." |
| 942 logger.info("Using fixed split as-is.") | 1346 logger.info("Using fixed split as-is.") |
| 943 else: | 1347 else: |
| 948 except Exception: | 1352 except Exception: |
| 949 logger.error("Error processing fixed split", exc_info=True) | 1353 logger.error("Error processing fixed split", exc_info=True) |
| 950 raise | 1354 raise |
| 951 | 1355 |
| 952 def _cleanup_temp_dirs(self) -> None: | 1356 def _cleanup_temp_dirs(self) -> None: |
| 953 """Remove any temporary directories.""" | |
| 954 if self.temp_dir and self.temp_dir.exists(): | 1357 if self.temp_dir and self.temp_dir.exists(): |
| 955 logger.info(f"Cleaning up temp directory: {self.temp_dir}") | 1358 logger.info(f"Cleaning up temp directory: {self.temp_dir}") |
| 956 shutil.rmtree(self.temp_dir, ignore_errors=True) | 1359 shutil.rmtree(self.temp_dir, ignore_errors=True) |
| 957 self.temp_dir = None | 1360 self.temp_dir = None |
| 958 self.image_extract_dir = None | 1361 self.image_extract_dir = None |
| 978 "preprocessing_num_processes": self.args.preprocessing_num_processes, | 1381 "preprocessing_num_processes": self.args.preprocessing_num_processes, |
| 979 "split_probabilities": self.args.split_probabilities, | 1382 "split_probabilities": self.args.split_probabilities, |
| 980 "learning_rate": self.args.learning_rate, | 1383 "learning_rate": self.args.learning_rate, |
| 981 "random_seed": self.args.random_seed, | 1384 "random_seed": self.args.random_seed, |
| 982 "early_stop": self.args.early_stop, | 1385 "early_stop": self.args.early_stop, |
| 1386 "label_column_data_path": csv_path, | |
| 983 } | 1387 } |
| 984 yaml_str = self.backend.prepare_config(backend_args, split_cfg) | 1388 yaml_str = self.backend.prepare_config(backend_args, split_cfg) |
| 985 | 1389 |
| 986 config_file = self.temp_dir / TEMP_CONFIG_FILENAME | 1390 config_file = self.temp_dir / TEMP_CONFIG_FILENAME |
| 987 config_file.write_text(yaml_str) | 1391 config_file.write_text(yaml_str) |
| 989 | 1393 |
| 990 self.backend.run_experiment( | 1394 self.backend.run_experiment( |
| 991 csv_path, | 1395 csv_path, |
| 992 config_file, | 1396 config_file, |
| 993 self.args.output_dir, | 1397 self.args.output_dir, |
| 994 self.args.random_seed | 1398 self.args.random_seed, |
| 995 ) | 1399 ) |
| 996 logger.info("Workflow completed successfully.") | 1400 logger.info("Workflow completed successfully.") |
| 997 self.backend.generate_plots(self.args.output_dir) | 1401 self.backend.generate_plots(self.args.output_dir) |
| 998 report_file = self.backend.generate_html_report( | 1402 report_file = self.backend.generate_html_report( |
| 999 "Image Classification Results", | 1403 "Image Classification Results", |
| 1000 self.args.output_dir, | 1404 self.args.output_dir, |
| 1001 backend_args, | 1405 backend_args, |
| 1002 split_info | 1406 split_info, |
| 1003 ) | 1407 ) |
| 1004 logger.info(f"HTML report generated at: {report_file}") | 1408 logger.info(f"HTML report generated at: {report_file}") |
| 1005 self.backend.convert_parquet_to_csv(self.args.output_dir) | 1409 self.backend.convert_parquet_to_csv(self.args.output_dir) |
| 1006 logger.info("Converted Parquet to CSV.") | 1410 logger.info("Converted Parquet to CSV.") |
| 1007 except Exception: | 1411 except Exception: |
| 1008 logger.error("Workflow execution failed", exc_info=True) | 1412 logger.error("Workflow execution failed", exc_info=True) |
| 1009 raise | 1413 raise |
| 1010 | |
| 1011 finally: | 1414 finally: |
| 1012 self._cleanup_temp_dirs() | 1415 self._cleanup_temp_dirs() |
| 1013 | 1416 |
| 1014 | 1417 |
| 1015 def parse_learning_rate(s): | 1418 def parse_learning_rate(s): |
| 1019 return None | 1422 return None |
| 1020 | 1423 |
| 1021 | 1424 |
| 1022 class SplitProbAction(argparse.Action): | 1425 class SplitProbAction(argparse.Action): |
| 1023 def __call__(self, parser, namespace, values, option_string=None): | 1426 def __call__(self, parser, namespace, values, option_string=None): |
| 1024 # values is a list of three floats | |
| 1025 train, val, test = values | 1427 train, val, test = values |
| 1026 total = train + val + test | 1428 total = train + val + test |
| 1027 if abs(total - 1.0) > 1e-6: | 1429 if abs(total - 1.0) > 1e-6: |
| 1028 parser.error( | 1430 parser.error( |
| 1029 f"--split-probabilities must sum to 1.0; " | 1431 f"--split-probabilities must sum to 1.0; " |
| 1031 ) | 1433 ) |
| 1032 setattr(namespace, self.dest, values) | 1434 setattr(namespace, self.dest, values) |
| 1033 | 1435 |
| 1034 | 1436 |
| 1035 def main(): | 1437 def main(): |
| 1036 | |
| 1037 parser = argparse.ArgumentParser( | 1438 parser = argparse.ArgumentParser( |
| 1038 description="Image Classification Learner with Pluggable Backends" | 1439 description="Image Classification Learner with Pluggable Backends", |
| 1039 ) | 1440 ) |
| 1040 parser.add_argument( | 1441 parser.add_argument( |
| 1041 "--csv-file", required=True, type=Path, | 1442 "--csv-file", |
| 1042 help="Path to the input CSV" | 1443 required=True, |
| 1444 type=Path, | |
| 1445 help="Path to the input CSV", | |
| 1043 ) | 1446 ) |
| 1044 parser.add_argument( | 1447 parser.add_argument( |
| 1045 "--image-zip", required=True, type=Path, | 1448 "--image-zip", |
| 1046 help="Path to the images ZIP" | 1449 required=True, |
| 1450 type=Path, | |
| 1451 help="Path to the images ZIP", | |
| 1047 ) | 1452 ) |
| 1048 parser.add_argument( | 1453 parser.add_argument( |
| 1049 "--model-name", required=True, | 1454 "--model-name", |
| 1455 required=True, | |
| 1050 choices=MODEL_ENCODER_TEMPLATES.keys(), | 1456 choices=MODEL_ENCODER_TEMPLATES.keys(), |
| 1051 help="Which model template to use" | 1457 help="Which model template to use", |
| 1052 ) | 1458 ) |
| 1053 parser.add_argument( | 1459 parser.add_argument( |
| 1054 "--use-pretrained", action="store_true", | 1460 "--use-pretrained", |
| 1055 help="Use pretrained weights for the model" | 1461 action="store_true", |
| 1462 help="Use pretrained weights for the model", | |
| 1056 ) | 1463 ) |
| 1057 parser.add_argument( | 1464 parser.add_argument( |
| 1058 "--fine-tune", action="store_true", | 1465 "--fine-tune", |
| 1059 help="Enable fine-tuning" | 1466 action="store_true", |
| 1467 help="Enable fine-tuning", | |
| 1060 ) | 1468 ) |
| 1061 parser.add_argument( | 1469 parser.add_argument( |
| 1062 "--epochs", type=int, default=10, | 1470 "--epochs", |
| 1063 help="Number of training epochs" | 1471 type=int, |
| 1472 default=10, | |
| 1473 help="Number of training epochs", | |
| 1064 ) | 1474 ) |
| 1065 parser.add_argument( | 1475 parser.add_argument( |
| 1066 "--early-stop", type=int, default=5, | 1476 "--early-stop", |
| 1067 help="Early stopping patience" | 1477 type=int, |
| 1478 default=5, | |
| 1479 help="Early stopping patience", | |
| 1068 ) | 1480 ) |
| 1069 parser.add_argument( | 1481 parser.add_argument( |
| 1070 "--batch-size", type=int, | 1482 "--batch-size", |
| 1071 help="Batch size (None = auto)" | 1483 type=int, |
| 1484 help="Batch size (None = auto)", | |
| 1072 ) | 1485 ) |
| 1073 parser.add_argument( | 1486 parser.add_argument( |
| 1074 "--output-dir", type=Path, default=Path("learner_output"), | 1487 "--output-dir", |
| 1075 help="Where to write outputs" | 1488 type=Path, |
| 1489 default=Path("learner_output"), | |
| 1490 help="Where to write outputs", | |
| 1076 ) | 1491 ) |
| 1077 parser.add_argument( | 1492 parser.add_argument( |
| 1078 "--validation-size", type=float, default=0.15, | 1493 "--validation-size", |
| 1079 help="Fraction for validation (0.0–1.0)" | 1494 type=float, |
| 1495 default=0.15, | |
| 1496 help="Fraction for validation (0.0–1.0)", | |
| 1080 ) | 1497 ) |
| 1081 parser.add_argument( | 1498 parser.add_argument( |
| 1082 "--preprocessing-num-processes", type=int, | 1499 "--preprocessing-num-processes", |
| 1500 type=int, | |
| 1083 default=max(1, os.cpu_count() // 2), | 1501 default=max(1, os.cpu_count() // 2), |
| 1084 help="CPU processes for data prep" | 1502 help="CPU processes for data prep", |
| 1085 ) | 1503 ) |
| 1086 parser.add_argument( | 1504 parser.add_argument( |
| 1087 "--split-probabilities", type=float, nargs=3, | 1505 "--split-probabilities", |
| 1506 type=float, | |
| 1507 nargs=3, | |
| 1088 metavar=("train", "val", "test"), | 1508 metavar=("train", "val", "test"), |
| 1089 action=SplitProbAction, | 1509 action=SplitProbAction, |
| 1090 default=[0.7, 0.1, 0.2], | 1510 default=[0.7, 0.1, 0.2], |
| 1091 help="Random split proportions (e.g., 0.7 0.1 0.2). Only used if no split column is present." | 1511 help="Random split proportions (e.g., 0.7 0.1 0.2). Only used if no split column.", |
| 1092 ) | 1512 ) |
| 1093 parser.add_argument( | 1513 parser.add_argument( |
| 1094 "--random-seed", type=int, default=42, | 1514 "--random-seed", |
| 1095 help="Random seed used for dataset splitting (default: 42)" | 1515 type=int, |
| 1516 default=42, | |
| 1517 help="Random seed used for dataset splitting (default: 42)", | |
| 1096 ) | 1518 ) |
| 1097 parser.add_argument( | 1519 parser.add_argument( |
| 1098 "--learning-rate", type=parse_learning_rate, default=None, | 1520 "--learning-rate", |
| 1099 help="Learning rate. If not provided, Ludwig will auto-select it." | 1521 type=parse_learning_rate, |
| 1522 default=None, | |
| 1523 help="Learning rate. If not provided, Ludwig will auto-select it.", | |
| 1100 ) | 1524 ) |
| 1101 | 1525 |
| 1102 args = parser.parse_args() | 1526 args = parser.parse_args() |
| 1103 | 1527 |
| 1104 # -- Validation -- | |
| 1105 if not 0.0 <= args.validation_size <= 1.0: | 1528 if not 0.0 <= args.validation_size <= 1.0: |
| 1106 parser.error("validation-size must be between 0.0 and 1.0") | 1529 parser.error("validation-size must be between 0.0 and 1.0") |
| 1107 if not args.csv_file.is_file(): | 1530 if not args.csv_file.is_file(): |
| 1108 parser.error(f"CSV not found: {args.csv_file}") | 1531 parser.error(f"CSV not found: {args.csv_file}") |
| 1109 if not args.image_zip.is_file(): | 1532 if not args.image_zip.is_file(): |
| 1110 parser.error(f"ZIP not found: {args.image_zip}") | 1533 parser.error(f"ZIP not found: {args.image_zip}") |
| 1111 | 1534 |
| 1112 # --- Instantiate Backend and Orchestrator --- | |
| 1113 # Use the new LudwigDirectBackend | |
| 1114 backend_instance = LudwigDirectBackend() | 1535 backend_instance = LudwigDirectBackend() |
| 1115 orchestrator = WorkflowOrchestrator(args, backend_instance) | 1536 orchestrator = WorkflowOrchestrator(args, backend_instance) |
| 1116 | 1537 |
| 1117 # --- Run Workflow --- | |
| 1118 exit_code = 0 | 1538 exit_code = 0 |
| 1119 try: | 1539 try: |
| 1120 orchestrator.run() | 1540 orchestrator.run() |
| 1121 logger.info("Main script finished successfully.") | 1541 logger.info("Main script finished successfully.") |
| 1122 except Exception as e: | 1542 except Exception as e: |
| 1124 exit_code = 1 | 1544 exit_code = 1 |
| 1125 finally: | 1545 finally: |
| 1126 sys.exit(exit_code) | 1546 sys.exit(exit_code) |
| 1127 | 1547 |
| 1128 | 1548 |
| 1129 if __name__ == '__main__': | 1549 if __name__ == "__main__": |
| 1130 try: | 1550 try: |
| 1131 import ludwig | 1551 import ludwig |
| 1552 | |
| 1132 logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}") | 1553 logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}") |
| 1133 except ImportError: | 1554 except ImportError: |
| 1134 logger.error("Ludwig library not found. Please ensure Ludwig is installed ('pip install ludwig[image]')") | 1555 logger.error( |
| 1556 "Ludwig library not found. Please ensure Ludwig is installed " | |
| 1557 "('pip install ludwig[image]')" | |
| 1558 ) | |
| 1135 sys.exit(1) | 1559 sys.exit(1) |
| 1136 | 1560 |
| 1137 main() | 1561 main() |
