Update train_util.py

This commit is contained in:
DKnight54
2025-05-26 01:02:52 +08:00
committed by GitHub
parent 16584de262
commit d7838ff50c

View File

@@ -29,6 +29,7 @@ import hashlib
import subprocess
from io import BytesIO
import toml
import copy
from tqdm import tqdm
@@ -164,6 +165,8 @@ class ImageInfo:
self.text_encoder_outputs2: Optional[torch.Tensor] = None
self.text_encoder_pool2: Optional[torch.Tensor] = None
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
self.latent_cache_checked: bool = False
self.te_cache_checked: bool = False
class BucketManager:
@@ -653,6 +656,11 @@ class BaseDataset(torch.utils.data.Dataset):
# caching
self.caching_mode = None # None, 'latents', 'text'
# lists for incremental loading of regularization images
self.reg_infos = None
self.reg_infos_index = None
self.reg_randomize = False
def adjust_min_max_bucket_reso_by_steps(
self, resolution: Tuple[int, int], min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int
) -> Tuple[int, int]:
@@ -684,6 +692,12 @@ class BaseDataset(torch.utils.data.Dataset):
def set_seed(self, seed):
self.seed = seed
def set_reg_randomize(self, reg_randomize = False):
self.reg_randomize = reg_randomize
def incremental_reg_load(self, make_bucket = False): # Placeholder method, does nothing unless overridden in subclasses.
return
def set_caching_mode(self, mode):
self.caching_mode = mode
@@ -951,11 +965,14 @@ class BaseDataset(torch.utils.data.Dataset):
if self.enable_bucket:
self.bucket_info = {"buckets": {}}
logger.info("number of images (including repeats) / 各bucketの画像枚数繰り返し回数を含む")
batch_count: int = 0
for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)):
count = len(bucket)
if count > 0:
batch_count += math.ceil(len(bucket) / self.batch_size)
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}, batches: {int(math.ceil(len(bucket) / self.batch_size))}")
logger.info(f"Total batch count: {batch_count}")
if len(img_ar_errors) == 0:
mean_img_ar_error = 0 # avoid NaN
@@ -967,6 +984,7 @@ class BaseDataset(torch.utils.data.Dataset):
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
self.buckets_indices: List[BucketBatchIndex] = []
self.buckets_indices.clear()
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
batch_count = int(math.ceil(len(bucket) / self.batch_size))
for batch_index in range(batch_count):
@@ -1025,6 +1043,10 @@ class BaseDataset(torch.utils.data.Dataset):
logger.info("caching latents.")
image_infos = list(self.image_data.values())
image_infos = list(filter(lambda info: info.latent_cache_checked == False, image_infos))
if len(image_infos) == 0:
logger.info("All images latents previously checked and cached. Skipping.")
return
# sort by resolution
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
@@ -1054,11 +1076,17 @@ class BaseDataset(torch.utils.data.Dataset):
subset = self.image_to_subset[info.image_key]
if info.latents_npz is not None: # fine tuning dataset
info.latent_cache_checked = True
if self.reg_infos is not None and info.image_key in self.reg_infos:
self.reg_infos[info.image_key][0].latent_cache_checked = True
continue
# check disk cache exists and size of latents
if cache_to_disk:
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
info.latent_cache_checked = True
if self.reg_infos is not None and info.image_key in self.reg_infos:
self.reg_infos[info.image_key][0].latent_cache_checked = True
if not is_main_process: # store to info only
continue
@@ -1094,6 +1122,15 @@ class BaseDataset(torch.utils.data.Dataset):
logger.info("caching latents...")
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
if self.reg_infos is not None:
for info in batch:
if info.image_key in self.reg_infos:
self.reg_infos[info.image_key][0].latents_npz = info.latents_npz
self.reg_infos[info.image_key][0].latents_original_size = info.latents_original_size
self.reg_infos[info.image_key][0].latents_crop_ltrb = info.latents_crop_ltrb
self.reg_infos[info.image_key][0].latents_crop_ltrb = info.latents_flipped
self.reg_infos[info.image_key][0].latents = info.latents
self.reg_infos[info.image_key][0].alpha_mask = info.alpha_mask
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
# SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する
@@ -1107,6 +1144,10 @@ class BaseDataset(torch.utils.data.Dataset):
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
logger.info("caching text encoder outputs.")
image_infos = list(self.image_data.values())
image_infos = list(filter(lambda info: info.te_cache_checked == False, image_infos))
if len(image_infos) == 0:
logger.info("Text encoder outputs for all images previously checked and cached. Skipping.")
return
logger.info("checking cache existence...")
image_infos_to_cache = []
@@ -1115,6 +1156,10 @@ class BaseDataset(torch.utils.data.Dataset):
if cache_to_disk:
te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
info.text_encoder_outputs_npz = te_out_npz
info.te_cache_checked = True
if self.reg_infos is not None:
self.reg_infos[info.image_key][0].text_encoder_outputs_npz = te_out_npz
self.reg_infos[info.image_key][0].te_cache_checked = True
if not is_main_process: # store to info only
continue
@@ -1157,6 +1202,14 @@ class BaseDataset(torch.utils.data.Dataset):
cache_batch_text_encoder_outputs(
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype
)
if self.reg_infos is not None:
for info in batch:
if info.image_key in self.reg_infos:
self.reg_infos[info.image_key][0].text_encoder_outputs_npz = te_out_npz
self.reg_infos[info.image_key][0].te_cache_checked = True
self.reg_infos[info.image_key][0]. = info.text_encoder_outputs1 = hidden_state1
self.reg_infos[info.image_key][0]. = info.text_encoder_outputs2 = hidden_state2
self.reg_infos[info.image_key][0]. = info.text_encoder_pool2 = pool2
def get_image_size(self, image_path):
return imagesize.get(image_path)
@@ -1561,6 +1614,8 @@ class DreamBoothDataset(BaseDataset):
self.size = min(self.width, self.height) # 短いほう
self.prior_loss_weight = prior_loss_weight
self.latents_cache = None
self.reg_infos: Dict[str, Tuple[ImageInfo, DreamBoothSubset]] = {}
self.reg_infos_index: List[str] = []
self.enable_bucket = enable_bucket
if self.enable_bucket:
@@ -1689,7 +1744,6 @@ class DreamBoothDataset(BaseDataset):
logger.info("prepare images.")
num_train_images = 0
num_reg_images = 0
reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = []
for subset in subsets:
if subset.num_repeats < 1:
logger.warning(
@@ -1711,7 +1765,11 @@ class DreamBoothDataset(BaseDataset):
continue
if subset.is_reg:
num_reg_images += subset.num_repeats * len(img_paths)
if subset.num_repeats > 1:
info.num_repeats = 1
self.reg_infos[info.image_key] = (info, subset)
for i in range(subset.num_repeats):
self.reg_infos_index.append(info.image_key)
else:
num_train_images += subset.num_repeats * len(img_paths)
@@ -1731,30 +1789,88 @@ class DreamBoothDataset(BaseDataset):
self.num_train_images = num_train_images
logger.info(f"{num_reg_images} reg images.")
if num_train_images < num_reg_images:
logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
if num_reg_images == 0:
logger.warning("no regularization images / 正則化画像が見つかりませんでした")
else:
# num_repeatsを計算するどうせ大した数ではないのでループで処理する
n = 0
first_loop = True
while n < num_train_images:
for info, subset in reg_infos:
if first_loop:
self.register_image(info, subset)
n += info.num_repeats
else:
info.num_repeats += 1 # rewrite registered info
n += 1
if n >= num_train_images:
break
first_loop = False
self.num_reg_images = num_reg_images
def set_reg_randomize(self, reg_randomize = False):
self.reg_randomize = reg_randomize
# As first set of data is loaded before the first opportunity to shuffle, will need to force reset self.reg_infos_index_traverser and reinitialize dataset
self.reg_infos_index_traverser = 0
self.bucket_manager = None
self.incremental_reg_load(True)
def subset_loaded_count(self):
count_str = ""
for index, subset in enumerate(self.subsets):
counter = 0
count_str += f"\nSubset {index} (Class: {subset.class_tokens}): " if isinstance(subset, DreamBoothSubset) and subset.class_tokens is not None else f"\nSubset {index}: "
img_keys = [key for key, value in self.image_to_subset.items() if value == subset]
for img_key in img_keys:
counter += self.image_data[img_key].num_repeats
count_str += f"{counter}/{subset.img_count * subset.num_repeats}"
count_str += f"\nSubset dir: {subset.image_dir}" if subset.image_dir is not None else ""
count_str += f"\n\n"
logger.info(count_str)
def incremental_reg_load(self, make_bucket = False):
#override to for loading random reg images
distributed_state = PartialState()
if self.num_reg_images == 0:
logger.warning("no regularization images / 正則化画像が見つかりませんでした")
return
if self.num_train_images < self.num_reg_images:
logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
if not self.num_train_images == self.num_reg_images:
logger.info(f"Inititating loading of regularizaion images.")
for info, subset in self.reg_infos.values():
if info.image_key in self.image_data:
self.image_data.pop(info.image_key, None)
self.image_to_subset.pop(info.image_key, None)
temp_reg_infos = copy.deepcopy(self.reg_infos)
n = 0
first_loop = True
logger.info(f"self.reg_infos_index_traverser at: {self.reg_infos_index_traverser}\n reg_infos_index len = {len(self.reg_infos_index)}")
reg_img_log = f"\nDataset seed: {self.seed}"
start_index = self.reg_infos_index_traverser
while n < self.num_train_images :
if self.reg_randomize and self.reg_infos_index_traverser == 0:
if distributed_state.num_processes > 1:
if not distributed_state.is_main_process:
self.reg_infos_index = []
else:
random.shuffle(self.reg_infos_index)
distributed_state.wait_for_everyone()
self.reg_infos_index = gather_object(self.reg_infos_index)
else:
random.shuffle(self.reg_infos_index)
info, subset = temp_reg_infos[self.reg_infos_index[self.reg_infos_index_traverser]]
if info.image_key in self.image_data:
info.num_repeats += 1 # rewrite registered info
else:
self.register_image(info, subset)
self.reg_infos_index_traverser += 1
if self.reg_infos_index_traverser % len(self.reg_infos_index) == 0:
self.reg_infos_index_traverser = 0
'''
if n < 5:
reg_img_log += f"\nRegistering image: {info.absolute_path}, count: {info.num_repeats}"
'''
n += 1
# logger.info(reg_img_log)
if distributed_state.is_main_process:
self.subset_loaded_count()
self.bucket_manager = None
if make_bucket:
self.make_buckets()
del temp_reg_infos
else:
logger.warning(f"Number of training images({self.num_train_images}) is the same as number of regularization images({self.num_reg_images}).\nSkipping randomized/incremental loading of regularization images.")
class FineTuningDataset(BaseDataset):
def __init__(
self,
@@ -2098,6 +2214,11 @@ class ControlNetDataset(BaseDataset):
self.conditioning_image_transforms = IMAGE_TRANSFORMS
def incremental_reg_load(self, make_bucket = False):
self.dreambooth_dataset_delegate.incremental_reg_load()
if make_bucket:
self.make_buckets()
def make_buckets(self):
self.dreambooth_dataset_delegate.make_buckets()
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
@@ -2185,9 +2306,13 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for dataset in self.datasets:
dataset.add_replacement(str_from, str_to)
# def make_buckets(self):
# for dataset in self.datasets:
# dataset.make_buckets()
def set_reg_randomize(self, reg_randomize = False):
for dataset in self.datasets:
dataset.set_reg_randomize(reg_randomize)
def make_buckets(self):
for dataset in self.datasets:
dataset.make_buckets()
def enable_XTI(self, *args, **kwargs):
for dataset in self.datasets:
@@ -2234,7 +2359,14 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
def disable_token_padding(self):
for dataset in self.datasets:
dataset.disable_token_padding()
def incremental_reg_load(self, make_bucket = False):
for dataset in self.datasets:
dataset.incremental_reg_load(make_bucket)
def __len__(self):
self.cumulative_sizes = self.cumsum(self.datasets)
return self.cumulative_sizes[-1]
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alpha_mask: bool):
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
@@ -3579,6 +3711,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
)
parser.add_argument(
"--incremental_reg_load",
action="store_true",
help="Forces reload of regularization images at each Epoch. Will sequentially load regularization images unless '--randomized_regularization_image' is set. Useful if there are more regularization images than training images",
)
parser.add_argument(
"--randomized_regularization_image",
action="store_true",
help="Shuffles regularization images to even out distribution. Useful if there are more regularization images than training images",
)
if support_dreambooth:
# DreamBooth training