Mercurial > repos > bgruening > black_forest_labs_flux
comparison flux.py @ 1:7933bed1ffab draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flux commit 0959e15618c76e05c78e5811218c9c4bc343db27
| author | bgruening |
|---|---|
| date | Tue, 29 Oct 2024 13:53:36 +0000 |
| parents | 0d0561746128 |
| children | 7939ae8c5fd5 |
comparison
equal
deleted
inserted
replaced
| 0:0d0561746128 | 1:7933bed1ffab |
|---|---|
| 1 import os | |
| 1 import sys | 2 import sys |
| 2 | 3 |
| 3 import torch | 4 import torch |
| 4 from diffusers import FluxPipeline | 5 from diffusers import FluxPipeline |
| 5 | 6 |
| 6 model = sys.argv[1] | 7 model_path = sys.argv[1] |
| 7 | 8 |
| 8 prompt_type = sys.argv[2] | 9 prompt_type = sys.argv[2] |
| 9 if prompt_type == "file": | 10 if prompt_type == "file": |
| 10 with open(sys.argv[3], "r") as f: | 11 with open(sys.argv[3], "r") as f: |
| 11 prompt = f.read().strip() | 12 prompt = f.read().strip() |
| 12 elif prompt_type == "text": | 13 elif prompt_type == "text": |
| 13 prompt = sys.argv[3] | 14 prompt = sys.argv[3] |
| 14 | 15 |
| 15 if model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]: | 16 if "dev" in model_path: |
| 17 num_inference_steps = 20 | |
| 18 elif "schnell" in model_path: | |
| 19 num_inference_steps = 4 | |
| 20 else: | |
| 16 print("Invalid model!") | 21 print("Invalid model!") |
| 17 sys.exit(1) | 22 sys.exit(1) |
| 18 | 23 |
| 24 snapshots = [] | |
| 25 for d in os.listdir(os.path.join(model_path, "snapshots")): | |
| 26 snapshots.append(os.path.join(model_path, "snapshots", d)) | |
| 27 latest_snapshot_path = max(snapshots, key=os.path.getmtime) | |
| 19 | 28 |
| 20 pipe = FluxPipeline.from_pretrained(model, torch_dtype=torch.bfloat16) | 29 pipe = FluxPipeline.from_pretrained(latest_snapshot_path, torch_dtype=torch.bfloat16) |
| 21 pipe.enable_sequential_cpu_offload() | 30 pipe.enable_sequential_cpu_offload() |
| 22 pipe.vae.enable_slicing() | 31 pipe.vae.enable_slicing() |
| 23 pipe.vae.enable_tiling() | 32 pipe.vae.enable_tiling() |
| 24 pipe.to(torch.float16) | 33 pipe.to(torch.float16) |
| 25 | 34 |
| 26 image = pipe( | 35 image = pipe( |
| 27 prompt, | 36 prompt, |
| 28 num_inference_steps=4, | 37 num_inference_steps=num_inference_steps, |
| 29 generator=torch.Generator("cpu").manual_seed(42), | 38 generator=torch.Generator("cpu").manual_seed(42), |
| 30 ).images[0] | 39 ).images[0] |
| 31 | 40 |
| 32 image.save("output.png") | 41 image.save("output.png") |
