add feature to sample images during sdxl training

This commit is contained in:
Kohya S
2023-07-02 16:42:19 +09:00
parent 227a62e4c4
commit 64cf922841
5 changed files with 1402 additions and 32 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -9,6 +9,7 @@ from tqdm import tqdm
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
import open_clip import open_clip
from library import model_util, sdxl_model_util, train_util from library import model_util, sdxl_model_util, train_util
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
TOKENIZER_PATH = "openai/clip-vit-large-patch14" TOKENIZER_PATH = "openai/clip-vit-large-patch14"
@@ -87,13 +88,22 @@ class WrapperTokenizer:
def __call__(self, *args: Any, **kwds: Any) -> Any: def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.tokenize(*args, **kwds) return self.tokenize(*args, **kwds)
def tokenize(self, text, padding, truncation, max_length, return_tensors): def tokenize(self, text, padding=False, truncation=None, max_length=None, return_tensors=None):
assert padding == "max_length" if padding == "max_length":
# for training
assert max_length is not None
assert truncation == True assert truncation == True
assert return_tensors == "pt" assert return_tensors == "pt"
input_ids = open_clip.tokenize(text, context_length=max_length) input_ids = open_clip.tokenize(text, context_length=max_length)
return SimpleNamespace(**{"input_ids": input_ids}) return SimpleNamespace(**{"input_ids": input_ids})
# for weighted prompt
input_ids = open_clip.tokenize(text, context_length=self.model_max_length)
# find eos
eos_index = (input_ids == self.eos_token_id).nonzero()[0].max() # max index of each batch
input_ids = input_ids[:, : eos_index + 1] # include eos
return SimpleNamespace(**{"input_ids": input_ids})
def load_tokenizers(args: argparse.Namespace): def load_tokenizers(args: argparse.Namespace):
print("prepare tokenizers") print("prepare tokenizers")
@@ -381,3 +391,7 @@ def verify_sdxl_training_args(args: argparse.Namespace):
assert ( assert (
not args.weighted_captions not args.weighted_captions
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
def sample_images(*args, **kwargs):
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)

View File

@@ -3695,7 +3695,12 @@ SCHEDULER_TIMESTEPS = 1000
SCHEDLER_SCHEDULE = "scaled_linear" SCHEDLER_SCHEDULE = "scaled_linear"
def sample_images( def sample_images(*args, **kwargs):
return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
def sample_images_common(
pipe_class,
accelerator, accelerator,
args: argparse.Namespace, args: argparse.Namespace,
epoch, epoch,
@@ -3790,7 +3795,7 @@ def sample_images(
# print("set clip_sample to True") # print("set clip_sample to True")
scheduler.config.clip_sample = True scheduler.config.clip_sample = True
pipeline = StableDiffusionLongPromptWeightingPipeline( pipeline = pipe_class(
text_encoder=text_encoder, text_encoder=text_encoder,
vae=vae, vae=vae,
unet=unet, unet=unet,
@@ -3801,7 +3806,6 @@ def sample_images(
requires_safety_checker=False, requires_safety_checker=False,
clip_skip=args.clip_skip, clip_skip=args.clip_skip,
) )
pipeline.clip_skip = args.clip_skip # Pipelineのコンストラクタにckip_skipを追加できないので後から設定する
pipeline.to(device) pipeline.to(device)
save_dir = args.output_dir + "/sample" save_dir = args.output_dir + "/sample"

View File

@@ -290,8 +290,9 @@ def train(args):
args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataloader, None args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataloader, None
) )
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
text_encoder1.to("cpu") # Text Encoder doesn't work on CPU with fp16
text_encoder2.to("cpu") text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
else: else:
@@ -467,19 +468,17 @@ def train(args):
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
# sdxl_train_util.sample_images( sdxl_train_util.sample_images(
# accelerator, accelerator,
# args, args,
# None, None,
# global_step, global_step,
# accelerator.device, accelerator.device,
# vae, vae,
# tokenizer1, [tokenizer1, tokenizer2],
# tokenizer2, [text_encoder1, text_encoder2],
# text_encoder1, unet,
# text_encoder2, )
# unet,
# )
# 指定ステップごとにモデルを保存 # 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
@@ -553,7 +552,17 @@ def train(args):
ckpt_info, ckpt_info,
) )
# train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) sdxl_train_util.sample_images(
accelerator,
args,
epoch + 1,
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
unet,
)
is_main_process = accelerator.is_main_process is_main_process = accelerator.is_main_process
# if is_main_process: # if is_main_process:

View File

@@ -8,7 +8,6 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
self.sampling_warning_showed = False
def assert_extra_args(self, args, train_dataset_group): def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group) super().assert_extra_args(args, train_dataset_group)
@@ -65,8 +64,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype
) )
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
text_encoders[0].to("cpu") text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
text_encoders[1].to("cpu") text_encoders[1].to("cpu", dtype=torch.float32)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
@@ -149,9 +148,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
return noise_pred return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
if not self.sampling_warning_showed: sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
print("sample_images is not implemented")
self.sampling_warning_showed = True
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser: