Mercurial > repos > bgruening > black_forest_labs_flux
comparison flux.py @ 0:0d0561746128 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flux commit 71b3dacee16dc999cb4fa113858d6ace1781c71c
| author | bgruening |
|---|---|
| date | Mon, 14 Oct 2024 16:50:38 +0000 |
| parents | |
| children | 7933bed1ffab |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:0d0561746128 |
|---|---|
| 1 import sys | |
| 2 | |
| 3 import torch | |
| 4 from diffusers import FluxPipeline | |
| 5 | |
| 6 model = sys.argv[1] | |
| 7 | |
| 8 prompt_type = sys.argv[2] | |
| 9 if prompt_type == "file": | |
| 10 with open(sys.argv[3], "r") as f: | |
| 11 prompt = f.read().strip() | |
| 12 elif prompt_type == "text": | |
| 13 prompt = sys.argv[3] | |
| 14 | |
| 15 if model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]: | |
| 16 print("Invalid model!") | |
| 17 sys.exit(1) | |
| 18 | |
| 19 | |
| 20 pipe = FluxPipeline.from_pretrained(model, torch_dtype=torch.bfloat16) | |
| 21 pipe.enable_sequential_cpu_offload() | |
| 22 pipe.vae.enable_slicing() | |
| 23 pipe.vae.enable_tiling() | |
| 24 pipe.to(torch.float16) | |
| 25 | |
| 26 image = pipe( | |
| 27 prompt, | |
| 28 num_inference_steps=4, | |
| 29 generator=torch.Generator("cpu").manual_seed(42), | |
| 30 ).images[0] | |
| 31 | |
| 32 image.save("output.png") |
