mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
sd3 training
This commit is contained in:
@@ -58,7 +58,7 @@ from diffusers import (
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
AutoencoderKL,
|
||||
)
|
||||
from library import custom_train_functions
|
||||
from library import custom_train_functions, sd3_utils
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
import numpy as np
|
||||
@@ -135,6 +135,7 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
||||
)
|
||||
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
|
||||
|
||||
|
||||
class ImageInfo:
|
||||
@@ -985,7 +986,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
]
|
||||
)
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
|
||||
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
logger.info("caching latents.")
|
||||
|
||||
@@ -1006,7 +1007,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
if cache_to_disk:
|
||||
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
|
||||
info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
|
||||
if not is_main_process: # store to info only
|
||||
continue
|
||||
|
||||
@@ -1040,14 +1041,43 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
|
||||
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
|
||||
# SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する
|
||||
# SD1/2に対応するにはv2のフラグを持つ必要があるので後回し
|
||||
# if weight_dtype is specified, Text Encoder itself and output will be converted to the dtype
|
||||
# this method is only for SDXL, but it should be implemented here because it needs to be a method of dataset
|
||||
# to support SD1/2, it needs a flag for v2, but it is postponed
|
||||
def cache_text_encoder_outputs(
|
||||
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
|
||||
self, tokenizers, text_encoders, device, output_dtype, cache_to_disk=False, is_main_process=True
|
||||
):
|
||||
assert len(tokenizers) == 2, "only support SDXL"
|
||||
return self.cache_text_encoder_outputs_common(
|
||||
tokenizers, text_encoders, [device, device], output_dtype, [output_dtype], cache_to_disk, is_main_process
|
||||
)
|
||||
|
||||
# same as above, but for SD3
|
||||
def cache_text_encoder_outputs_sd3(
|
||||
self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True
|
||||
):
|
||||
return self.cache_text_encoder_outputs_common(
|
||||
[tokenizer],
|
||||
text_encoders,
|
||||
devices,
|
||||
output_dtype,
|
||||
te_dtypes,
|
||||
cache_to_disk,
|
||||
is_main_process,
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3,
|
||||
)
|
||||
|
||||
def cache_text_encoder_outputs_common(
|
||||
self,
|
||||
tokenizers,
|
||||
text_encoders,
|
||||
devices,
|
||||
output_dtype,
|
||||
te_dtypes,
|
||||
cache_to_disk=False,
|
||||
is_main_process=True,
|
||||
file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX,
|
||||
):
|
||||
# latentsのキャッシュと同様に、ディスクへのキャッシュに対応する
|
||||
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
logger.info("caching text encoder outputs.")
|
||||
@@ -1058,13 +1088,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
for info in tqdm(image_infos):
|
||||
# subset = self.image_to_subset[info.image_key]
|
||||
if cache_to_disk:
|
||||
te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
|
||||
te_out_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
|
||||
info.text_encoder_outputs_npz = te_out_npz
|
||||
|
||||
if not is_main_process: # store to info only
|
||||
continue
|
||||
|
||||
if os.path.exists(te_out_npz):
|
||||
# TODO check varidity of cache here
|
||||
continue
|
||||
|
||||
image_infos_to_cache.append(info)
|
||||
@@ -1073,18 +1104,23 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
return
|
||||
|
||||
# prepare tokenizers and text encoders
|
||||
for text_encoder in text_encoders:
|
||||
for text_encoder, device, te_dtype in zip(text_encoders, devices, te_dtypes):
|
||||
text_encoder.to(device)
|
||||
if weight_dtype is not None:
|
||||
text_encoder.to(dtype=weight_dtype)
|
||||
if te_dtype is not None:
|
||||
text_encoder.to(dtype=te_dtype)
|
||||
|
||||
# create batch
|
||||
is_sd3 = len(tokenizers) == 1
|
||||
batch = []
|
||||
batches = []
|
||||
for info in image_infos_to_cache:
|
||||
input_ids1 = self.get_input_ids(info.caption, tokenizers[0])
|
||||
input_ids2 = self.get_input_ids(info.caption, tokenizers[1])
|
||||
batch.append((info, input_ids1, input_ids2))
|
||||
if not is_sd3:
|
||||
input_ids1 = self.get_input_ids(info.caption, tokenizers[0])
|
||||
input_ids2 = self.get_input_ids(info.caption, tokenizers[1])
|
||||
batch.append((info, input_ids1, input_ids2))
|
||||
else:
|
||||
l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption)
|
||||
batch.append((info, l_tokens, g_tokens, t5_tokens))
|
||||
|
||||
if len(batch) >= self.batch_size:
|
||||
batches.append(batch)
|
||||
@@ -1095,13 +1131,32 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# iterate batches: call text encoder and cache outputs for memory or disk
|
||||
logger.info("caching text encoder outputs...")
|
||||
for batch in tqdm(batches):
|
||||
infos, input_ids1, input_ids2 = zip(*batch)
|
||||
input_ids1 = torch.stack(input_ids1, dim=0)
|
||||
input_ids2 = torch.stack(input_ids2, dim=0)
|
||||
cache_batch_text_encoder_outputs(
|
||||
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype
|
||||
)
|
||||
if not is_sd3:
|
||||
for batch in tqdm(batches):
|
||||
infos, input_ids1, input_ids2 = zip(*batch)
|
||||
input_ids1 = torch.stack(input_ids1, dim=0)
|
||||
input_ids2 = torch.stack(input_ids2, dim=0)
|
||||
cache_batch_text_encoder_outputs(
|
||||
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, output_dtype
|
||||
)
|
||||
else:
|
||||
for batch in tqdm(batches):
|
||||
infos, l_tokens, g_tokens, t5_tokens = zip(*batch)
|
||||
|
||||
# stack tokens
|
||||
# l_tokens = [tokens[0] for tokens in l_tokens]
|
||||
# g_tokens = [tokens[0] for tokens in g_tokens]
|
||||
# t5_tokens = [tokens[0] for tokens in t5_tokens]
|
||||
|
||||
cache_batch_text_encoder_outputs_sd3(
|
||||
infos,
|
||||
tokenizers[0],
|
||||
text_encoders,
|
||||
self.max_token_length,
|
||||
cache_to_disk,
|
||||
(l_tokens, g_tokens, t5_tokens),
|
||||
output_dtype,
|
||||
)
|
||||
|
||||
def get_image_size(self, image_path):
|
||||
return imagesize.get(image_path)
|
||||
@@ -1332,6 +1387,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
captions.append(caption)
|
||||
|
||||
if not self.token_padding_disabled: # this option might be omitted in future
|
||||
# TODO get_input_ids must support SD3
|
||||
if self.XTI_layers:
|
||||
token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
|
||||
else:
|
||||
@@ -2140,10 +2196,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
for dataset in self.datasets:
|
||||
dataset.enable_XTI(*args, **kwargs)
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
|
||||
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix)
|
||||
|
||||
def cache_text_encoder_outputs(
|
||||
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
|
||||
@@ -2152,6 +2208,15 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process)
|
||||
|
||||
def cache_text_encoder_outputs_sd3(
|
||||
self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True
|
||||
):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.cache_text_encoder_outputs_sd3(
|
||||
tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process
|
||||
)
|
||||
|
||||
def set_caching_mode(self, caching_mode):
|
||||
for dataset in self.datasets:
|
||||
dataset.set_caching_mode(caching_mode)
|
||||
@@ -2585,6 +2650,30 @@ def cache_batch_text_encoder_outputs(
|
||||
info.text_encoder_pool2 = pool2
|
||||
|
||||
|
||||
def cache_batch_text_encoder_outputs_sd3(
|
||||
image_infos, tokenizer, text_encoders, max_token_length, cache_to_disk, input_ids, output_dtype
|
||||
):
|
||||
# make input_ids for each text encoder
|
||||
l_tokens, g_tokens, t5_tokens = input_ids
|
||||
|
||||
clip_l, clip_g, t5xxl = text_encoders
|
||||
with torch.no_grad():
|
||||
b_lg_out, b_t5_out, b_pool = sd3_utils.get_cond_from_tokens(
|
||||
l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, "cpu", output_dtype
|
||||
)
|
||||
b_lg_out = b_lg_out.detach()
|
||||
b_t5_out = b_t5_out.detach()
|
||||
b_pool = b_pool.detach()
|
||||
|
||||
for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool):
|
||||
if cache_to_disk:
|
||||
save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool)
|
||||
else:
|
||||
info.text_encoder_outputs1 = lg_out
|
||||
info.text_encoder_outputs2 = t5_out
|
||||
info.text_encoder_pool2 = pool
|
||||
|
||||
|
||||
def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2):
|
||||
np.savez(
|
||||
npz_path,
|
||||
@@ -2907,6 +2996,7 @@ def get_sai_model_spec(
|
||||
lora: bool,
|
||||
textual_inversion: bool,
|
||||
is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA
|
||||
sd3: str = None,
|
||||
):
|
||||
timestamp = time.time()
|
||||
|
||||
@@ -2940,6 +3030,7 @@ def get_sai_model_spec(
|
||||
tags=args.metadata_tags,
|
||||
timesteps=timesteps,
|
||||
clip_skip=args.clip_skip, # None or int
|
||||
sd3=sd3,
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
Reference in New Issue
Block a user