diff --git a/library/train_util.py b/library/train_util.py index b93b8ea4..949a8206 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4477,6 +4477,11 @@ def line_to_prompt_dict(line: str) -> dict: prompt_dict['negative_prompt'] = m.group(1) continue + m = re.match(r"ss (.+)", parg, re.IGNORECASE) + if m: # negative prompt + prompt_dict['sample_sampler'] = m.group(1) + continue + m = re.match(r"cn (.+)", parg, re.IGNORECASE) if m: # negative prompt prompt_dict['controlnet_image'] = m.group(1) @@ -4540,17 +4545,19 @@ def sample_images_common( with open(args.sample_prompts, "r", encoding="utf-8") as f: prompts = json.load(f) - scheduler = get_my_scheduler( - sample_sampler=args.scheduler, + schedulers: dict = {} + default_scheduler = get_my_scheduler( + sample_sampler=args.sample_sampler, v_parameterization=args.v_parameterization, ) + schedulers[args.sample_sampler] = default_scheduler pipeline = pipe_class( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, - scheduler=scheduler, + scheduler=default_scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False, @@ -4582,11 +4589,18 @@ def sample_images_common( seed = prompt_dict.get("seed") controlnet_image = prompt_dict.get("controlnet_image") prompt: str = prompt_dict.get("prompt", "") + sampler_name:str = prompt_dict.get("sample_sampler", args.sample_sampler) if seed is not None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) + scheduler = schedulers.get(sampler_name) + if scheduler is None: + scheduler = get_my_scheduler(sample_sampler=sampler_name, v_parameterization=args.v_parameterization,) + schedulers[sampler_name] = scheduler + pipeline.scheduler = scheduler + if prompt_replacement is not None: prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) if negative_prompt is not None: @@ -4604,6 +4618,7 @@ def sample_images_common( print(f"width: {width}") print(f"sample_steps: {sample_steps}") print(f"scale: {scale}") + print(f"sample_sampler: {sampler_name}") with accelerator.autocast(): latents = pipeline( prompt=prompt,