mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Refactor caching mechanism for latents and text encoder outputs, etc.
This commit is contained in:
@@ -12,6 +12,7 @@ import re
|
||||
import shutil
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
@@ -34,6 +35,7 @@ from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy, TextEncodingStrategy
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -81,10 +83,6 @@ logger = logging.getLogger(__name__)
|
||||
# from library.hypernetwork import replace_attentions_for_hypernetwork
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
|
||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
|
||||
|
||||
HIGH_VRAM = False
|
||||
|
||||
# checkpointファイル名
|
||||
@@ -148,18 +146,24 @@ class ImageInfo:
|
||||
self.image_size: Tuple[int, int] = None
|
||||
self.resized_size: Tuple[int, int] = None
|
||||
self.bucket_reso: Tuple[int, int] = None
|
||||
self.latents: torch.Tensor = None
|
||||
self.latents_flipped: torch.Tensor = None
|
||||
self.latents_npz: str = None
|
||||
self.latents_original_size: Tuple[int, int] = None # original image size, not latents size
|
||||
self.latents_crop_ltrb: Tuple[int, int] = None # crop left top right bottom in original pixel size, not latents size
|
||||
self.cond_img_path: str = None
|
||||
self.latents: Optional[torch.Tensor] = None
|
||||
self.latents_flipped: Optional[torch.Tensor] = None
|
||||
self.latents_npz: Optional[str] = None # set in cache_latents
|
||||
self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size
|
||||
self.latents_crop_ltrb: Optional[Tuple[int, int]] = (
|
||||
None # crop left top right bottom in original pixel size, not latents size
|
||||
)
|
||||
self.cond_img_path: Optional[str] = None
|
||||
self.image: Optional[Image.Image] = None # optional, original PIL Image
|
||||
# SDXL, optional
|
||||
self.text_encoder_outputs_npz: Optional[str] = None
|
||||
self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs
|
||||
|
||||
# new
|
||||
self.text_encoder_outputs: Optional[List[torch.Tensor]] = None
|
||||
# old
|
||||
self.text_encoder_outputs1: Optional[torch.Tensor] = None
|
||||
self.text_encoder_outputs2: Optional[torch.Tensor] = None
|
||||
self.text_encoder_pool2: Optional[torch.Tensor] = None
|
||||
|
||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||
|
||||
|
||||
@@ -359,47 +363,6 @@ class AugHelper:
|
||||
return self.color_aug if use_color_aug else None
|
||||
|
||||
|
||||
class LatentsCachingStrategy:
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
self._cache_to_disk = cache_to_disk
|
||||
self._batch_size = batch_size
|
||||
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
||||
|
||||
@classmethod
|
||||
def set_strategy(cls, strategy):
|
||||
if cls._strategy is not None:
|
||||
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||||
cls._strategy = strategy
|
||||
|
||||
@classmethod
|
||||
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
|
||||
return cls._strategy
|
||||
|
||||
@property
|
||||
def cache_to_disk(self):
|
||||
return self._cache_to_disk
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, bucket_reso: Tuple[int, int]) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_disk_cached_latents_expected(
|
||||
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseSubset:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -639,17 +602,12 @@ class ControlNetSubset(BaseSubset):
|
||||
class BaseDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]],
|
||||
max_token_length: int,
|
||||
resolution: Optional[Tuple[int, int]],
|
||||
network_multiplier: float,
|
||||
debug_dataset: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer]
|
||||
|
||||
self.max_token_length = max_token_length
|
||||
# width/height is used when enable_bucket==False
|
||||
self.width, self.height = (None, None) if resolution is None else resolution
|
||||
self.network_multiplier = network_multiplier
|
||||
@@ -670,8 +628,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.bucket_no_upscale = None
|
||||
self.bucket_info = None # for metadata
|
||||
|
||||
self.tokenizer_max_length = self.tokenizers[0].model_max_length if max_token_length is None else max_token_length + 2
|
||||
|
||||
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
||||
|
||||
self.current_step: int = 0
|
||||
@@ -690,6 +646,15 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# caching
|
||||
self.caching_mode = None # None, 'latents', 'text'
|
||||
|
||||
self.tokenize_strategy = None
|
||||
self.text_encoder_output_caching_strategy = None
|
||||
self.latents_caching_strategy = None
|
||||
|
||||
def set_current_strategies(self):
|
||||
self.tokenize_strategy = TokenizeStrategy.get_strategy()
|
||||
self.text_encoder_output_caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy()
|
||||
self.latents_caching_strategy = LatentsCachingStrategy.get_strategy()
|
||||
|
||||
def set_seed(self, seed):
|
||||
self.seed = seed
|
||||
@@ -979,22 +944,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
for batch_index in range(batch_count):
|
||||
self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index))
|
||||
|
||||
# ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す
|
||||
# 学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる
|
||||
#
|
||||
# # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
|
||||
# # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
|
||||
# # そのためバッチサイズを画像種類までに制限する
|
||||
# # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
|
||||
# # TO DO 正則化画像をepochまたがりで利用する仕組み
|
||||
# num_of_image_types = len(set(bucket))
|
||||
# bucket_batch_size = min(self.batch_size, num_of_image_types)
|
||||
# batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
|
||||
# # logger.info(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
|
||||
# for batch_index in range(batch_count):
|
||||
# self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
|
||||
# ↑ここまで
|
||||
|
||||
self.shuffle_buckets()
|
||||
self._length = len(self.buckets_indices)
|
||||
|
||||
@@ -1027,12 +976,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
]
|
||||
)
|
||||
|
||||
def new_cache_latents(self, is_main_process: bool, caching_strategy: LatentsCachingStrategy):
|
||||
def new_cache_latents(self, model: Any, is_main_process: bool):
|
||||
r"""
|
||||
a brand new method to cache latents. This method caches latents with caching strategy.
|
||||
normal cache_latents method is used by default, but this method is used when caching strategy is specified.
|
||||
"""
|
||||
logger.info("caching latents with caching strategy.")
|
||||
caching_strategy = LatentsCachingStrategy.get_strategy()
|
||||
image_infos = list(self.image_data.values())
|
||||
|
||||
# sort by resolution
|
||||
@@ -1088,7 +1038,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
logger.info("caching latents...")
|
||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
# cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
caching_strategy.cache_batch_latents(batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
caching_strategy.cache_batch_latents(model, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
|
||||
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
@@ -1145,6 +1095,56 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
|
||||
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
|
||||
r"""
|
||||
a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy.
|
||||
"""
|
||||
tokenize_strategy = TokenizeStrategy.get_strategy()
|
||||
text_encoding_strategy = TextEncodingStrategy.get_strategy()
|
||||
caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy()
|
||||
batch_size = caching_strategy.batch_size or self.batch_size
|
||||
|
||||
# if cache to disk, don't cache TE outputs in non-main process
|
||||
if caching_strategy.cache_to_disk and not is_main_process:
|
||||
return
|
||||
|
||||
logger.info("caching Text Encoder outputs with caching strategy.")
|
||||
image_infos = list(self.image_data.values())
|
||||
|
||||
# split by resolution
|
||||
batches = []
|
||||
batch = []
|
||||
logger.info("checking cache validity...")
|
||||
for info in tqdm(image_infos):
|
||||
te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path)
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
if caching_strategy.cache_to_disk:
|
||||
info.text_encoder_outputs_npz = te_out_npz
|
||||
cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz)
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
|
||||
batch.append(info)
|
||||
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= batch_size:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
|
||||
if len(batches) == 0:
|
||||
logger.info("no Text Encoder outputs to cache")
|
||||
return
|
||||
|
||||
# iterate batches
|
||||
logger.info("caching Text Encoder outputs...")
|
||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
# cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
caching_strategy.cache_batch_outputs(tokenize_strategy, models, text_encoding_strategy, batch)
|
||||
|
||||
# if weight_dtype is specified, Text Encoder itself and output will be converted to the dtype
|
||||
# this method is only for SDXL, but it should be implemented here because it needs to be a method of dataset
|
||||
# to support SD1/2, it needs a flag for v2, but it is postponed
|
||||
@@ -1188,6 +1188,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
logger.info("caching text encoder outputs.")
|
||||
|
||||
tokenize_strategy = TokenizeStrategy.get_strategy()
|
||||
|
||||
if batch_size is None:
|
||||
batch_size = self.batch_size
|
||||
|
||||
@@ -1229,7 +1231,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
input_ids2 = self.get_input_ids(info.caption, tokenizers[1])
|
||||
batch.append((info, input_ids1, input_ids2))
|
||||
else:
|
||||
l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption)
|
||||
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(info.caption)
|
||||
batch.append((info, l_tokens, g_tokens, t5_tokens))
|
||||
|
||||
if len(batch) >= batch_size:
|
||||
@@ -1347,7 +1349,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
loss_weights = []
|
||||
captions = []
|
||||
input_ids_list = []
|
||||
input_ids2_list = []
|
||||
latents_list = []
|
||||
alpha_mask_list = []
|
||||
images = []
|
||||
@@ -1355,16 +1356,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
crop_top_lefts = []
|
||||
target_sizes_hw = []
|
||||
flippeds = [] # 変数名が微妙
|
||||
text_encoder_outputs1_list = []
|
||||
text_encoder_outputs2_list = []
|
||||
text_encoder_pool2_list = []
|
||||
text_encoder_outputs_list = []
|
||||
|
||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||
image_info = self.image_data[image_key]
|
||||
subset = self.image_to_subset[image_key]
|
||||
loss_weights.append(
|
||||
self.prior_loss_weight if image_info.is_reg else 1.0
|
||||
) # in case of fine tuning, is_reg is always False
|
||||
|
||||
# in case of fine tuning, is_reg is always False
|
||||
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
||||
|
||||
flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance
|
||||
|
||||
@@ -1381,7 +1380,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
image = None
|
||||
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = load_latents_from_disk(image_info.latents_npz)
|
||||
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = (
|
||||
self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz)
|
||||
)
|
||||
if flipped:
|
||||
latents = flipped_latents
|
||||
alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem
|
||||
@@ -1470,75 +1471,67 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# captionとtext encoder outputを処理する
|
||||
caption = image_info.caption # default
|
||||
if image_info.text_encoder_outputs1 is not None:
|
||||
text_encoder_outputs1_list.append(image_info.text_encoder_outputs1)
|
||||
text_encoder_outputs2_list.append(image_info.text_encoder_outputs2)
|
||||
text_encoder_pool2_list.append(image_info.text_encoder_pool2)
|
||||
captions.append(caption)
|
||||
|
||||
tokenization_required = (
|
||||
self.text_encoder_output_caching_strategy is None or self.text_encoder_output_caching_strategy.is_partial
|
||||
)
|
||||
text_encoder_outputs = None
|
||||
input_ids = None
|
||||
|
||||
if image_info.text_encoder_outputs is not None:
|
||||
# cached
|
||||
text_encoder_outputs = image_info.text_encoder_outputs
|
||||
elif image_info.text_encoder_outputs_npz is not None:
|
||||
text_encoder_outputs1, text_encoder_outputs2, text_encoder_pool2 = load_text_encoder_outputs_from_disk(
|
||||
# on disk
|
||||
text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz(
|
||||
image_info.text_encoder_outputs_npz
|
||||
)
|
||||
text_encoder_outputs1_list.append(text_encoder_outputs1)
|
||||
text_encoder_outputs2_list.append(text_encoder_outputs2)
|
||||
text_encoder_pool2_list.append(text_encoder_pool2)
|
||||
captions.append(caption)
|
||||
else:
|
||||
tokenization_required = True
|
||||
text_encoder_outputs_list.append(text_encoder_outputs)
|
||||
|
||||
if tokenization_required:
|
||||
caption = self.process_caption(subset, image_info.caption)
|
||||
if self.XTI_layers:
|
||||
caption_layer = []
|
||||
for layer in self.XTI_layers:
|
||||
token_strings_from = " ".join(self.token_strings)
|
||||
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
|
||||
caption_ = caption.replace(token_strings_from, token_strings_to)
|
||||
caption_layer.append(caption_)
|
||||
captions.append(caption_layer)
|
||||
else:
|
||||
captions.append(caption)
|
||||
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension
|
||||
# if self.XTI_layers:
|
||||
# caption_layer = []
|
||||
# for layer in self.XTI_layers:
|
||||
# token_strings_from = " ".join(self.token_strings)
|
||||
# token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
|
||||
# caption_ = caption.replace(token_strings_from, token_strings_to)
|
||||
# caption_layer.append(caption_)
|
||||
# captions.append(caption_layer)
|
||||
# else:
|
||||
# captions.append(caption)
|
||||
|
||||
if not self.token_padding_disabled: # this option might be omitted in future
|
||||
# TODO get_input_ids must support SD3
|
||||
if self.XTI_layers:
|
||||
token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
|
||||
else:
|
||||
token_caption = self.get_input_ids(caption, self.tokenizers[0])
|
||||
input_ids_list.append(token_caption)
|
||||
# if not self.token_padding_disabled: # this option might be omitted in future
|
||||
# # TODO get_input_ids must support SD3
|
||||
# if self.XTI_layers:
|
||||
# token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
|
||||
# else:
|
||||
# token_caption = self.get_input_ids(caption, self.tokenizers[0])
|
||||
# input_ids_list.append(token_caption)
|
||||
|
||||
if len(self.tokenizers) > 1:
|
||||
if self.XTI_layers:
|
||||
token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
|
||||
else:
|
||||
token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
|
||||
input_ids2_list.append(token_caption2)
|
||||
# if len(self.tokenizers) > 1:
|
||||
# if self.XTI_layers:
|
||||
# token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
|
||||
# else:
|
||||
# token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
|
||||
# input_ids2_list.append(token_caption2)
|
||||
|
||||
input_ids_list.append(input_ids)
|
||||
captions.append(caption)
|
||||
|
||||
def none_or_stack_elements(tensors_list, converter):
|
||||
# [[clip_l, clip_g, t5xxl], [clip_l, clip_g, t5xxl], ...] -> [torch.stack(clip_l), torch.stack(clip_g), torch.stack(t5xxl)]
|
||||
if len(tensors_list) == 0 or tensors_list[0] == None or len(tensors_list[0]) == 0 or tensors_list[0][0] is None:
|
||||
return None
|
||||
return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))]
|
||||
|
||||
example = {}
|
||||
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
||||
|
||||
if len(text_encoder_outputs1_list) == 0:
|
||||
if self.token_padding_disabled:
|
||||
# padding=True means pad in the batch
|
||||
example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
||||
if len(self.tokenizers) > 1:
|
||||
example["input_ids2"] = self.tokenizer[1](
|
||||
captions, padding=True, truncation=True, return_tensors="pt"
|
||||
).input_ids
|
||||
else:
|
||||
example["input_ids2"] = None
|
||||
else:
|
||||
example["input_ids"] = torch.stack(input_ids_list)
|
||||
example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None
|
||||
example["text_encoder_outputs1_list"] = None
|
||||
example["text_encoder_outputs2_list"] = None
|
||||
example["text_encoder_pool2_list"] = None
|
||||
else:
|
||||
example["input_ids"] = None
|
||||
example["input_ids2"] = None
|
||||
# # for assertion
|
||||
# example["input_ids"] = torch.stack([self.get_input_ids(cap, self.tokenizers[0]) for cap in captions])
|
||||
# example["input_ids2"] = torch.stack([self.get_input_ids(cap, self.tokenizers[1]) for cap in captions])
|
||||
example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list)
|
||||
example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list)
|
||||
example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list)
|
||||
example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor)
|
||||
example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x)
|
||||
|
||||
# if one of alpha_masks is not None, we need to replace None with ones
|
||||
none_or_not = [x is None for x in alpha_mask_list]
|
||||
@@ -1652,8 +1645,6 @@ class DreamBoothDataset(BaseDataset):
|
||||
self,
|
||||
subsets: Sequence[DreamBoothSubset],
|
||||
batch_size: int,
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
resolution,
|
||||
network_multiplier: float,
|
||||
enable_bucket: bool,
|
||||
@@ -1664,7 +1655,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
prior_loss_weight: float,
|
||||
debug_dataset: bool,
|
||||
) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
|
||||
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
||||
|
||||
@@ -1750,10 +1741,10 @@ class DreamBoothDataset(BaseDataset):
|
||||
# new caching: get image size from cache files
|
||||
strategy = LatentsCachingStrategy.get_strategy()
|
||||
if strategy is not None:
|
||||
logger.info("get image size from cache files")
|
||||
logger.info("get image size from name of cache files")
|
||||
size_set_count = 0
|
||||
for i, img_path in enumerate(tqdm(img_paths)):
|
||||
w, h = strategy.get_image_size_from_image_absolute_path(img_path)
|
||||
w, h = strategy.get_image_size_from_disk_cache_path(img_path)
|
||||
if w is not None and h is not None:
|
||||
sizes[i] = [w, h]
|
||||
size_set_count += 1
|
||||
@@ -1886,8 +1877,6 @@ class FineTuningDataset(BaseDataset):
|
||||
self,
|
||||
subsets: Sequence[FineTuningSubset],
|
||||
batch_size: int,
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
resolution,
|
||||
network_multiplier: float,
|
||||
enable_bucket: bool,
|
||||
@@ -1897,7 +1886,7 @@ class FineTuningDataset(BaseDataset):
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset: bool,
|
||||
) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
|
||||
self.batch_size = batch_size
|
||||
|
||||
@@ -2111,8 +2100,6 @@ class ControlNetDataset(BaseDataset):
|
||||
self,
|
||||
subsets: Sequence[ControlNetSubset],
|
||||
batch_size: int,
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
resolution,
|
||||
network_multiplier: float,
|
||||
enable_bucket: bool,
|
||||
@@ -2122,7 +2109,7 @@ class ControlNetDataset(BaseDataset):
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset: float,
|
||||
) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
|
||||
db_subsets = []
|
||||
for subset in subsets:
|
||||
@@ -2160,8 +2147,6 @@ class ControlNetDataset(BaseDataset):
|
||||
self.dreambooth_dataset_delegate = DreamBoothDataset(
|
||||
db_subsets,
|
||||
batch_size,
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
resolution,
|
||||
network_multiplier,
|
||||
enable_bucket,
|
||||
@@ -2221,6 +2206,9 @@ class ControlNetDataset(BaseDataset):
|
||||
|
||||
self.conditioning_image_transforms = IMAGE_TRANSFORMS
|
||||
|
||||
def set_current_strategies(self):
|
||||
return self.dreambooth_dataset_delegate.set_current_strategies()
|
||||
|
||||
def make_buckets(self):
|
||||
self.dreambooth_dataset_delegate.make_buckets()
|
||||
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
|
||||
@@ -2229,6 +2217,12 @@ class ControlNetDataset(BaseDataset):
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||
return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
|
||||
|
||||
def new_cache_latents(self, model: Any, is_main_process: bool):
|
||||
return self.dreambooth_dataset_delegate.new_cache_latents(model, is_main_process)
|
||||
|
||||
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
|
||||
return self.dreambooth_dataset_delegate.new_cache_text_encoder_outputs(models, is_main_process)
|
||||
|
||||
def __len__(self):
|
||||
return self.dreambooth_dataset_delegate.__len__()
|
||||
|
||||
@@ -2314,6 +2308,13 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
# for dataset in self.datasets:
|
||||
# dataset.make_buckets()
|
||||
|
||||
def set_text_encoder_output_caching_strategy(self, strategy: TextEncoderOutputsCachingStrategy):
|
||||
"""
|
||||
DataLoader is run in multiple processes, so we need to set the strategy manually.
|
||||
"""
|
||||
for dataset in self.datasets:
|
||||
dataset.set_text_encoder_output_caching_strategy(strategy)
|
||||
|
||||
def enable_XTI(self, *args, **kwargs):
|
||||
for dataset in self.datasets:
|
||||
dataset.enable_XTI(*args, **kwargs)
|
||||
@@ -2323,10 +2324,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix)
|
||||
|
||||
def new_cache_latents(self, is_main_process: bool, strategy: LatentsCachingStrategy):
|
||||
def new_cache_latents(self, model: Any, is_main_process: bool):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.new_cache_latents(is_main_process, strategy)
|
||||
dataset.new_cache_latents(model, is_main_process)
|
||||
|
||||
def cache_text_encoder_outputs(
|
||||
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
|
||||
@@ -2344,6 +2345,11 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size
|
||||
)
|
||||
|
||||
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.new_cache_text_encoder_outputs(models, is_main_process)
|
||||
|
||||
def set_caching_mode(self, caching_mode):
|
||||
for dataset in self.datasets:
|
||||
dataset.set_caching_mode(caching_mode)
|
||||
@@ -2358,6 +2364,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
def is_text_encoder_output_cacheable(self) -> bool:
|
||||
return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets])
|
||||
|
||||
def set_current_strategies(self):
|
||||
for dataset in self.datasets:
|
||||
dataset.set_current_strategies()
|
||||
|
||||
def set_current_epoch(self, epoch):
|
||||
for dataset in self.datasets:
|
||||
dataset.set_current_epoch(epoch)
|
||||
@@ -2411,34 +2421,34 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
|
||||
|
||||
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
|
||||
# TODO update to use CachingStrategy
|
||||
def load_latents_from_disk(
|
||||
npz_path,
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
npz = np.load(npz_path)
|
||||
if "latents" not in npz:
|
||||
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
|
||||
# def load_latents_from_disk(
|
||||
# npz_path,
|
||||
# ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
# npz = np.load(npz_path)
|
||||
# if "latents" not in npz:
|
||||
# raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
|
||||
|
||||
latents = npz["latents"]
|
||||
original_size = npz["original_size"].tolist()
|
||||
crop_ltrb = npz["crop_ltrb"].tolist()
|
||||
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
|
||||
alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
|
||||
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
|
||||
# latents = npz["latents"]
|
||||
# original_size = npz["original_size"].tolist()
|
||||
# crop_ltrb = npz["crop_ltrb"].tolist()
|
||||
# flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
|
||||
# alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
|
||||
# return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
|
||||
|
||||
|
||||
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None):
|
||||
kwargs = {}
|
||||
if flipped_latents_tensor is not None:
|
||||
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
|
||||
if alpha_mask is not None:
|
||||
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
|
||||
np.savez(
|
||||
npz_path,
|
||||
latents=latents_tensor.float().cpu().numpy(),
|
||||
original_size=np.array(original_size),
|
||||
crop_ltrb=np.array(crop_ltrb),
|
||||
**kwargs,
|
||||
)
|
||||
# def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None):
|
||||
# kwargs = {}
|
||||
# if flipped_latents_tensor is not None:
|
||||
# kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
|
||||
# if alpha_mask is not None:
|
||||
# kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
|
||||
# np.savez(
|
||||
# npz_path,
|
||||
# latents=latents_tensor.float().cpu().numpy(),
|
||||
# original_size=np.array(original_size),
|
||||
# crop_ltrb=np.array(crop_ltrb),
|
||||
# **kwargs,
|
||||
# )
|
||||
|
||||
|
||||
def debug_dataset(train_dataset, show_input_ids=False):
|
||||
@@ -2465,12 +2475,12 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
example = train_dataset[idx]
|
||||
if example["latents"] is not None:
|
||||
logger.info(f"sample has latents from npz file: {example['latents'].size()}")
|
||||
for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate(
|
||||
for j, (ik, cap, lw, orgsz, crptl, trgsz, flpdz) in enumerate(
|
||||
zip(
|
||||
example["image_keys"],
|
||||
example["captions"],
|
||||
example["loss_weights"],
|
||||
example["input_ids"],
|
||||
# example["input_ids"],
|
||||
example["original_sizes_hw"],
|
||||
example["crop_top_lefts"],
|
||||
example["target_sizes_hw"],
|
||||
@@ -2483,10 +2493,10 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
if "network_multipliers" in example:
|
||||
print(f"network multiplier: {example['network_multipliers'][j]}")
|
||||
|
||||
if show_input_ids:
|
||||
logger.info(f"input ids: {iid}")
|
||||
if "input_ids2" in example:
|
||||
logger.info(f"input ids2: {example['input_ids2'][j]}")
|
||||
# if show_input_ids:
|
||||
# logger.info(f"input ids: {iid}")
|
||||
# if "input_ids2" in example:
|
||||
# logger.info(f"input ids2: {example['input_ids2'][j]}")
|
||||
if example["images"] is not None:
|
||||
im = example["images"][j]
|
||||
logger.info(f"image size: {im.size()}")
|
||||
@@ -2555,8 +2565,8 @@ def glob_images_pathlib(dir_path, recursive):
|
||||
|
||||
|
||||
class MinimalDataset(BaseDataset):
|
||||
def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False):
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
def __init__(self, resolution, network_multiplier, debug_dataset=False):
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
|
||||
self.num_train_images = 0 # update in subclass
|
||||
self.num_reg_images = 0 # update in subclass
|
||||
@@ -2773,14 +2783,15 @@ def cache_batch_latents(
|
||||
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
|
||||
|
||||
if cache_to_disk:
|
||||
save_latents_to_disk(
|
||||
info.latents_npz,
|
||||
latent,
|
||||
info.latents_original_size,
|
||||
info.latents_crop_ltrb,
|
||||
flipped_latent,
|
||||
alpha_mask,
|
||||
)
|
||||
# save_latents_to_disk(
|
||||
# info.latents_npz,
|
||||
# latent,
|
||||
# info.latents_original_size,
|
||||
# info.latents_crop_ltrb,
|
||||
# flipped_latent,
|
||||
# alpha_mask,
|
||||
# )
|
||||
pass
|
||||
else:
|
||||
info.latents = latent
|
||||
if flip_aug:
|
||||
@@ -4662,33 +4673,6 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
||||
)
|
||||
|
||||
|
||||
def load_tokenizer(args: argparse.Namespace):
|
||||
logger.info("prepare tokenizer")
|
||||
original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH
|
||||
|
||||
tokenizer: CLIPTokenizer = None
|
||||
if args.tokenizer_cache_dir:
|
||||
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
|
||||
if os.path.exists(local_tokenizer_path):
|
||||
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2
|
||||
|
||||
if tokenizer is None:
|
||||
if args.v2:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer")
|
||||
else:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
||||
|
||||
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
||||
logger.info(f"update token length: {args.max_token_length}")
|
||||
|
||||
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
||||
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
|
||||
tokenizer.save_pretrained(local_tokenizer_path)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def prepare_accelerator(args: argparse.Namespace):
|
||||
"""
|
||||
this function also prepares deepspeed plugin
|
||||
@@ -5550,6 +5534,7 @@ def sample_images_common(
|
||||
):
|
||||
"""
|
||||
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
|
||||
TODO Use strategies here
|
||||
"""
|
||||
|
||||
if steps == 0:
|
||||
|
||||
Reference in New Issue
Block a user