mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix lpwp to support sdv2 and clip skip
This commit is contained in:
@@ -2703,8 +2703,7 @@ def sample_images(
|
||||
accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None
|
||||
):
|
||||
"""
|
||||
生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない
|
||||
clip skipは対応した
|
||||
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
|
||||
"""
|
||||
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
||||
return
|
||||
@@ -2724,26 +2723,6 @@ def sample_images(
|
||||
org_vae_device = vae.device # CPUにいるはず
|
||||
vae.to(device)
|
||||
|
||||
# clip skip 対応のための wrapper を作る
|
||||
if args.clip_skip is None:
|
||||
text_encoder_or_wrapper = text_encoder
|
||||
else:
|
||||
|
||||
class Wrapper:
|
||||
def __init__(self, tenc) -> None:
|
||||
self.tenc = tenc
|
||||
self.config = {}
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, input_ids, attention_mask):
|
||||
enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True)
|
||||
encoder_hidden_states = enc_out["hidden_states"][-args.clip_skip]
|
||||
encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states)
|
||||
pooled_output = enc_out["pooler_output"]
|
||||
return encoder_hidden_states, pooled_output # 1st output is only used
|
||||
|
||||
text_encoder_or_wrapper = Wrapper(text_encoder)
|
||||
|
||||
# read prompts
|
||||
with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
||||
prompts = f.readlines()
|
||||
@@ -2792,8 +2771,17 @@ def sample_images(
|
||||
# print("set clip_sample to True")
|
||||
scheduler.config.clip_sample = True
|
||||
|
||||
pipeline = StableDiffusionLongPromptWeightingPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer,
|
||||
scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False)
|
||||
pipeline = StableDiffusionLongPromptWeightingPipeline(
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
clip_skip=args.clip_skip,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
pipeline.to(device)
|
||||
|
||||
save_dir = args.output_dir + "/sample"
|
||||
@@ -2872,7 +2860,14 @@ def sample_images(
|
||||
print(f"width: {width}")
|
||||
print(f"sample_steps: {sample_steps}")
|
||||
print(f"scale: {scale}")
|
||||
image = pipeline(prompt=prompt, height=height, width=width,num_inference_steps=sample_steps,guidance_scale=scale,negative_prompt=negative_prompt).images[0]
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=sample_steps,
|
||||
guidance_scale=scale,
|
||||
negative_prompt=negative_prompt,
|
||||
).images[0]
|
||||
|
||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
||||
|
||||
Reference in New Issue
Block a user