Mercurial > repos > bgruening > json2yolosegment
comparison yolov8.py @ 5:ce7a96be8cb6 draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools commit c6c9d43a4ecdc88ebdeaf3451453a550f159c506
| author | bgruening |
|---|---|
| date | Mon, 21 Jul 2025 15:52:12 +0000 |
| parents | 7db48c618bbe |
| children |
comparison
equal
deleted
inserted
replaced
| 4:7db48c618bbe | 5:ce7a96be8cb6 |
|---|---|
| 1 import argparse | 1 import argparse |
| 2 import csv | |
| 2 import os | 3 import os |
| 3 import pathlib | 4 import pathlib |
| 4 import time | 5 import time |
| 5 from argparse import RawTextHelpFormatter | 6 from argparse import RawTextHelpFormatter |
| 6 from collections import defaultdict | 7 from collections import defaultdict |
| 8 import cv2 | 9 import cv2 |
| 9 import numpy as np | 10 import numpy as np |
| 10 from termcolor import colored | 11 from termcolor import colored |
| 11 from tifffile import imwrite | 12 from tifffile import imwrite |
| 12 from ultralytics import YOLO | 13 from ultralytics import YOLO |
| 13 | |
| 14 | 14 |
| 15 # | 15 # |
| 16 # Input arguments | 16 # Input arguments |
| 17 # | 17 # |
| 18 parser = argparse.ArgumentParser( | 18 parser = argparse.ArgumentParser( |
| 77 help="Format of the YOLO model i.e pt, yaml etc.", | 77 help="Format of the YOLO model i.e pt, yaml etc.", |
| 78 default='pt', type=str) | 78 default='pt', type=str) |
| 79 parser.add_argument("--class_names_file", | 79 parser.add_argument("--class_names_file", |
| 80 help="Path to the text file containing class names.", | 80 help="Path to the text file containing class names.", |
| 81 type=str) | 81 type=str) |
| 82 | |
| 83 # For training the model and prediction | 82 # For training the model and prediction |
| 84 parser.add_argument("--mode", | 83 parser.add_argument("--mode", |
| 85 help=( | 84 help=( |
| 86 "detection, segmentation, classification, and pose \n. " | 85 "detection, segmentation, classification, and pose \n. " |
| 87 " Only detection mode available currently i.e. `detect`" | 86 " Only detection mode available currently i.e. `detect`" |
| 127 default='bytetrack.yaml', type=str) | 126 default='bytetrack.yaml', type=str) |
| 128 | 127 |
| 129 # For headless operation | 128 # For headless operation |
| 130 parser.add_argument('--headless', action='store_true') | 129 parser.add_argument('--headless', action='store_true') |
| 131 parser.add_argument('--nextflow', action='store_true') | 130 parser.add_argument('--nextflow', action='store_true') |
| 131 | |
| 132 | 132 |
| 133 # For data augmentation | 133 # For data augmentation |
| 134 parser.add_argument("--hsv_h", | 134 parser.add_argument("--hsv_h", |
| 135 help="(float) image HSV-Hue augmentation (fraction)", | 135 help="(float) image HSV-Hue augmentation (fraction)", |
| 136 default=0.015, type=float) | 136 default=0.015, type=float) |
| 169 "emphasize central features and adapt to object scales, " | 169 "emphasize central features and adapt to object scales, " |
| 170 "reducing background distractions", | 170 "reducing background distractions", |
| 171 default=1.0, type=float) | 171 default=1.0, type=float) |
| 172 | 172 |
| 173 | 173 |
| 174 # | |
| 175 # Functions | |
| 176 # | |
| 177 # Train a new model on the dataset mentioned in yaml file | 174 # Train a new model on the dataset mentioned in yaml file |
| 178 def trainModel(model_path, model_name, yaml_filepath, **kwargs): | 175 def trainModel(model_path, model_name, yaml_filepath, **kwargs): |
| 179 if "imgsz" in kwargs: | 176 if "imgsz" in kwargs: |
| 180 image_size = kwargs['imgsz'] | 177 image_size = kwargs['imgsz'] |
| 181 else: | 178 else: |
| 270 imgsz=image_size, verbose=True, hsv_h=aug_hsv_h, | 267 imgsz=image_size, verbose=True, hsv_h=aug_hsv_h, |
| 271 hsv_s=aug_hsv_s, hsv_v=aug_hsv_v, degrees=aug_degrees, | 268 hsv_s=aug_hsv_s, hsv_v=aug_hsv_v, degrees=aug_degrees, |
| 272 translate=aug_translate, shear=aug_shear, scale=aug_scale, | 269 translate=aug_translate, shear=aug_shear, scale=aug_scale, |
| 273 perspective=aug_perspective, fliplr=aug_fliplr, | 270 perspective=aug_perspective, fliplr=aug_fliplr, |
| 274 flipud=aug_flipud, mosaic=aug_mosaic, crop_fraction=aug_crop_fraction, | 271 flipud=aug_flipud, mosaic=aug_mosaic, crop_fraction=aug_crop_fraction, |
| 275 weight_decay=weight_decay, lr0=init_lr, seed=42) | 272 weight_decay=weight_decay, lr0=init_lr) |
| 276 return model | 273 return model |
| 277 | 274 |
| 278 | 275 |
| 279 # Validate the trained model | 276 # Validate the trained model |
| 280 def validateModel(model): | 277 def validateModel(model): |
| 281 # Validate the model | |
| 282 metrics = model.val() # no args needed, dataset & settings remembered | 278 metrics = model.val() # no args needed, dataset & settings remembered |
| 283 metrics.box.map # map50-95 | 279 metrics.box.map # map50-95 |
| 284 metrics.box.map50 # map50 | 280 metrics.box.map50 # map50 |
| 285 metrics.box.map75 # map75 | 281 metrics.box.map75 # map75 |
| 286 metrics.box.maps # a list contains map50-95 of each category | 282 metrics.box.maps # a list contains map50-95 of each category |
| 314 maximum_detections = 300 | 310 maximum_detections = 300 |
| 315 | 311 |
| 316 run_save_dir = kwargs['run_dir'] # For Galaxy, run_save_dir is always provided via xml wrapper | 312 run_save_dir = kwargs['run_dir'] # For Galaxy, run_save_dir is always provided via xml wrapper |
| 317 if "foldername" in kwargs: | 313 if "foldername" in kwargs: |
| 318 save_folder_name = kwargs['foldername'] | 314 save_folder_name = kwargs['foldername'] |
| 315 | |
| 319 # infer on a local image or directory containing images/videos | 316 # infer on a local image or directory containing images/videos |
| 320 prediction = model.predict(source=source_datapath, save=True, stream=True, | 317 prediction = model.predict(source=source_datapath, save=True, stream=True, |
| 321 conf=confidence, imgsz=image_size, | 318 conf=confidence, imgsz=image_size, |
| 322 save_conf=True, iou=iou_value, max_det=maximum_detections, | 319 save_conf=True, iou=iou_value, max_det=maximum_detections, |
| 323 classes=class_array, save_txt=False, | 320 classes=class_array, save_txt=False, |
| 327 | 324 |
| 328 # Save bounding boxes | 325 # Save bounding boxes |
| 329 def save_yolo_bounding_boxes_to_txt(predictions, save_dir): | 326 def save_yolo_bounding_boxes_to_txt(predictions, save_dir): |
| 330 """ | 327 """ |
| 331 Function to save YOLO bounding boxes to text files. | 328 Function to save YOLO bounding boxes to text files. |
| 329 | |
| 332 Parameters: | 330 Parameters: |
| 333 - predictions: List of results from YOLO model inference. | 331 - predictions: List of results from YOLO model inference. |
| 334 - save_dir: Directory where the text files will be saved. | 332 - save_dir: Directory where the text files will be saved. |
| 335 """ | 333 """ |
| 336 for result in predictions: | 334 for result in predictions: |
| 337 result = result.to("cpu").numpy() | 335 result = result.to("cpu").numpy() |
| 338 # Using bounding_boxes, confidence_scores, and class_num which are defined in the list | 336 # Using bounding_boxes, confidence_scores, and class_num which are defined in the list |
| 339 bounding_boxes = result.boxes.xyxy # Bounding boxes in xyxy format | 337 bounding_boxes = result.boxes.xyxy # Bounding boxes in xyxy format |
| 340 confidence_scores = result.boxes.conf # Confidence scores | 338 confidence_scores = result.boxes.conf # Confidence scores |
| 341 class_nums = result.boxes.cls # Class numbers | 339 class_nums = result.boxes.cls # Class numbers |
| 340 | |
| 342 # Create save directory if it doesn't exist | 341 # Create save directory if it doesn't exist |
| 343 save_path = pathlib.Path(save_dir).absolute() | 342 save_path = pathlib.Path(save_dir).absolute() |
| 344 save_path.mkdir(parents=True, exist_ok=True) | 343 save_path.mkdir(parents=True, exist_ok=True) |
| 344 | |
| 345 # Construct filename for the text file | 345 # Construct filename for the text file |
| 346 image_filename = pathlib.Path(result.path).stem | 346 image_filename = pathlib.Path(result.path).stem |
| 347 text_filename = save_path / f"{image_filename}.txt" | 347 text_filename = save_path / f"{image_filename}.txt" |
| 348 | |
| 348 # Write bounding boxes info into the text file | 349 # Write bounding boxes info into the text file |
| 349 with open(text_filename, 'w') as f: | 350 with open(text_filename, 'w') as f: |
| 350 for i in range(bounding_boxes.shape[0]): | 351 for i in range(bounding_boxes.shape[0]): |
| 351 x1, y1, x2, y2 = bounding_boxes[i] | 352 x1, y1, x2, y2 = bounding_boxes[i] |
| 352 confidence = confidence_scores[i] | 353 confidence = confidence_scores[i] |
| 353 class_num = int(class_nums[i]) | 354 class_num = int(class_nums[i]) |
| 354 f.write(f'{class_num:01} {x1:06.2f} {y1:06.2f} {x2:06.2f} {y2:06.2f} {confidence:0.02} \n') | 355 f.write(f'{class_num:01} {x1:06.2f} {y1:06.2f} {x2:06.2f} {y2:06.2f} {confidence:0.02} \n') |
| 355 print(colored(f"Bounding boxes saved in: {text_filename}", 'green')) | 356 print(colored(f"Bounding boxes saved in: {text_filename}", 'green')) |
| 356 | 357 |
| 357 | 358 |
| 359 # Main code | |
| 358 if __name__ == '__main__': | 360 if __name__ == '__main__': |
| 359 args = parser.parse_args() | 361 args = parser.parse_args() |
| 360 os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" | 362 os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" |
| 363 | |
| 361 # Train/load model | 364 # Train/load model |
| 362 if (args.train): | 365 if (args.train): |
| 363 model = trainModel(args.model_path, args.model_name, args.yaml_path, | 366 model = trainModel(args.model_path, args.model_name, args.yaml_path, |
| 364 imgsz=args.image_size, epochs=args.epochs, | 367 imgsz=args.image_size, epochs=args.epochs, |
| 365 hsv_h=args.hsv_h, hsv_s=args.hsv_s, hsv_v=args.hsv_v, | 368 hsv_h=args.hsv_h, hsv_s=args.hsv_s, hsv_v=args.hsv_v, |
| 375 "train", "weights", "best.pt")) and (args.model_name == 'sam'): | 378 "train", "weights", "best.pt")) and (args.model_name == 'sam'): |
| 376 model = YOLO(os.path.join(train_save_path, | 379 model = YOLO(os.path.join(train_save_path, |
| 377 "train", "weights", "best.pt")) | 380 "train", "weights", "best.pt")) |
| 378 else: | 381 else: |
| 379 model = YOLO(os.path.join(args.model_path, | 382 model = YOLO(os.path.join(args.model_path, |
| 380 args.model_name + ".pt")) | 383 args.model_name + ".pt")) |
| 381 model.info(verbose=True) | 384 model.info(verbose=True) |
| 382 elapsed = time.time() - t | 385 elapsed = time.time() - t |
| 383 print(colored(f"\nYOLO model loaded in : '{elapsed}' sec \n", 'white', 'on_yellow')) | 386 print(colored(f"\nYOLO model loaded in : '{elapsed}' sec \n", 'white', 'on_yellow')) |
| 384 | 387 |
| 385 if (args.save_dir): | 388 if (args.save_dir): |
| 420 elif (args.mode == "track"): | 423 elif (args.mode == "track"): |
| 421 results = model.track(source=datapath_for_prediction, | 424 results = model.track(source=datapath_for_prediction, |
| 422 tracker=args.tracker_file, | 425 tracker=args.tracker_file, |
| 423 conf=args.confidence, | 426 conf=args.confidence, |
| 424 iou=args.iou, | 427 iou=args.iou, |
| 425 persist=False, | 428 persist=True, |
| 426 show=True, | 429 show=False, |
| 427 save=True, | 430 save=True, |
| 428 project=args.run_dir, | 431 project=args.run_dir, |
| 429 name=args.foldername) | 432 name=args.foldername) |
| 430 # Store the track history | 433 # Store the track history |
| 431 track_history = defaultdict(lambda: []) | 434 track_history = defaultdict(lambda: []) |
| 432 | 435 |
| 433 for result in results: | 436 tsv_path = os.path.join(args.save_dir, "tracks.tsv") |
| 434 # Get the boxes and track IDs | 437 with open(tsv_path, "w", newline="") as tsvfile: |
| 435 if result.boxes and result.boxes.is_track: | 438 writer = csv.writer(tsvfile, delimiter='\t') |
| 436 boxes = result.boxes.xywh.cpu() | 439 writer.writerow(['track_id', 'frame', 'class', 'centroid_x', 'centroid_y']) |
| 437 track_ids = result.boxes.id.int().cpu().tolist() | 440 frame_idx = 0 |
| 438 # Visualize the result on the frame | 441 for result in results: |
| 439 frame = result.plot() | 442 # Get the boxes and track IDs |
| 440 # Plot the tracks | 443 if result.boxes and result.boxes.is_track: |
| 441 for box, track_id in zip(boxes, track_ids): | 444 track_ids = result.boxes.id.int().cpu().tolist() |
| 442 x, y, w, h = box | 445 labels = result.boxes.cls.int().cpu().tolist() if hasattr(result.boxes, "cls") else [0] * len(track_ids) |
| 443 track = track_history[track_id] | 446 # Prepare mask image |
| 444 track.append((float(x), float(y))) # x, y center point | 447 img_shape = result.orig_shape if hasattr(result, "orig_shape") else result.orig_img.shape |
| 445 if len(track) > 30: # retain 30 tracks for 30 frames | 448 mask = np.zeros(img_shape[:2], dtype=np.uint16) |
| 446 track.pop(0) | 449 # Check if polygons (masks) are available |
| 447 | 450 if hasattr(result, "masks") and result.masks is not None and hasattr(result.masks, "xy"): |
| 448 # Draw the tracking lines | 451 polygons = result.masks.xy |
| 449 points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2)) | 452 for i, (track_id, label) in enumerate(zip(track_ids, labels)): |
| 450 cv2.polylines(frame, [points], isClosed=False, color=(230, 230, 230), thickness=2) | 453 if i < len(polygons): |
| 451 | 454 contour = polygons[i].astype(np.int32) |
| 452 # Display the annotated frame | 455 contour = contour.reshape(-1, 1, 2) |
| 453 cv2.imshow("YOLO11 Tracking", frame) | 456 cv2.drawContours(mask, [contour], -1, int(track_id), cv2.FILLED) |
| 454 print(colored(f"Tracking results saved in : '{args.save_dir}' \n", 'green')) | 457 # Calculate centroid of the polygon |
| 458 M = cv2.moments(contour) | |
| 459 if M["m00"] != 0: | |
| 460 cx = float(M["m10"] / M["m00"]) | |
| 461 cy = float(M["m01"] / M["m00"]) | |
| 462 else: | |
| 463 cx, cy = 0.0, 0.0 | |
| 464 writer.writerow([track_id, frame_idx, label, cx, cy]) | |
| 465 else: | |
| 466 # Fallback to bounding boxes if polygons are not available | |
| 467 boxes = result.boxes.xywh.cpu() | |
| 468 xyxy_boxes = result.boxes.xyxy.cpu().numpy() | |
| 469 for i, (box, xyxy, track_id, label) in enumerate(zip(boxes, xyxy_boxes, track_ids, labels)): | |
| 470 x, y, w, h = box | |
| 471 writer.writerow([track_id, frame_idx, label, float(x), float(y)]) | |
| 472 x1, y1, x2, y2 = map(int, xyxy) | |
| 473 cv2.rectangle(mask, (x1, y1), (x2, y2), int(track_id), thickness=-1) | |
| 474 # Collect masks for TYX stack | |
| 475 if frame_idx == 0: | |
| 476 mask_stack = [] | |
| 477 mask_stack.append(mask) | |
| 478 frame_idx += 1 | |
| 479 # Save TYX stack (T=frames, Y, X) | |
| 480 if 'mask_stack' in locals() and len(mask_stack) > 0: | |
| 481 tyx_array = np.stack(mask_stack, axis=0) | |
| 482 # Remove string from last underscore in filename | |
| 483 stem = pathlib.Path(result.path).stem | |
| 484 stem = stem.rsplit('_', 1)[0] if '_' in stem else stem | |
| 485 mask_save_as = str(pathlib.Path(os.path.join(args.save_dir, stem + "_mask.tiff")).absolute()) | |
| 486 imwrite(mask_save_as, tyx_array) | |
| 487 print(colored(f"TYX mask stack saved as : '{mask_save_as}'", 'magenta')) | |
| 488 print(colored(f"Tracking results saved in : '{args.save_dir}' \n", 'green')) | |
| 455 elif (args.mode == "segment"): | 489 elif (args.mode == "segment"): |
| 456 # Read class names from the file | 490 # Read class names from the file |
| 457 with open(args.class_names_file, 'r') as f: | 491 with open(args.class_names_file, 'r') as f: |
| 458 class_names = [line.strip() for line in f.readlines()] | 492 class_names = [line.strip() for line in f.readlines()] |
| 493 # Create a mapping from class names to indices | |
| 459 class_to_index = {class_name: i for i, class_name in enumerate(class_names)} | 494 class_to_index = {class_name: i for i, class_name in enumerate(class_names)} |
| 460 | 495 |
| 461 # Save polygon coordinates | 496 # Save polygon coordinates |
| 462 for result in predictions: | 497 for result in predictions: |
| 498 # Create binary mask | |
| 463 img = np.copy(result.orig_img) | 499 img = np.copy(result.orig_img) |
| 464 filename = pathlib.Path(result.path).stem | 500 filename = pathlib.Path(result.path).stem |
| 465 b_mask = np.zeros(img.shape[:2], np.uint8) | 501 b_mask = np.zeros(img.shape[:2], np.uint8) |
| 466 mask_save_as = str(pathlib.Path(os.path.join(args.save_dir, filename + "_mask.tiff")).absolute()) | 502 mask_save_as = str(pathlib.Path(os.path.join(args.save_dir, filename + "_mask.tiff")).absolute()) |
| 503 # Define output file path for text file | |
| 504 output_filename = os.path.splitext(filename)[0] + ".txt" | |
| 467 txt_save_as = str(pathlib.Path(os.path.join(args.save_dir, filename + ".txt")).absolute()) | 505 txt_save_as = str(pathlib.Path(os.path.join(args.save_dir, filename + ".txt")).absolute()) |
| 468 | 506 instance_id = 1 # Start instance IDs from 1 |
| 469 for c, ci in enumerate(result): | 507 for c, ci in enumerate(result): |
| 470 if ci.masks is not None and ci.masks.xy: | 508 # Extract contour result |
| 471 # Extract contour | 509 contour = ci.masks.xy.pop() |
| 472 contour = ci.masks.xy.pop() | 510 contour = contour.astype(np.int32) |
| 473 contour = contour.astype(np.int32).reshape(-1, 1, 2) | 511 contour = contour.reshape(-1, 1, 2) |
| 474 _ = cv2.drawContours(b_mask, [contour], -1, (255, 255, 255), cv2.FILLED) | 512 # Draw contour onto mask with unique instance id |
| 475 | 513 _ = cv2.drawContours(b_mask, [contour], -1, instance_id, cv2.FILLED) |
| 476 # Normalized polygon points | 514 |
| 477 points = ci.masks.xyn.pop() | 515 # Normalized polygon points |
| 478 obj_class = int(ci.boxes.cls.to("cpu").numpy().item()) | 516 points = ci.masks.xyn.pop() |
| 479 confidence = result.boxes.conf.to("cpu").numpy()[c] | 517 confidence = result.boxes.conf.to("cpu").numpy()[c] |
| 480 | 518 |
| 481 with open(txt_save_as, 'a') as f: | 519 with open(txt_save_as, 'a') as f: |
| 482 segmentation_points = ['{} {}'.format(points[i][0], points[i][1]) for i in range(len(points))] | 520 segmentation_points = ['{} {}'.format(points[i][0], points[i][1]) for i in range(len(points))] |
| 483 segmentation_points_string = ' '.join(segmentation_points) | 521 segmentation_points_string = ' '.join(segmentation_points) |
| 484 line = '{} {} {}\n'.format(obj_class, segmentation_points_string, confidence) | 522 line = '{} {} {}\n'.format(instance_id, segmentation_points_string, confidence) |
| 485 f.write(line) | 523 f.write(line) |
| 486 else: | 524 |
| 487 print(colored(f"⚠️ No mask found for object {c} in '{filename}'. Skipping.", "yellow")) | 525 instance_id += 1 # Increment for next object |
| 488 | 526 |
| 489 # Overlay mask onto original image | 527 imwrite(mask_save_as, b_mask, imagej=True) # save label mask image |
| 490 colored_mask = cv2.merge([b_mask, np.zeros_like(b_mask), np.zeros_like(b_mask)]) | 528 print(colored(f"Saved label mask as : \n '{mask_save_as}' \n", 'magenta')) |
| 491 blended = cv2.addWeighted(img, 1.0, colored_mask, 0.5, 0) | |
| 492 overlay_path = os.path.join(args.save_dir, filename + "_overlay.jpg") | |
| 493 cv2.imwrite(overlay_path, blended) | |
| 494 | |
| 495 imwrite(mask_save_as, b_mask, imagej=True) | |
| 496 print(colored(f"Saved binary mask as : \n '{mask_save_as}' \n", 'magenta')) | |
| 497 print(colored(f"Polygon coordinates saved as : \n '{txt_save_as}' \n", 'cyan')) | 529 print(colored(f"Polygon coordinates saved as : \n '{txt_save_as}' \n", 'cyan')) |
| 530 else: | |
| 531 raise Exception(("Currently only 'detect', 'segment' and 'track' modes are available")) |
