sd3 training

This commit is contained in:
Kohya S
2024-06-23 23:38:20 +09:00
parent a518e3c819
commit d53ea22b2a
8 changed files with 1909 additions and 44 deletions

View File

@@ -58,7 +58,7 @@ from diffusers import (
KDPM2AncestralDiscreteScheduler,
AutoencoderKL,
)
from library import custom_train_functions
from library import custom_train_functions, sd3_utils
from library.original_unet import UNet2DConditionModel
from huggingface_hub import hf_hub_download
import numpy as np
@@ -135,6 +135,7 @@ IMAGE_TRANSFORMS = transforms.Compose(
)
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
class ImageInfo:
@@ -985,7 +986,7 @@ class BaseDataset(torch.utils.data.Dataset):
]
)
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
logger.info("caching latents.")
@@ -1006,7 +1007,7 @@ class BaseDataset(torch.utils.data.Dataset):
# check disk cache exists and size of latents
if cache_to_disk:
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
if not is_main_process: # store to info only
continue
@@ -1040,14 +1041,43 @@ class BaseDataset(torch.utils.data.Dataset):
for batch in tqdm(batches, smoothing=1, total=len(batches)):
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
# SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する
# SD1/2に対応するにはv2のフラグを持つ必要があるので後回し
# if weight_dtype is specified, Text Encoder itself and output will be converted to the dtype
# this method is only for SDXL, but it should be implemented here because it needs to be a method of dataset
# to support SD1/2, it needs a flag for v2, but it is postponed
def cache_text_encoder_outputs(
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
self, tokenizers, text_encoders, device, output_dtype, cache_to_disk=False, is_main_process=True
):
assert len(tokenizers) == 2, "only support SDXL"
return self.cache_text_encoder_outputs_common(
tokenizers, text_encoders, [device, device], output_dtype, [output_dtype], cache_to_disk, is_main_process
)
# same as above, but for SD3
def cache_text_encoder_outputs_sd3(
self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True
):
return self.cache_text_encoder_outputs_common(
[tokenizer],
text_encoders,
devices,
output_dtype,
te_dtypes,
cache_to_disk,
is_main_process,
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3,
)
def cache_text_encoder_outputs_common(
self,
tokenizers,
text_encoders,
devices,
output_dtype,
te_dtypes,
cache_to_disk=False,
is_main_process=True,
file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX,
):
# latentsのキャッシュと同様に、ディスクへのキャッシュに対応する
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
logger.info("caching text encoder outputs.")
@@ -1058,13 +1088,14 @@ class BaseDataset(torch.utils.data.Dataset):
for info in tqdm(image_infos):
# subset = self.image_to_subset[info.image_key]
if cache_to_disk:
te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
te_out_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
info.text_encoder_outputs_npz = te_out_npz
if not is_main_process: # store to info only
continue
if os.path.exists(te_out_npz):
# TODO check varidity of cache here
continue
image_infos_to_cache.append(info)
@@ -1073,18 +1104,23 @@ class BaseDataset(torch.utils.data.Dataset):
return
# prepare tokenizers and text encoders
for text_encoder in text_encoders:
for text_encoder, device, te_dtype in zip(text_encoders, devices, te_dtypes):
text_encoder.to(device)
if weight_dtype is not None:
text_encoder.to(dtype=weight_dtype)
if te_dtype is not None:
text_encoder.to(dtype=te_dtype)
# create batch
is_sd3 = len(tokenizers) == 1
batch = []
batches = []
for info in image_infos_to_cache:
input_ids1 = self.get_input_ids(info.caption, tokenizers[0])
input_ids2 = self.get_input_ids(info.caption, tokenizers[1])
batch.append((info, input_ids1, input_ids2))
if not is_sd3:
input_ids1 = self.get_input_ids(info.caption, tokenizers[0])
input_ids2 = self.get_input_ids(info.caption, tokenizers[1])
batch.append((info, input_ids1, input_ids2))
else:
l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption)
batch.append((info, l_tokens, g_tokens, t5_tokens))
if len(batch) >= self.batch_size:
batches.append(batch)
@@ -1095,13 +1131,32 @@ class BaseDataset(torch.utils.data.Dataset):
# iterate batches: call text encoder and cache outputs for memory or disk
logger.info("caching text encoder outputs...")
for batch in tqdm(batches):
infos, input_ids1, input_ids2 = zip(*batch)
input_ids1 = torch.stack(input_ids1, dim=0)
input_ids2 = torch.stack(input_ids2, dim=0)
cache_batch_text_encoder_outputs(
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype
)
if not is_sd3:
for batch in tqdm(batches):
infos, input_ids1, input_ids2 = zip(*batch)
input_ids1 = torch.stack(input_ids1, dim=0)
input_ids2 = torch.stack(input_ids2, dim=0)
cache_batch_text_encoder_outputs(
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, output_dtype
)
else:
for batch in tqdm(batches):
infos, l_tokens, g_tokens, t5_tokens = zip(*batch)
# stack tokens
# l_tokens = [tokens[0] for tokens in l_tokens]
# g_tokens = [tokens[0] for tokens in g_tokens]
# t5_tokens = [tokens[0] for tokens in t5_tokens]
cache_batch_text_encoder_outputs_sd3(
infos,
tokenizers[0],
text_encoders,
self.max_token_length,
cache_to_disk,
(l_tokens, g_tokens, t5_tokens),
output_dtype,
)
def get_image_size(self, image_path):
return imagesize.get(image_path)
@@ -1332,6 +1387,7 @@ class BaseDataset(torch.utils.data.Dataset):
captions.append(caption)
if not self.token_padding_disabled: # this option might be omitted in future
# TODO get_input_ids must support SD3
if self.XTI_layers:
token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
else:
@@ -2140,10 +2196,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for dataset in self.datasets:
dataset.enable_XTI(*args, **kwargs)
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix)
def cache_text_encoder_outputs(
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
@@ -2152,6 +2208,15 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
logger.info(f"[Dataset {i}]")
dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process)
def cache_text_encoder_outputs_sd3(
self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True
):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.cache_text_encoder_outputs_sd3(
tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process
)
def set_caching_mode(self, caching_mode):
for dataset in self.datasets:
dataset.set_caching_mode(caching_mode)
@@ -2585,6 +2650,30 @@ def cache_batch_text_encoder_outputs(
info.text_encoder_pool2 = pool2
def cache_batch_text_encoder_outputs_sd3(
image_infos, tokenizer, text_encoders, max_token_length, cache_to_disk, input_ids, output_dtype
):
# make input_ids for each text encoder
l_tokens, g_tokens, t5_tokens = input_ids
clip_l, clip_g, t5xxl = text_encoders
with torch.no_grad():
b_lg_out, b_t5_out, b_pool = sd3_utils.get_cond_from_tokens(
l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, "cpu", output_dtype
)
b_lg_out = b_lg_out.detach()
b_t5_out = b_t5_out.detach()
b_pool = b_pool.detach()
for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool):
if cache_to_disk:
save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool)
else:
info.text_encoder_outputs1 = lg_out
info.text_encoder_outputs2 = t5_out
info.text_encoder_pool2 = pool
def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2):
np.savez(
npz_path,
@@ -2907,6 +2996,7 @@ def get_sai_model_spec(
lora: bool,
textual_inversion: bool,
is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA
sd3: str = None,
):
timestamp = time.time()
@@ -2940,6 +3030,7 @@ def get_sai_model_spec(
tags=args.metadata_tags,
timesteps=timesteps,
clip_skip=args.clip_skip, # None or int
sd3=sd3,
)
return metadata