comparison flux.py @ 1:7933bed1ffab draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flux commit 0959e15618c76e05c78e5811218c9c4bc343db27
author bgruening
date Tue, 29 Oct 2024 13:53:36 +0000
parents 0d0561746128
children
comparison
equal deleted inserted replaced
0:0d0561746128 1:7933bed1ffab
1 import os
1 import sys 2 import sys
2 3
3 import torch 4 import torch
4 from diffusers import FluxPipeline 5 from diffusers import FluxPipeline
5 6
6 model = sys.argv[1] 7 model_path = sys.argv[1]
7 8
8 prompt_type = sys.argv[2] 9 prompt_type = sys.argv[2]
9 if prompt_type == "file": 10 if prompt_type == "file":
10 with open(sys.argv[3], "r") as f: 11 with open(sys.argv[3], "r") as f:
11 prompt = f.read().strip() 12 prompt = f.read().strip()
12 elif prompt_type == "text": 13 elif prompt_type == "text":
13 prompt = sys.argv[3] 14 prompt = sys.argv[3]
14 15
15 if model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]: 16 if "dev" in model_path:
17 num_inference_steps = 20
18 elif "schnell" in model_path:
19 num_inference_steps = 4
20 else:
16 print("Invalid model!") 21 print("Invalid model!")
17 sys.exit(1) 22 sys.exit(1)
18 23
24 snapshots = []
25 for d in os.listdir(os.path.join(model_path, "snapshots")):
26 snapshots.append(os.path.join(model_path, "snapshots", d))
27 latest_snapshot_path = max(snapshots, key=os.path.getmtime)
19 28
20 pipe = FluxPipeline.from_pretrained(model, torch_dtype=torch.bfloat16) 29 pipe = FluxPipeline.from_pretrained(latest_snapshot_path, torch_dtype=torch.bfloat16)
21 pipe.enable_sequential_cpu_offload() 30 pipe.enable_sequential_cpu_offload()
22 pipe.vae.enable_slicing() 31 pipe.vae.enable_slicing()
23 pipe.vae.enable_tiling() 32 pipe.vae.enable_tiling()
24 pipe.to(torch.float16) 33 pipe.to(torch.float16)
25 34
26 image = pipe( 35 image = pipe(
27 prompt, 36 prompt,
28 num_inference_steps=4, 37 num_inference_steps=num_inference_steps,
29 generator=torch.Generator("cpu").manual_seed(42), 38 generator=torch.Generator("cpu").manual_seed(42),
30 ).images[0] 39 ).images[0]
31 40
32 image.save("output.png") 41 image.save("output.png")