mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 23:01:22 +00:00
feat: add interactive mode for generating multiple images
This commit is contained in:
@@ -257,6 +257,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="LoRA weights, each argument is a `path;multiplier` (semi-colon separated)",
|
||||
)
|
||||
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
|
||||
parser.add_argument(
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
help="Enable interactive mode for generating multiple images / 対話モードで複数の画像を生成する",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@@ -294,9 +299,7 @@ if __name__ == "__main__":
|
||||
multiplier = 1.0
|
||||
|
||||
weights_sd = load_file(weights_file)
|
||||
lora_model, _ = lora_lumina.create_network_from_weights(
|
||||
multiplier, None, ae, [gemma2], model, weights_sd, True
|
||||
)
|
||||
lora_model, _ = lora_lumina.create_network_from_weights(multiplier, None, ae, [gemma2], model, weights_sd, True)
|
||||
|
||||
if args.merge_lora_weights:
|
||||
lora_model.merge_to([gemma2], model, weights_sd)
|
||||
@@ -304,25 +307,109 @@ if __name__ == "__main__":
|
||||
lora_model.apply_to([gemma2], model)
|
||||
info = lora_model.load_state_dict(weights_sd, strict=True)
|
||||
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
|
||||
lora_model.to(device)
|
||||
lora_model.set_multiplier(multiplier)
|
||||
lora_model.eval()
|
||||
|
||||
lora_models.append(lora_model)
|
||||
|
||||
generate_image(
|
||||
model,
|
||||
gemma2,
|
||||
ae,
|
||||
args.prompt,
|
||||
args.system_prompt,
|
||||
args.seed,
|
||||
args.image_width,
|
||||
args.image_height,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
args.negative_prompt,
|
||||
args,
|
||||
args.cfg_trunc_ratio,
|
||||
args.renorm_cfg,
|
||||
)
|
||||
if not args.interactive:
|
||||
generate_image(
|
||||
model,
|
||||
gemma2,
|
||||
ae,
|
||||
args.prompt,
|
||||
args.system_prompt,
|
||||
args.seed,
|
||||
args.image_width,
|
||||
args.image_height,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
args.negative_prompt,
|
||||
args,
|
||||
args.cfg_trunc_ratio,
|
||||
args.renorm_cfg,
|
||||
)
|
||||
else:
|
||||
# Interactive mode loop
|
||||
image_width = args.image_width
|
||||
image_height = args.image_height
|
||||
steps = args.steps
|
||||
guidance_scale = args.guidance_scale
|
||||
cfg_trunc_ratio = args.cfg_trunc_ratio
|
||||
renorm_cfg = args.renorm_cfg
|
||||
|
||||
print("Entering interactive mode.")
|
||||
while True:
|
||||
print(
|
||||
"\nEnter prompt (or 'exit'). Options: --w <int> --h <int> --s <int> --d <int> --g <float> --n <str> --ctr <float> --rcfg <float> --m <m1,m2...>"
|
||||
)
|
||||
user_input = input()
|
||||
if user_input.lower() == "exit":
|
||||
break
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
# Parse options
|
||||
options = user_input.split("--")
|
||||
prompt = options[0].strip()
|
||||
|
||||
# Set defaults for each generation
|
||||
seed = None # New random seed each time unless specified
|
||||
negative_prompt = args.negative_prompt # Reset to default
|
||||
|
||||
for opt in options[1:]:
|
||||
try:
|
||||
opt = opt.strip()
|
||||
if not opt:
|
||||
continue
|
||||
|
||||
key, value = (opt.split(None, 1) + [""])[:2]
|
||||
|
||||
if key == "w":
|
||||
image_width = int(value)
|
||||
elif key == "h":
|
||||
image_height = int(value)
|
||||
elif key == "s":
|
||||
steps = int(value)
|
||||
elif key == "d":
|
||||
seed = int(value)
|
||||
elif key == "g":
|
||||
guidance_scale = float(value)
|
||||
elif key == "n":
|
||||
negative_prompt = value if value != "-" else ""
|
||||
elif key == "ctr":
|
||||
cfg_trunc_ratio = float(value)
|
||||
elif key == "rcfg":
|
||||
renorm_cfg = float(value)
|
||||
elif key == "m":
|
||||
multipliers = value.split(",")
|
||||
if len(multipliers) != len(lora_models):
|
||||
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
|
||||
continue
|
||||
for i, lora_model in enumerate(lora_models):
|
||||
lora_model.set_multiplier(float(multipliers[i].strip()))
|
||||
else:
|
||||
logger.warning(f"Unknown option: --{key}")
|
||||
|
||||
except (ValueError, IndexError) as e:
|
||||
logger.error(f"Invalid value for option --{key}: '{value}'. Error: {e}")
|
||||
|
||||
generate_image(
|
||||
model,
|
||||
gemma2,
|
||||
ae,
|
||||
prompt,
|
||||
args.system_prompt,
|
||||
seed,
|
||||
image_width,
|
||||
image_height,
|
||||
steps,
|
||||
guidance_scale,
|
||||
negative_prompt,
|
||||
args,
|
||||
cfg_trunc_ratio,
|
||||
renorm_cfg,
|
||||
)
|
||||
|
||||
logger.info("Done.")
|
||||
|
||||
Reference in New Issue
Block a user