mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Accept --ss to set sample_sampler dynamically
This commit is contained in:
@@ -4477,6 +4477,11 @@ def line_to_prompt_dict(line: str) -> dict:
|
|||||||
prompt_dict['negative_prompt'] = m.group(1)
|
prompt_dict['negative_prompt'] = m.group(1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"ss (.+)", parg, re.IGNORECASE)
|
||||||
|
if m: # negative prompt
|
||||||
|
prompt_dict['sample_sampler'] = m.group(1)
|
||||||
|
continue
|
||||||
|
|
||||||
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
|
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
|
||||||
if m: # negative prompt
|
if m: # negative prompt
|
||||||
prompt_dict['controlnet_image'] = m.group(1)
|
prompt_dict['controlnet_image'] = m.group(1)
|
||||||
@@ -4540,17 +4545,19 @@ def sample_images_common(
|
|||||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||||
prompts = json.load(f)
|
prompts = json.load(f)
|
||||||
|
|
||||||
scheduler = get_my_scheduler(
|
schedulers: dict = {}
|
||||||
sample_sampler=args.scheduler,
|
default_scheduler = get_my_scheduler(
|
||||||
|
sample_sampler=args.sample_sampler,
|
||||||
v_parameterization=args.v_parameterization,
|
v_parameterization=args.v_parameterization,
|
||||||
)
|
)
|
||||||
|
schedulers[args.sample_sampler] = default_scheduler
|
||||||
|
|
||||||
pipeline = pipe_class(
|
pipeline = pipe_class(
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
vae=vae,
|
vae=vae,
|
||||||
unet=unet,
|
unet=unet,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
scheduler=scheduler,
|
scheduler=default_scheduler,
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
feature_extractor=None,
|
feature_extractor=None,
|
||||||
requires_safety_checker=False,
|
requires_safety_checker=False,
|
||||||
@@ -4582,11 +4589,18 @@ def sample_images_common(
|
|||||||
seed = prompt_dict.get("seed")
|
seed = prompt_dict.get("seed")
|
||||||
controlnet_image = prompt_dict.get("controlnet_image")
|
controlnet_image = prompt_dict.get("controlnet_image")
|
||||||
prompt: str = prompt_dict.get("prompt", "")
|
prompt: str = prompt_dict.get("prompt", "")
|
||||||
|
sampler_name:str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
||||||
|
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
scheduler = schedulers.get(sampler_name)
|
||||||
|
if scheduler is None:
|
||||||
|
scheduler = get_my_scheduler(sample_sampler=sampler_name, v_parameterization=args.v_parameterization,)
|
||||||
|
schedulers[sampler_name] = scheduler
|
||||||
|
pipeline.scheduler = scheduler
|
||||||
|
|
||||||
if prompt_replacement is not None:
|
if prompt_replacement is not None:
|
||||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||||
if negative_prompt is not None:
|
if negative_prompt is not None:
|
||||||
@@ -4604,6 +4618,7 @@ def sample_images_common(
|
|||||||
print(f"width: {width}")
|
print(f"width: {width}")
|
||||||
print(f"sample_steps: {sample_steps}")
|
print(f"sample_steps: {sample_steps}")
|
||||||
print(f"scale: {scale}")
|
print(f"scale: {scale}")
|
||||||
|
print(f"sample_sampler: {sampler_name}")
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
latents = pipeline(
|
latents = pipeline(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
|||||||
Reference in New Issue
Block a user