sample images with weight and no length limit

This commit is contained in:
mio
2023-03-12 16:08:31 +08:00
parent 7c1cf7f4ea
commit e24a43ae0b
2 changed files with 1155 additions and 4 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -43,7 +43,7 @@ import cv2
from einops import rearrange from einops import rearrange
from torch import einsum from torch import einsum
import safetensors.torch import safetensors.torch
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
import library.model_util as model_util import library.model_util as model_util
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
@@ -2293,8 +2293,8 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
# print("set clip_sample to True") # print("set clip_sample to True")
scheduler.config.clip_sample = True scheduler.config.clip_sample = True
pipeline = StableDiffusionPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer, pipeline = StableDiffusionLongPromptWeightingPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer,
scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False) scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False)
pipeline.to(device) pipeline.to(device)
@@ -2374,7 +2374,7 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
print(f"width: {width}") print(f"width: {width}")
print(f"sample_steps: {sample_steps}") print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}") print(f"scale: {scale}")
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0] image = pipeline(prompt=prompt, height=height, width=width,num_inference_steps=sample_steps,guidance_scale=scale,negative_prompt=negative_prompt).images[0]
ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime()) ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"