Mercurial > repos > bgruening > black_forest_labs_flux
comparison flux.py @ 1:7933bed1ffab draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flux commit 0959e15618c76e05c78e5811218c9c4bc343db27
author | bgruening |
---|---|
date | Tue, 29 Oct 2024 13:53:36 +0000 |
parents | 0d0561746128 |
children |
comparison
equal
deleted
inserted
replaced
0:0d0561746128 | 1:7933bed1ffab |
---|---|
1 import os | |
1 import sys | 2 import sys |
2 | 3 |
3 import torch | 4 import torch |
4 from diffusers import FluxPipeline | 5 from diffusers import FluxPipeline |
5 | 6 |
6 model = sys.argv[1] | 7 model_path = sys.argv[1] |
7 | 8 |
8 prompt_type = sys.argv[2] | 9 prompt_type = sys.argv[2] |
9 if prompt_type == "file": | 10 if prompt_type == "file": |
10 with open(sys.argv[3], "r") as f: | 11 with open(sys.argv[3], "r") as f: |
11 prompt = f.read().strip() | 12 prompt = f.read().strip() |
12 elif prompt_type == "text": | 13 elif prompt_type == "text": |
13 prompt = sys.argv[3] | 14 prompt = sys.argv[3] |
14 | 15 |
15 if model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]: | 16 if "dev" in model_path: |
17 num_inference_steps = 20 | |
18 elif "schnell" in model_path: | |
19 num_inference_steps = 4 | |
20 else: | |
16 print("Invalid model!") | 21 print("Invalid model!") |
17 sys.exit(1) | 22 sys.exit(1) |
18 | 23 |
24 snapshots = [] | |
25 for d in os.listdir(os.path.join(model_path, "snapshots")): | |
26 snapshots.append(os.path.join(model_path, "snapshots", d)) | |
27 latest_snapshot_path = max(snapshots, key=os.path.getmtime) | |
19 | 28 |
20 pipe = FluxPipeline.from_pretrained(model, torch_dtype=torch.bfloat16) | 29 pipe = FluxPipeline.from_pretrained(latest_snapshot_path, torch_dtype=torch.bfloat16) |
21 pipe.enable_sequential_cpu_offload() | 30 pipe.enable_sequential_cpu_offload() |
22 pipe.vae.enable_slicing() | 31 pipe.vae.enable_slicing() |
23 pipe.vae.enable_tiling() | 32 pipe.vae.enable_tiling() |
24 pipe.to(torch.float16) | 33 pipe.to(torch.float16) |
25 | 34 |
26 image = pipe( | 35 image = pipe( |
27 prompt, | 36 prompt, |
28 num_inference_steps=4, | 37 num_inference_steps=num_inference_steps, |
29 generator=torch.Generator("cpu").manual_seed(42), | 38 generator=torch.Generator("cpu").manual_seed(42), |
30 ).images[0] | 39 ).images[0] |
31 | 40 |
32 image.save("output.png") | 41 image.save("output.png") |