mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add feature to sample images during sdxl training
This commit is contained in:
1346
library/sdxl_lpw_stable_diffusion.py
Normal file
1346
library/sdxl_lpw_stable_diffusion.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -9,6 +9,7 @@ from tqdm import tqdm
|
|||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
import open_clip
|
import open_clip
|
||||||
from library import model_util, sdxl_model_util, train_util
|
from library import model_util, sdxl_model_util, train_util
|
||||||
|
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
||||||
|
|
||||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||||
|
|
||||||
@@ -87,13 +88,22 @@ class WrapperTokenizer:
|
|||||||
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
||||||
return self.tokenize(*args, **kwds)
|
return self.tokenize(*args, **kwds)
|
||||||
|
|
||||||
def tokenize(self, text, padding, truncation, max_length, return_tensors):
|
def tokenize(self, text, padding=False, truncation=None, max_length=None, return_tensors=None):
|
||||||
assert padding == "max_length"
|
if padding == "max_length":
|
||||||
assert truncation == True
|
# for training
|
||||||
assert return_tensors == "pt"
|
assert max_length is not None
|
||||||
input_ids = open_clip.tokenize(text, context_length=max_length)
|
assert truncation == True
|
||||||
return SimpleNamespace(**{"input_ids": input_ids})
|
assert return_tensors == "pt"
|
||||||
|
input_ids = open_clip.tokenize(text, context_length=max_length)
|
||||||
|
return SimpleNamespace(**{"input_ids": input_ids})
|
||||||
|
|
||||||
|
# for weighted prompt
|
||||||
|
input_ids = open_clip.tokenize(text, context_length=self.model_max_length)
|
||||||
|
|
||||||
|
# find eos
|
||||||
|
eos_index = (input_ids == self.eos_token_id).nonzero()[0].max() # max index of each batch
|
||||||
|
input_ids = input_ids[:, : eos_index + 1] # include eos
|
||||||
|
return SimpleNamespace(**{"input_ids": input_ids})
|
||||||
|
|
||||||
def load_tokenizers(args: argparse.Namespace):
|
def load_tokenizers(args: argparse.Namespace):
|
||||||
print("prepare tokenizers")
|
print("prepare tokenizers")
|
||||||
@@ -381,3 +391,7 @@ def verify_sdxl_training_args(args: argparse.Namespace):
|
|||||||
assert (
|
assert (
|
||||||
not args.weighted_captions
|
not args.weighted_captions
|
||||||
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
||||||
|
|
||||||
|
|
||||||
|
def sample_images(*args, **kwargs):
|
||||||
|
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
||||||
|
|||||||
@@ -3695,7 +3695,12 @@ SCHEDULER_TIMESTEPS = 1000
|
|||||||
SCHEDLER_SCHEDULE = "scaled_linear"
|
SCHEDLER_SCHEDULE = "scaled_linear"
|
||||||
|
|
||||||
|
|
||||||
def sample_images(
|
def sample_images(*args, **kwargs):
|
||||||
|
return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_images_common(
|
||||||
|
pipe_class,
|
||||||
accelerator,
|
accelerator,
|
||||||
args: argparse.Namespace,
|
args: argparse.Namespace,
|
||||||
epoch,
|
epoch,
|
||||||
@@ -3790,7 +3795,7 @@ def sample_images(
|
|||||||
# print("set clip_sample to True")
|
# print("set clip_sample to True")
|
||||||
scheduler.config.clip_sample = True
|
scheduler.config.clip_sample = True
|
||||||
|
|
||||||
pipeline = StableDiffusionLongPromptWeightingPipeline(
|
pipeline = pipe_class(
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
vae=vae,
|
vae=vae,
|
||||||
unet=unet,
|
unet=unet,
|
||||||
@@ -3801,9 +3806,8 @@ def sample_images(
|
|||||||
requires_safety_checker=False,
|
requires_safety_checker=False,
|
||||||
clip_skip=args.clip_skip,
|
clip_skip=args.clip_skip,
|
||||||
)
|
)
|
||||||
pipeline.clip_skip = args.clip_skip # Pipelineのコンストラクタにckip_skipを追加できないので後から設定する
|
|
||||||
pipeline.to(device)
|
pipeline.to(device)
|
||||||
|
|
||||||
save_dir = args.output_dir + "/sample"
|
save_dir = args.output_dir + "/sample"
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
|||||||
@@ -290,8 +290,9 @@ def train(args):
|
|||||||
args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataloader, None
|
args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataloader, None
|
||||||
)
|
)
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
text_encoder1.to("cpu")
|
# Text Encoder doesn't work on CPU with fp16
|
||||||
text_encoder2.to("cpu")
|
text_encoder1.to("cpu", dtype=torch.float32)
|
||||||
|
text_encoder2.to("cpu", dtype=torch.float32)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
else:
|
else:
|
||||||
@@ -467,19 +468,17 @@ def train(args):
|
|||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
# sdxl_train_util.sample_images(
|
sdxl_train_util.sample_images(
|
||||||
# accelerator,
|
accelerator,
|
||||||
# args,
|
args,
|
||||||
# None,
|
None,
|
||||||
# global_step,
|
global_step,
|
||||||
# accelerator.device,
|
accelerator.device,
|
||||||
# vae,
|
vae,
|
||||||
# tokenizer1,
|
[tokenizer1, tokenizer2],
|
||||||
# tokenizer2,
|
[text_encoder1, text_encoder2],
|
||||||
# text_encoder1,
|
unet,
|
||||||
# text_encoder2,
|
)
|
||||||
# unet,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# 指定ステップごとにモデルを保存
|
# 指定ステップごとにモデルを保存
|
||||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||||
@@ -553,7 +552,17 @@ def train(args):
|
|||||||
ckpt_info,
|
ckpt_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
# train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
sdxl_train_util.sample_images(
|
||||||
|
accelerator,
|
||||||
|
args,
|
||||||
|
epoch + 1,
|
||||||
|
global_step,
|
||||||
|
accelerator.device,
|
||||||
|
vae,
|
||||||
|
[tokenizer1, tokenizer2],
|
||||||
|
[text_encoder1, text_encoder2],
|
||||||
|
unet,
|
||||||
|
)
|
||||||
|
|
||||||
is_main_process = accelerator.is_main_process
|
is_main_process = accelerator.is_main_process
|
||||||
# if is_main_process:
|
# if is_main_process:
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
|
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
|
||||||
self.sampling_warning_showed = False
|
|
||||||
|
|
||||||
def assert_extra_args(self, args, train_dataset_group):
|
def assert_extra_args(self, args, train_dataset_group):
|
||||||
super().assert_extra_args(args, train_dataset_group)
|
super().assert_extra_args(args, train_dataset_group)
|
||||||
@@ -65,8 +64,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype
|
args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype
|
||||||
)
|
)
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
text_encoders[0].to("cpu")
|
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
||||||
text_encoders[1].to("cpu")
|
text_encoders[1].to("cpu", dtype=torch.float32)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@@ -149,9 +148,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
return noise_pred
|
return noise_pred
|
||||||
|
|
||||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
||||||
if not self.sampling_warning_showed:
|
sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
|
||||||
print("sample_images is not implemented")
|
|
||||||
self.sampling_warning_showed = True
|
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
Reference in New Issue
Block a user