Mercurial > repos > bgruening > black_forest_labs_flux
diff flux.py @ 1:7933bed1ffab draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flux commit 0959e15618c76e05c78e5811218c9c4bc343db27
author | bgruening |
---|---|
date | Tue, 29 Oct 2024 13:53:36 +0000 |
parents | 0d0561746128 |
children |
line wrap: on
line diff
--- a/flux.py Mon Oct 14 16:50:38 2024 +0000 +++ b/flux.py Tue Oct 29 13:53:36 2024 +0000 @@ -1,9 +1,10 @@ +import os import sys import torch from diffusers import FluxPipeline -model = sys.argv[1] +model_path = sys.argv[1] prompt_type = sys.argv[2] if prompt_type == "file": @@ -12,12 +13,20 @@ elif prompt_type == "text": prompt = sys.argv[3] -if model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]: +if "dev" in model_path: + num_inference_steps = 20 +elif "schnell" in model_path: + num_inference_steps = 4 +else: print("Invalid model!") sys.exit(1) +snapshots = [] +for d in os.listdir(os.path.join(model_path, "snapshots")): + snapshots.append(os.path.join(model_path, "snapshots", d)) +latest_snapshot_path = max(snapshots, key=os.path.getmtime) -pipe = FluxPipeline.from_pretrained(model, torch_dtype=torch.bfloat16) +pipe = FluxPipeline.from_pretrained(latest_snapshot_path, torch_dtype=torch.bfloat16) pipe.enable_sequential_cpu_offload() pipe.vae.enable_slicing() pipe.vae.enable_tiling() @@ -25,7 +34,7 @@ image = pipe( prompt, - num_inference_steps=4, + num_inference_steps=num_inference_steps, generator=torch.Generator("cpu").manual_seed(42), ).images[0]