feat: add interactive mode for generating multiple images

This commit is contained in:
Kohya S
2025-07-13 20:45:09 +09:00
parent 8a72f56c9f
commit 1a9bf2ab56

View File

@@ -257,6 +257,11 @@ def setup_parser() -> argparse.ArgumentParser:
help="LoRA weights, each argument is a `path;multiplier` (semi-colon separated)",
)
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
parser.add_argument(
"--interactive",
action="store_true",
help="Enable interactive mode for generating multiple images / 対話モードで複数の画像を生成する",
)
return parser
@@ -294,9 +299,7 @@ if __name__ == "__main__":
multiplier = 1.0
weights_sd = load_file(weights_file)
lora_model, _ = lora_lumina.create_network_from_weights(
multiplier, None, ae, [gemma2], model, weights_sd, True
)
lora_model, _ = lora_lumina.create_network_from_weights(multiplier, None, ae, [gemma2], model, weights_sd, True)
if args.merge_lora_weights:
lora_model.merge_to([gemma2], model, weights_sd)
@@ -304,25 +307,109 @@ if __name__ == "__main__":
lora_model.apply_to([gemma2], model)
info = lora_model.load_state_dict(weights_sd, strict=True)
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
lora_model.to(device)
lora_model.set_multiplier(multiplier)
lora_model.eval()
lora_models.append(lora_model)
generate_image(
model,
gemma2,
ae,
args.prompt,
args.system_prompt,
args.seed,
args.image_width,
args.image_height,
args.steps,
args.guidance_scale,
args.negative_prompt,
args,
args.cfg_trunc_ratio,
args.renorm_cfg,
)
if not args.interactive:
generate_image(
model,
gemma2,
ae,
args.prompt,
args.system_prompt,
args.seed,
args.image_width,
args.image_height,
args.steps,
args.guidance_scale,
args.negative_prompt,
args,
args.cfg_trunc_ratio,
args.renorm_cfg,
)
else:
# Interactive mode loop
image_width = args.image_width
image_height = args.image_height
steps = args.steps
guidance_scale = args.guidance_scale
cfg_trunc_ratio = args.cfg_trunc_ratio
renorm_cfg = args.renorm_cfg
print("Entering interactive mode.")
while True:
print(
"\nEnter prompt (or 'exit'). Options: --w <int> --h <int> --s <int> --d <int> --g <float> --n <str> --ctr <float> --rcfg <float> --m <m1,m2...>"
)
user_input = input()
if user_input.lower() == "exit":
break
if not user_input:
continue
# Parse options
options = user_input.split("--")
prompt = options[0].strip()
# Set defaults for each generation
seed = None # New random seed each time unless specified
negative_prompt = args.negative_prompt # Reset to default
for opt in options[1:]:
try:
opt = opt.strip()
if not opt:
continue
key, value = (opt.split(None, 1) + [""])[:2]
if key == "w":
image_width = int(value)
elif key == "h":
image_height = int(value)
elif key == "s":
steps = int(value)
elif key == "d":
seed = int(value)
elif key == "g":
guidance_scale = float(value)
elif key == "n":
negative_prompt = value if value != "-" else ""
elif key == "ctr":
cfg_trunc_ratio = float(value)
elif key == "rcfg":
renorm_cfg = float(value)
elif key == "m":
multipliers = value.split(",")
if len(multipliers) != len(lora_models):
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
continue
for i, lora_model in enumerate(lora_models):
lora_model.set_multiplier(float(multipliers[i].strip()))
else:
logger.warning(f"Unknown option: --{key}")
except (ValueError, IndexError) as e:
logger.error(f"Invalid value for option --{key}: '{value}'. Error: {e}")
generate_image(
model,
gemma2,
ae,
prompt,
args.system_prompt,
seed,
image_width,
image_height,
steps,
guidance_scale,
negative_prompt,
args,
cfg_trunc_ratio,
renorm_cfg,
)
logger.info("Done.")