diff --git a/lumina_minimal_inference.py b/lumina_minimal_inference.py index 4f915179..31362c00 100644 --- a/lumina_minimal_inference.py +++ b/lumina_minimal_inference.py @@ -257,6 +257,11 @@ def setup_parser() -> argparse.ArgumentParser: help="LoRA weights, each argument is a `path;multiplier` (semi-colon separated)", ) parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") + parser.add_argument( + "--interactive", + action="store_true", + help="Enable interactive mode for generating multiple images / 対話モードで複数の画像を生成する", + ) return parser @@ -294,9 +299,7 @@ if __name__ == "__main__": multiplier = 1.0 weights_sd = load_file(weights_file) - lora_model, _ = lora_lumina.create_network_from_weights( - multiplier, None, ae, [gemma2], model, weights_sd, True - ) + lora_model, _ = lora_lumina.create_network_from_weights(multiplier, None, ae, [gemma2], model, weights_sd, True) if args.merge_lora_weights: lora_model.merge_to([gemma2], model, weights_sd) @@ -304,25 +307,109 @@ if __name__ == "__main__": lora_model.apply_to([gemma2], model) info = lora_model.load_state_dict(weights_sd, strict=True) logger.info(f"Loaded LoRA weights from {weights_file}: {info}") + lora_model.to(device) + lora_model.set_multiplier(multiplier) lora_model.eval() lora_models.append(lora_model) - generate_image( - model, - gemma2, - ae, - args.prompt, - args.system_prompt, - args.seed, - args.image_width, - args.image_height, - args.steps, - args.guidance_scale, - args.negative_prompt, - args, - args.cfg_trunc_ratio, - args.renorm_cfg, - ) + if not args.interactive: + generate_image( + model, + gemma2, + ae, + args.prompt, + args.system_prompt, + args.seed, + args.image_width, + args.image_height, + args.steps, + args.guidance_scale, + args.negative_prompt, + args, + args.cfg_trunc_ratio, + args.renorm_cfg, + ) + else: + # Interactive mode loop + image_width = args.image_width + image_height = args.image_height + steps = args.steps + guidance_scale = args.guidance_scale + cfg_trunc_ratio = args.cfg_trunc_ratio + renorm_cfg = args.renorm_cfg + + print("Entering interactive mode.") + while True: + print( + "\nEnter prompt (or 'exit'). Options: --w --h --s --d --g --n --ctr --rcfg --m " + ) + user_input = input() + if user_input.lower() == "exit": + break + if not user_input: + continue + + # Parse options + options = user_input.split("--") + prompt = options[0].strip() + + # Set defaults for each generation + seed = None # New random seed each time unless specified + negative_prompt = args.negative_prompt # Reset to default + + for opt in options[1:]: + try: + opt = opt.strip() + if not opt: + continue + + key, value = (opt.split(None, 1) + [""])[:2] + + if key == "w": + image_width = int(value) + elif key == "h": + image_height = int(value) + elif key == "s": + steps = int(value) + elif key == "d": + seed = int(value) + elif key == "g": + guidance_scale = float(value) + elif key == "n": + negative_prompt = value if value != "-" else "" + elif key == "ctr": + cfg_trunc_ratio = float(value) + elif key == "rcfg": + renorm_cfg = float(value) + elif key == "m": + multipliers = value.split(",") + if len(multipliers) != len(lora_models): + logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") + continue + for i, lora_model in enumerate(lora_models): + lora_model.set_multiplier(float(multipliers[i].strip())) + else: + logger.warning(f"Unknown option: --{key}") + + except (ValueError, IndexError) as e: + logger.error(f"Invalid value for option --{key}: '{value}'. Error: {e}") + + generate_image( + model, + gemma2, + ae, + prompt, + args.system_prompt, + seed, + image_width, + image_height, + steps, + guidance_scale, + negative_prompt, + args, + cfg_trunc_ratio, + renorm_cfg, + ) logger.info("Done.")