mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add feature to sample images during sdxl training
This commit is contained in:
1346
library/sdxl_lpw_stable_diffusion.py
Normal file
1346
library/sdxl_lpw_stable_diffusion.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -9,6 +9,7 @@ from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer
|
||||
import open_clip
|
||||
from library import model_util, sdxl_model_util, train_util
|
||||
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
||||
|
||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||
|
||||
@@ -87,13 +88,22 @@ class WrapperTokenizer:
|
||||
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
||||
return self.tokenize(*args, **kwds)
|
||||
|
||||
def tokenize(self, text, padding, truncation, max_length, return_tensors):
|
||||
assert padding == "max_length"
|
||||
assert truncation == True
|
||||
assert return_tensors == "pt"
|
||||
input_ids = open_clip.tokenize(text, context_length=max_length)
|
||||
return SimpleNamespace(**{"input_ids": input_ids})
|
||||
def tokenize(self, text, padding=False, truncation=None, max_length=None, return_tensors=None):
|
||||
if padding == "max_length":
|
||||
# for training
|
||||
assert max_length is not None
|
||||
assert truncation == True
|
||||
assert return_tensors == "pt"
|
||||
input_ids = open_clip.tokenize(text, context_length=max_length)
|
||||
return SimpleNamespace(**{"input_ids": input_ids})
|
||||
|
||||
# for weighted prompt
|
||||
input_ids = open_clip.tokenize(text, context_length=self.model_max_length)
|
||||
|
||||
# find eos
|
||||
eos_index = (input_ids == self.eos_token_id).nonzero()[0].max() # max index of each batch
|
||||
input_ids = input_ids[:, : eos_index + 1] # include eos
|
||||
return SimpleNamespace(**{"input_ids": input_ids})
|
||||
|
||||
def load_tokenizers(args: argparse.Namespace):
|
||||
print("prepare tokenizers")
|
||||
@@ -381,3 +391,7 @@ def verify_sdxl_training_args(args: argparse.Namespace):
|
||||
assert (
|
||||
not args.weighted_captions
|
||||
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
||||
|
||||
|
||||
def sample_images(*args, **kwargs):
|
||||
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
||||
|
||||
@@ -3695,7 +3695,12 @@ SCHEDULER_TIMESTEPS = 1000
|
||||
SCHEDLER_SCHEDULE = "scaled_linear"
|
||||
|
||||
|
||||
def sample_images(
|
||||
def sample_images(*args, **kwargs):
|
||||
return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
||||
|
||||
|
||||
def sample_images_common(
|
||||
pipe_class,
|
||||
accelerator,
|
||||
args: argparse.Namespace,
|
||||
epoch,
|
||||
@@ -3790,7 +3795,7 @@ def sample_images(
|
||||
# print("set clip_sample to True")
|
||||
scheduler.config.clip_sample = True
|
||||
|
||||
pipeline = StableDiffusionLongPromptWeightingPipeline(
|
||||
pipeline = pipe_class(
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
@@ -3801,9 +3806,8 @@ def sample_images(
|
||||
requires_safety_checker=False,
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
pipeline.clip_skip = args.clip_skip # Pipelineのコンストラクタにckip_skipを追加できないので後から設定する
|
||||
pipeline.to(device)
|
||||
|
||||
|
||||
save_dir = args.output_dir + "/sample"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user