update to work interactive mode

This commit is contained in:
kohya-ss
2024-08-12 13:24:10 +09:00
parent 9e09a69df1
commit 4af36f9632
2 changed files with 29 additions and 6 deletions

View File

@@ -39,6 +39,8 @@ The trained LoRA model can be used with ComfyUI.
The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.
Aug 12: `--interactive` option is now working.
``` ```
python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0
``` ```

View File

@@ -5,7 +5,7 @@ import datetime
import math import math
import os import os
import random import random
from typing import Callable, Optional, Tuple from typing import Callable, List, Optional, Tuple
import einops import einops
import numpy as np import numpy as np
@@ -121,6 +121,9 @@ def generate_image(
steps: Optional[int], steps: Optional[int],
guidance: float, guidance: float,
): ):
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
logger.info(f"Seed: {seed}")
# make first noise with packed shape # make first noise with packed shape
# original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2 # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16) packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
@@ -183,9 +186,7 @@ def generate_image(
steps = 4 if is_schnell else 50 steps = 4 if is_schnell else 50
img_ids = img_ids.to(device) img_ids = img_ids.to(device)
x = do_sample( x = do_sample(accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, is_schnell, device, flux_dtype)
accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance_scale, is_schnell, device, flux_dtype
)
if args.offload: if args.offload:
model = model.cpu() model = model.cpu()
# del model # del model
@@ -255,6 +256,7 @@ if __name__ == "__main__":
default=[], default=[],
help="LoRA weights, only supports networks.lora_flux, each argument is a `path;multiplier` (semi-colon separated)", help="LoRA weights, only supports networks.lora_flux, each argument is a `path;multiplier` (semi-colon separated)",
) )
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
parser.add_argument("--width", type=int, default=target_width) parser.add_argument("--width", type=int, default=target_width)
parser.add_argument("--height", type=int, default=target_height) parser.add_argument("--height", type=int, default=target_height)
parser.add_argument("--interactive", action="store_true") parser.add_argument("--interactive", action="store_true")
@@ -341,6 +343,7 @@ if __name__ == "__main__":
ae = accelerator.prepare(ae) ae = accelerator.prepare(ae)
# LoRA # LoRA
lora_models: List[lora_flux.LoRANetwork] = []
for weights_file in args.lora_weights: for weights_file in args.lora_weights:
if ";" in weights_file: if ";" in weights_file:
weights_file, multiplier = weights_file.split(";") weights_file, multiplier = weights_file.split(";")
@@ -351,7 +354,16 @@ if __name__ == "__main__":
lora_model, weights_sd = lora_flux.create_network_from_weights( lora_model, weights_sd = lora_flux.create_network_from_weights(
multiplier, weights_file, ae, [clip_l, t5xxl], model, None, True multiplier, weights_file, ae, [clip_l, t5xxl], model, None, True
) )
if args.merge_lora_weights:
lora_model.merge_to([clip_l, t5xxl], model, weights_sd) lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
else:
lora_model.apply_to([clip_l, t5xxl], model)
info = lora_model.load_state_dict(weights_sd, strict=True)
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
lora_model.eval()
lora_model.to(device)
lora_models.append(lora_model)
if not args.interactive: if not args.interactive:
generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance) generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance)
@@ -363,7 +375,9 @@ if __name__ == "__main__":
guidance = args.guidance guidance = args.guidance
while True: while True:
print("Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --g <guidance>") print(
"Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --g <guidance> --m <multipliers for LoRA>"
)
prompt = input() prompt = input()
if prompt == "": if prompt == "":
break break
@@ -384,6 +398,13 @@ if __name__ == "__main__":
seed = int(opt[1:].strip()) seed = int(opt[1:].strip())
elif opt.startswith("g"): elif opt.startswith("g"):
guidance = float(opt[1:].strip()) guidance = float(opt[1:].strip())
elif opt.startswith("m"):
mutipliers = opt[1:].strip().split(",")
if len(mutipliers) != len(lora_models):
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
continue
for i, lora_model in enumerate(lora_models):
lora_model.set_multiplier(float(mutipliers[i]))
generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance) generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance)