diff flux.py @ 1:7933bed1ffab draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flux commit 0959e15618c76e05c78e5811218c9c4bc343db27
author bgruening
date Tue, 29 Oct 2024 13:53:36 +0000
parents 0d0561746128
children
line wrap: on
line diff
--- a/flux.py	Mon Oct 14 16:50:38 2024 +0000
+++ b/flux.py	Tue Oct 29 13:53:36 2024 +0000
@@ -1,9 +1,10 @@
+import os
 import sys
 
 import torch
 from diffusers import FluxPipeline
 
-model = sys.argv[1]
+model_path = sys.argv[1]
 
 prompt_type = sys.argv[2]
 if prompt_type == "file":
@@ -12,12 +13,20 @@
 elif prompt_type == "text":
     prompt = sys.argv[3]
 
-if model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]:
+if "dev" in model_path:
+    num_inference_steps = 20
+elif "schnell" in model_path:
+    num_inference_steps = 4
+else:
     print("Invalid model!")
     sys.exit(1)
 
+snapshots = []
+for d in os.listdir(os.path.join(model_path, "snapshots")):
+    snapshots.append(os.path.join(model_path, "snapshots", d))
+latest_snapshot_path = max(snapshots, key=os.path.getmtime)
 
-pipe = FluxPipeline.from_pretrained(model, torch_dtype=torch.bfloat16)
+pipe = FluxPipeline.from_pretrained(latest_snapshot_path, torch_dtype=torch.bfloat16)
 pipe.enable_sequential_cpu_offload()
 pipe.vae.enable_slicing()
 pipe.vae.enable_tiling()
@@ -25,7 +34,7 @@
 
 image = pipe(
     prompt,
-    num_inference_steps=4,
+    num_inference_steps=num_inference_steps,
     generator=torch.Generator("cpu").manual_seed(42),
 ).images[0]