Accept --ss to set sample_sampler dynamically

This commit is contained in:
Yuta Hayashibe
2023-10-29 20:15:04 +09:00
parent 291c29caaf
commit cf876fcdb4

View File

@@ -4477,6 +4477,11 @@ def line_to_prompt_dict(line: str) -> dict:
prompt_dict['negative_prompt'] = m.group(1) prompt_dict['negative_prompt'] = m.group(1)
continue continue
m = re.match(r"ss (.+)", parg, re.IGNORECASE)
if m: # negative prompt
prompt_dict['sample_sampler'] = m.group(1)
continue
m = re.match(r"cn (.+)", parg, re.IGNORECASE) m = re.match(r"cn (.+)", parg, re.IGNORECASE)
if m: # negative prompt if m: # negative prompt
prompt_dict['controlnet_image'] = m.group(1) prompt_dict['controlnet_image'] = m.group(1)
@@ -4540,17 +4545,19 @@ def sample_images_common(
with open(args.sample_prompts, "r", encoding="utf-8") as f: with open(args.sample_prompts, "r", encoding="utf-8") as f:
prompts = json.load(f) prompts = json.load(f)
scheduler = get_my_scheduler( schedulers: dict = {}
sample_sampler=args.scheduler, default_scheduler = get_my_scheduler(
sample_sampler=args.sample_sampler,
v_parameterization=args.v_parameterization, v_parameterization=args.v_parameterization,
) )
schedulers[args.sample_sampler] = default_scheduler
pipeline = pipe_class( pipeline = pipe_class(
text_encoder=text_encoder, text_encoder=text_encoder,
vae=vae, vae=vae,
unet=unet, unet=unet,
tokenizer=tokenizer, tokenizer=tokenizer,
scheduler=scheduler, scheduler=default_scheduler,
safety_checker=None, safety_checker=None,
feature_extractor=None, feature_extractor=None,
requires_safety_checker=False, requires_safety_checker=False,
@@ -4582,11 +4589,18 @@ def sample_images_common(
seed = prompt_dict.get("seed") seed = prompt_dict.get("seed")
controlnet_image = prompt_dict.get("controlnet_image") controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "") prompt: str = prompt_dict.get("prompt", "")
sampler_name:str = prompt_dict.get("sample_sampler", args.sample_sampler)
if seed is not None: if seed is not None:
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
scheduler = schedulers.get(sampler_name)
if scheduler is None:
scheduler = get_my_scheduler(sample_sampler=sampler_name, v_parameterization=args.v_parameterization,)
schedulers[sampler_name] = scheduler
pipeline.scheduler = scheduler
if prompt_replacement is not None: if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None: if negative_prompt is not None:
@@ -4604,6 +4618,7 @@ def sample_images_common(
print(f"width: {width}") print(f"width: {width}")
print(f"sample_steps: {sample_steps}") print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}") print(f"scale: {scale}")
print(f"sample_sampler: {sampler_name}")
with accelerator.autocast(): with accelerator.autocast():
latents = pipeline( latents = pipeline(
prompt=prompt, prompt=prompt,