Fix lpwp to support sdv2 and clip skip

This commit is contained in:
Kohya S
2023-03-19 11:10:17 +09:00
parent cfb19ad0da
commit 1f7babd2c7
2 changed files with 105 additions and 82 deletions

View File

@@ -2703,8 +2703,7 @@ def sample_images(
accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None
):
"""
生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけに対応していない
clip skipは対応した
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応し
"""
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
@@ -2724,26 +2723,6 @@ def sample_images(
org_vae_device = vae.device # CPUにいるはず
vae.to(device)
# clip skip 対応のための wrapper を作る
if args.clip_skip is None:
text_encoder_or_wrapper = text_encoder
else:
class Wrapper:
def __init__(self, tenc) -> None:
self.tenc = tenc
self.config = {}
super().__init__()
def __call__(self, input_ids, attention_mask):
enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True)
encoder_hidden_states = enc_out["hidden_states"][-args.clip_skip]
encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states)
pooled_output = enc_out["pooler_output"]
return encoder_hidden_states, pooled_output # 1st output is only used
text_encoder_or_wrapper = Wrapper(text_encoder)
# read prompts
with open(args.sample_prompts, "rt", encoding="utf-8") as f:
prompts = f.readlines()
@@ -2792,8 +2771,17 @@ def sample_images(
# print("set clip_sample to True")
scheduler.config.clip_sample = True
pipeline = StableDiffusionLongPromptWeightingPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer,
scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False)
pipeline = StableDiffusionLongPromptWeightingPipeline(
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=scheduler,
clip_skip=args.clip_skip,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
pipeline.to(device)
save_dir = args.output_dir + "/sample"
@@ -2872,7 +2860,14 @@ def sample_images(
print(f"width: {width}")
print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}")
image = pipeline(prompt=prompt, height=height, width=width,num_inference_steps=sample_steps,guidance_scale=scale,negative_prompt=negative_prompt).images[0]
image = pipeline(
prompt=prompt,
height=height,
width=width,
num_inference_steps=sample_steps,
guidance_scale=scale,
negative_prompt=negative_prompt,
).images[0]
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"