mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
Update train_util.py
This commit is contained in:
@@ -29,6 +29,7 @@ import hashlib
|
||||
import subprocess
|
||||
from io import BytesIO
|
||||
import toml
|
||||
import copy
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -164,6 +165,8 @@ class ImageInfo:
|
||||
self.text_encoder_outputs2: Optional[torch.Tensor] = None
|
||||
self.text_encoder_pool2: Optional[torch.Tensor] = None
|
||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||
self.latent_cache_checked: bool = False
|
||||
self.te_cache_checked: bool = False
|
||||
|
||||
|
||||
class BucketManager:
|
||||
@@ -653,6 +656,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# caching
|
||||
self.caching_mode = None # None, 'latents', 'text'
|
||||
|
||||
# lists for incremental loading of regularization images
|
||||
self.reg_infos = None
|
||||
self.reg_infos_index = None
|
||||
self.reg_randomize = False
|
||||
|
||||
def adjust_min_max_bucket_reso_by_steps(
|
||||
self, resolution: Tuple[int, int], min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int
|
||||
) -> Tuple[int, int]:
|
||||
@@ -684,6 +692,12 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
def set_seed(self, seed):
|
||||
self.seed = seed
|
||||
|
||||
def set_reg_randomize(self, reg_randomize = False):
|
||||
self.reg_randomize = reg_randomize
|
||||
|
||||
def incremental_reg_load(self, make_bucket = False): # Placeholder method, does nothing unless overridden in subclasses.
|
||||
return
|
||||
|
||||
def set_caching_mode(self, mode):
|
||||
self.caching_mode = mode
|
||||
|
||||
@@ -951,11 +965,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if self.enable_bucket:
|
||||
self.bucket_info = {"buckets": {}}
|
||||
logger.info("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
|
||||
batch_count: int = 0
|
||||
for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)):
|
||||
count = len(bucket)
|
||||
if count > 0:
|
||||
batch_count += math.ceil(len(bucket) / self.batch_size)
|
||||
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
|
||||
logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
|
||||
logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}, batches: {int(math.ceil(len(bucket) / self.batch_size))}")
|
||||
logger.info(f"Total batch count: {batch_count}")
|
||||
|
||||
if len(img_ar_errors) == 0:
|
||||
mean_img_ar_error = 0 # avoid NaN
|
||||
@@ -967,6 +984,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
|
||||
self.buckets_indices: List[BucketBatchIndex] = []
|
||||
self.buckets_indices.clear()
|
||||
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
|
||||
batch_count = int(math.ceil(len(bucket) / self.batch_size))
|
||||
for batch_index in range(batch_count):
|
||||
@@ -1025,6 +1043,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
logger.info("caching latents.")
|
||||
|
||||
image_infos = list(self.image_data.values())
|
||||
image_infos = list(filter(lambda info: info.latent_cache_checked == False, image_infos))
|
||||
if len(image_infos) == 0:
|
||||
logger.info("All images latents previously checked and cached. Skipping.")
|
||||
return
|
||||
|
||||
# sort by resolution
|
||||
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
|
||||
@@ -1054,11 +1076,17 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
|
||||
if info.latents_npz is not None: # fine tuning dataset
|
||||
info.latent_cache_checked = True
|
||||
if self.reg_infos is not None and info.image_key in self.reg_infos:
|
||||
self.reg_infos[info.image_key][0].latent_cache_checked = True
|
||||
continue
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
if cache_to_disk:
|
||||
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
|
||||
info.latent_cache_checked = True
|
||||
if self.reg_infos is not None and info.image_key in self.reg_infos:
|
||||
self.reg_infos[info.image_key][0].latent_cache_checked = True
|
||||
if not is_main_process: # store to info only
|
||||
continue
|
||||
|
||||
@@ -1094,6 +1122,15 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
logger.info("caching latents...")
|
||||
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
|
||||
if self.reg_infos is not None:
|
||||
for info in batch:
|
||||
if info.image_key in self.reg_infos:
|
||||
self.reg_infos[info.image_key][0].latents_npz = info.latents_npz
|
||||
self.reg_infos[info.image_key][0].latents_original_size = info.latents_original_size
|
||||
self.reg_infos[info.image_key][0].latents_crop_ltrb = info.latents_crop_ltrb
|
||||
self.reg_infos[info.image_key][0].latents_crop_ltrb = info.latents_flipped
|
||||
self.reg_infos[info.image_key][0].latents = info.latents
|
||||
self.reg_infos[info.image_key][0].alpha_mask = info.alpha_mask
|
||||
|
||||
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
|
||||
# SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する
|
||||
@@ -1107,6 +1144,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
logger.info("caching text encoder outputs.")
|
||||
image_infos = list(self.image_data.values())
|
||||
image_infos = list(filter(lambda info: info.te_cache_checked == False, image_infos))
|
||||
if len(image_infos) == 0:
|
||||
logger.info("Text encoder outputs for all images previously checked and cached. Skipping.")
|
||||
return
|
||||
|
||||
logger.info("checking cache existence...")
|
||||
image_infos_to_cache = []
|
||||
@@ -1115,6 +1156,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if cache_to_disk:
|
||||
te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
|
||||
info.text_encoder_outputs_npz = te_out_npz
|
||||
info.te_cache_checked = True
|
||||
if self.reg_infos is not None:
|
||||
self.reg_infos[info.image_key][0].text_encoder_outputs_npz = te_out_npz
|
||||
self.reg_infos[info.image_key][0].te_cache_checked = True
|
||||
|
||||
if not is_main_process: # store to info only
|
||||
continue
|
||||
@@ -1157,6 +1202,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
cache_batch_text_encoder_outputs(
|
||||
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype
|
||||
)
|
||||
if self.reg_infos is not None:
|
||||
for info in batch:
|
||||
if info.image_key in self.reg_infos:
|
||||
self.reg_infos[info.image_key][0].text_encoder_outputs_npz = te_out_npz
|
||||
self.reg_infos[info.image_key][0].te_cache_checked = True
|
||||
self.reg_infos[info.image_key][0]. = info.text_encoder_outputs1 = hidden_state1
|
||||
self.reg_infos[info.image_key][0]. = info.text_encoder_outputs2 = hidden_state2
|
||||
self.reg_infos[info.image_key][0]. = info.text_encoder_pool2 = pool2
|
||||
|
||||
def get_image_size(self, image_path):
|
||||
return imagesize.get(image_path)
|
||||
@@ -1561,6 +1614,8 @@ class DreamBoothDataset(BaseDataset):
|
||||
self.size = min(self.width, self.height) # 短いほう
|
||||
self.prior_loss_weight = prior_loss_weight
|
||||
self.latents_cache = None
|
||||
self.reg_infos: Dict[str, Tuple[ImageInfo, DreamBoothSubset]] = {}
|
||||
self.reg_infos_index: List[str] = []
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
@@ -1689,7 +1744,6 @@ class DreamBoothDataset(BaseDataset):
|
||||
logger.info("prepare images.")
|
||||
num_train_images = 0
|
||||
num_reg_images = 0
|
||||
reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = []
|
||||
for subset in subsets:
|
||||
if subset.num_repeats < 1:
|
||||
logger.warning(
|
||||
@@ -1711,7 +1765,11 @@ class DreamBoothDataset(BaseDataset):
|
||||
continue
|
||||
|
||||
if subset.is_reg:
|
||||
num_reg_images += subset.num_repeats * len(img_paths)
|
||||
if subset.num_repeats > 1:
|
||||
info.num_repeats = 1
|
||||
self.reg_infos[info.image_key] = (info, subset)
|
||||
for i in range(subset.num_repeats):
|
||||
self.reg_infos_index.append(info.image_key)
|
||||
else:
|
||||
num_train_images += subset.num_repeats * len(img_paths)
|
||||
|
||||
@@ -1731,30 +1789,88 @@ class DreamBoothDataset(BaseDataset):
|
||||
self.num_train_images = num_train_images
|
||||
|
||||
logger.info(f"{num_reg_images} reg images.")
|
||||
if num_train_images < num_reg_images:
|
||||
logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
|
||||
|
||||
if num_reg_images == 0:
|
||||
logger.warning("no regularization images / 正則化画像が見つかりませんでした")
|
||||
else:
|
||||
# num_repeatsを計算する:どうせ大した数ではないのでループで処理する
|
||||
n = 0
|
||||
first_loop = True
|
||||
while n < num_train_images:
|
||||
for info, subset in reg_infos:
|
||||
if first_loop:
|
||||
self.register_image(info, subset)
|
||||
n += info.num_repeats
|
||||
else:
|
||||
info.num_repeats += 1 # rewrite registered info
|
||||
n += 1
|
||||
if n >= num_train_images:
|
||||
break
|
||||
first_loop = False
|
||||
|
||||
self.num_reg_images = num_reg_images
|
||||
|
||||
def set_reg_randomize(self, reg_randomize = False):
|
||||
self.reg_randomize = reg_randomize
|
||||
# As first set of data is loaded before the first opportunity to shuffle, will need to force reset self.reg_infos_index_traverser and reinitialize dataset
|
||||
self.reg_infos_index_traverser = 0
|
||||
self.bucket_manager = None
|
||||
self.incremental_reg_load(True)
|
||||
|
||||
def subset_loaded_count(self):
|
||||
count_str = ""
|
||||
for index, subset in enumerate(self.subsets):
|
||||
counter = 0
|
||||
count_str += f"\nSubset {index} (Class: {subset.class_tokens}): " if isinstance(subset, DreamBoothSubset) and subset.class_tokens is not None else f"\nSubset {index}: "
|
||||
img_keys = [key for key, value in self.image_to_subset.items() if value == subset]
|
||||
for img_key in img_keys:
|
||||
counter += self.image_data[img_key].num_repeats
|
||||
count_str += f"{counter}/{subset.img_count * subset.num_repeats}"
|
||||
count_str += f"\nSubset dir: {subset.image_dir}" if subset.image_dir is not None else ""
|
||||
count_str += f"\n\n"
|
||||
logger.info(count_str)
|
||||
|
||||
def incremental_reg_load(self, make_bucket = False):
|
||||
#override to for loading random reg images
|
||||
distributed_state = PartialState()
|
||||
|
||||
if self.num_reg_images == 0:
|
||||
logger.warning("no regularization images / 正則化画像が見つかりませんでした")
|
||||
return
|
||||
if self.num_train_images < self.num_reg_images:
|
||||
logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
|
||||
|
||||
if not self.num_train_images == self.num_reg_images:
|
||||
logger.info(f"Inititating loading of regularizaion images.")
|
||||
for info, subset in self.reg_infos.values():
|
||||
if info.image_key in self.image_data:
|
||||
self.image_data.pop(info.image_key, None)
|
||||
self.image_to_subset.pop(info.image_key, None)
|
||||
|
||||
temp_reg_infos = copy.deepcopy(self.reg_infos)
|
||||
n = 0
|
||||
first_loop = True
|
||||
logger.info(f"self.reg_infos_index_traverser at: {self.reg_infos_index_traverser}\n reg_infos_index len = {len(self.reg_infos_index)}")
|
||||
reg_img_log = f"\nDataset seed: {self.seed}"
|
||||
start_index = self.reg_infos_index_traverser
|
||||
|
||||
while n < self.num_train_images :
|
||||
if self.reg_randomize and self.reg_infos_index_traverser == 0:
|
||||
if distributed_state.num_processes > 1:
|
||||
if not distributed_state.is_main_process:
|
||||
self.reg_infos_index = []
|
||||
else:
|
||||
random.shuffle(self.reg_infos_index)
|
||||
distributed_state.wait_for_everyone()
|
||||
self.reg_infos_index = gather_object(self.reg_infos_index)
|
||||
else:
|
||||
random.shuffle(self.reg_infos_index)
|
||||
info, subset = temp_reg_infos[self.reg_infos_index[self.reg_infos_index_traverser]]
|
||||
if info.image_key in self.image_data:
|
||||
info.num_repeats += 1 # rewrite registered info
|
||||
else:
|
||||
self.register_image(info, subset)
|
||||
|
||||
self.reg_infos_index_traverser += 1
|
||||
if self.reg_infos_index_traverser % len(self.reg_infos_index) == 0:
|
||||
self.reg_infos_index_traverser = 0
|
||||
'''
|
||||
if n < 5:
|
||||
reg_img_log += f"\nRegistering image: {info.absolute_path}, count: {info.num_repeats}"
|
||||
'''
|
||||
n += 1
|
||||
|
||||
# logger.info(reg_img_log)
|
||||
if distributed_state.is_main_process:
|
||||
self.subset_loaded_count()
|
||||
self.bucket_manager = None
|
||||
if make_bucket:
|
||||
self.make_buckets()
|
||||
del temp_reg_infos
|
||||
else:
|
||||
logger.warning(f"Number of training images({self.num_train_images}) is the same as number of regularization images({self.num_reg_images}).\nSkipping randomized/incremental loading of regularization images.")
|
||||
|
||||
class FineTuningDataset(BaseDataset):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -2098,6 +2214,11 @@ class ControlNetDataset(BaseDataset):
|
||||
|
||||
self.conditioning_image_transforms = IMAGE_TRANSFORMS
|
||||
|
||||
def incremental_reg_load(self, make_bucket = False):
|
||||
self.dreambooth_dataset_delegate.incremental_reg_load()
|
||||
if make_bucket:
|
||||
self.make_buckets()
|
||||
|
||||
def make_buckets(self):
|
||||
self.dreambooth_dataset_delegate.make_buckets()
|
||||
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
|
||||
@@ -2185,9 +2306,13 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
for dataset in self.datasets:
|
||||
dataset.add_replacement(str_from, str_to)
|
||||
|
||||
# def make_buckets(self):
|
||||
# for dataset in self.datasets:
|
||||
# dataset.make_buckets()
|
||||
def set_reg_randomize(self, reg_randomize = False):
|
||||
for dataset in self.datasets:
|
||||
dataset.set_reg_randomize(reg_randomize)
|
||||
|
||||
def make_buckets(self):
|
||||
for dataset in self.datasets:
|
||||
dataset.make_buckets()
|
||||
|
||||
def enable_XTI(self, *args, **kwargs):
|
||||
for dataset in self.datasets:
|
||||
@@ -2234,7 +2359,14 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
def disable_token_padding(self):
|
||||
for dataset in self.datasets:
|
||||
dataset.disable_token_padding()
|
||||
|
||||
def incremental_reg_load(self, make_bucket = False):
|
||||
for dataset in self.datasets:
|
||||
dataset.incremental_reg_load(make_bucket)
|
||||
|
||||
def __len__(self):
|
||||
self.cumulative_sizes = self.cumsum(self.datasets)
|
||||
return self.cumulative_sizes[-1]
|
||||
|
||||
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
|
||||
@@ -3579,6 +3711,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
default=None,
|
||||
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--incremental_reg_load",
|
||||
action="store_true",
|
||||
help="Forces reload of regularization images at each Epoch. Will sequentially load regularization images unless '--randomized_regularization_image' is set. Useful if there are more regularization images than training images",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--randomized_regularization_image",
|
||||
action="store_true",
|
||||
help="Shuffles regularization images to even out distribution. Useful if there are more regularization images than training images",
|
||||
)
|
||||
|
||||
if support_dreambooth:
|
||||
# DreamBooth training
|
||||
|
||||
Reference in New Issue
Block a user