This commit is contained in:
DKnight54
2025-09-30 09:53:07 +05:30
committed by GitHub
3 changed files with 244 additions and 49 deletions

View File

@@ -576,9 +576,10 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
for i, dataset in enumerate(datasets):
logger.info(f"[Dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)
dataset.incremental_reg_load()
dataset.make_buckets()
return DatasetGroup(datasets)

View File

@@ -29,6 +29,7 @@ import hashlib
import subprocess
from io import BytesIO
import toml
import copy
from tqdm import tqdm
@@ -164,6 +165,8 @@ class ImageInfo:
self.text_encoder_outputs2: Optional[torch.Tensor] = None
self.text_encoder_pool2: Optional[torch.Tensor] = None
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
self.latent_cache_checked: bool = False
self.te_cache_checked: bool = False
class BucketManager:
@@ -653,6 +656,11 @@ class BaseDataset(torch.utils.data.Dataset):
# caching
self.caching_mode = None # None, 'latents', 'text'
# lists for incremental loading of regularization images
self.reg_infos = None
self.reg_infos_index = None
self.reg_randomize = False
def adjust_min_max_bucket_reso_by_steps(
self, resolution: Tuple[int, int], min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int
) -> Tuple[int, int]:
@@ -684,6 +692,12 @@ class BaseDataset(torch.utils.data.Dataset):
def set_seed(self, seed):
self.seed = seed
def set_reg_randomize(self, reg_randomize = False):
self.reg_randomize = reg_randomize
def incremental_reg_load(self, make_bucket = False): # Placeholder method, does nothing unless overridden in subclasses.
return
def set_caching_mode(self, mode):
self.caching_mode = mode
@@ -951,11 +965,14 @@ class BaseDataset(torch.utils.data.Dataset):
if self.enable_bucket:
self.bucket_info = {"buckets": {}}
logger.info("number of images (including repeats) / 各bucketの画像枚数繰り返し回数を含む")
batch_count: int = 0
for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)):
count = len(bucket)
if count > 0:
batch_count += math.ceil(len(bucket) / self.batch_size)
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}, batches: {int(math.ceil(len(bucket) / self.batch_size))}")
logger.info(f"Total batch count: {batch_count}")
if len(img_ar_errors) == 0:
mean_img_ar_error = 0 # avoid NaN
@@ -967,6 +984,7 @@ class BaseDataset(torch.utils.data.Dataset):
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
self.buckets_indices: List[BucketBatchIndex] = []
self.buckets_indices.clear()
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
batch_count = int(math.ceil(len(bucket) / self.batch_size))
for batch_index in range(batch_count):
@@ -1025,6 +1043,10 @@ class BaseDataset(torch.utils.data.Dataset):
logger.info("caching latents.")
image_infos = list(self.image_data.values())
image_infos = list(filter(lambda info: info.latent_cache_checked == False, image_infos))
if len(image_infos) == 0:
logger.info("All images latents previously checked and cached. Skipping.")
return
# sort by resolution
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
@@ -1054,11 +1076,17 @@ class BaseDataset(torch.utils.data.Dataset):
subset = self.image_to_subset[info.image_key]
if info.latents_npz is not None: # fine tuning dataset
info.latent_cache_checked = True
if self.reg_infos is not None and info.image_key in self.reg_infos:
self.reg_infos[info.image_key][0].latent_cache_checked = True
continue
# check disk cache exists and size of latents
if cache_to_disk:
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
info.latent_cache_checked = True
if self.reg_infos is not None and info.image_key in self.reg_infos:
self.reg_infos[info.image_key][0].latent_cache_checked = True
if not is_main_process: # store to info only
continue
@@ -1094,6 +1122,15 @@ class BaseDataset(torch.utils.data.Dataset):
logger.info("caching latents...")
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
if self.reg_infos is not None:
for info in batch:
if info.image_key in self.reg_infos:
self.reg_infos[info.image_key][0].latents_npz = info.latents_npz
self.reg_infos[info.image_key][0].latents_original_size = info.latents_original_size
self.reg_infos[info.image_key][0].latents_crop_ltrb = info.latents_crop_ltrb
self.reg_infos[info.image_key][0].latents_crop_ltrb = info.latents_flipped
self.reg_infos[info.image_key][0].latents = info.latents
self.reg_infos[info.image_key][0].alpha_mask = info.alpha_mask
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
# SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する
@@ -1107,6 +1144,10 @@ class BaseDataset(torch.utils.data.Dataset):
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
logger.info("caching text encoder outputs.")
image_infos = list(self.image_data.values())
image_infos = list(filter(lambda info: info.te_cache_checked == False, image_infos))
if len(image_infos) == 0:
logger.info("Text encoder outputs for all images previously checked and cached. Skipping.")
return
logger.info("checking cache existence...")
image_infos_to_cache = []
@@ -1115,6 +1156,10 @@ class BaseDataset(torch.utils.data.Dataset):
if cache_to_disk:
te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
info.text_encoder_outputs_npz = te_out_npz
info.te_cache_checked = True
if self.reg_infos is not None:
self.reg_infos[info.image_key][0].text_encoder_outputs_npz = te_out_npz
self.reg_infos[info.image_key][0].te_cache_checked = True
if not is_main_process: # store to info only
continue
@@ -1157,6 +1202,14 @@ class BaseDataset(torch.utils.data.Dataset):
cache_batch_text_encoder_outputs(
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype
)
if self.reg_infos is not None:
for info in batch:
if info.image_key in self.reg_infos:
self.reg_infos[info.image_key][0].text_encoder_outputs_npz = te_out_npz
self.reg_infos[info.image_key][0].te_cache_checked = True
self.reg_infos[info.image_key][0].text_encoder_outputs1 = info.text_encoder_outputs1
self.reg_infos[info.image_key][0].text_encoder_outputs2 = info.text_encoder_outputs2
self.reg_infos[info.image_key][0].text_encoder_pool2 = info.text_encoder_pool2
def get_image_size(self, image_path):
return imagesize.get(image_path)
@@ -1561,6 +1614,9 @@ class DreamBoothDataset(BaseDataset):
self.size = min(self.width, self.height) # 短いほう
self.prior_loss_weight = prior_loss_weight
self.latents_cache = None
self.reg_infos: Dict[str, Tuple[ImageInfo, DreamBoothSubset]] = {}
self.reg_infos_index: List[str] = []
self.reg_infos_index_traverser = 0
self.enable_bucket = enable_bucket
if self.enable_bucket:
@@ -1689,7 +1745,6 @@ class DreamBoothDataset(BaseDataset):
logger.info("prepare images.")
num_train_images = 0
num_reg_images = 0
reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = []
for subset in subsets:
if subset.num_repeats < 1:
logger.warning(
@@ -1720,7 +1775,11 @@ class DreamBoothDataset(BaseDataset):
if size is not None:
info.image_size = size
if subset.is_reg:
reg_infos.append((info, subset))
if subset.num_repeats > 1:
info.num_repeats = 1
self.reg_infos[info.image_key] = (info, subset)
for i in range(subset.num_repeats):
self.reg_infos_index.append(info.image_key)
else:
self.register_image(info, subset)
@@ -1731,30 +1790,89 @@ class DreamBoothDataset(BaseDataset):
self.num_train_images = num_train_images
logger.info(f"{num_reg_images} reg images.")
if num_train_images < num_reg_images:
self.num_reg_images = num_reg_images
self.reg_infos_index_traverser = 0
def set_reg_randomize(self, reg_randomize = False):
self.reg_randomize = reg_randomize
# As first set of data is loaded before the first opportunity to shuffle, will need to force reset self.reg_infos_index_traverser and reinitialize dataset
self.reg_infos_index_traverser = 0
self.bucket_manager = None
self.incremental_reg_load(True)
def subset_loaded_count(self):
count_str = ""
for index, subset in enumerate(self.subsets):
counter = 0
count_str += f"\nSubset {index} (Class: {subset.class_tokens}): " if isinstance(subset, DreamBoothSubset) and subset.class_tokens is not None else f"\nSubset {index}: "
img_keys = [key for key, value in self.image_to_subset.items() if value == subset]
for img_key in img_keys:
counter += self.image_data[img_key].num_repeats
count_str += f"{counter}/{subset.img_count * subset.num_repeats}"
count_str += f"\nSubset dir: {subset.image_dir}" if subset.image_dir is not None else ""
count_str += f"\n\n"
logger.info(count_str)
def incremental_reg_load(self, make_bucket = False):
#override to for loading random reg images
distributed_state = PartialState()
if self.num_reg_images == 0:
logger.warning("no regularization images / 正則化画像が見つかりませんでした")
return
if self.num_train_images < self.num_reg_images:
logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
if num_reg_images == 0:
logger.warning("no regularization images / 正則化画像が見つかりませんでした")
else:
# num_repeatsを計算するどうせ大した数ではないのでループで処理する
if not self.num_train_images == self.num_reg_images:
logger.info(f"Inititating loading of regularizaion images.")
for info, subset in self.reg_infos.values():
if info.image_key in self.image_data:
self.image_data.pop(info.image_key, None)
self.image_to_subset.pop(info.image_key, None)
temp_reg_infos = copy.deepcopy(self.reg_infos)
n = 0
first_loop = True
while n < num_train_images:
for info, subset in reg_infos:
if first_loop:
self.register_image(info, subset)
n += info.num_repeats
logger.info(f"self.reg_infos_index_traverser at: {self.reg_infos_index_traverser}\n reg_infos_index len = {len(self.reg_infos_index)}")
reg_img_log = f"\nDataset seed: {self.seed}"
start_index = self.reg_infos_index_traverser
while n < self.num_train_images :
if self.reg_randomize and self.reg_infos_index_traverser == 0:
if distributed_state.num_processes > 1:
if not distributed_state.is_main_process:
self.reg_infos_index = []
else:
random.shuffle(self.reg_infos_index)
distributed_state.wait_for_everyone()
self.reg_infos_index = gather_object(self.reg_infos_index)
else:
info.num_repeats += 1 # rewrite registered info
n += 1
if n >= num_train_images:
break
first_loop = False
self.num_reg_images = num_reg_images
random.shuffle(self.reg_infos_index)
info, subset = temp_reg_infos[self.reg_infos_index[self.reg_infos_index_traverser]]
if info.image_key in self.image_data:
info.num_repeats += 1 # rewrite registered info
else:
self.register_image(info, subset)
self.reg_infos_index_traverser += 1
if self.reg_infos_index_traverser % len(self.reg_infos_index) == 0:
self.reg_infos_index_traverser = 0
'''
if n < 5:
reg_img_log += f"\nRegistering image: {info.absolute_path}, count: {info.num_repeats}"
'''
n += 1
# logger.info(reg_img_log)
if distributed_state.is_main_process:
self.subset_loaded_count()
self.bucket_manager = None
if make_bucket:
self.make_buckets()
del temp_reg_infos
else:
logger.warning(f"Number of training images({self.num_train_images}) is the same as number of regularization images({self.num_reg_images}).\nSkipping randomized/incremental loading of regularization images.")
class FineTuningDataset(BaseDataset):
def __init__(
self,
@@ -2098,6 +2216,11 @@ class ControlNetDataset(BaseDataset):
self.conditioning_image_transforms = IMAGE_TRANSFORMS
def incremental_reg_load(self, make_bucket = False):
self.dreambooth_dataset_delegate.incremental_reg_load()
if make_bucket:
self.make_buckets()
def make_buckets(self):
self.dreambooth_dataset_delegate.make_buckets()
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
@@ -2185,9 +2308,13 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for dataset in self.datasets:
dataset.add_replacement(str_from, str_to)
# def make_buckets(self):
# for dataset in self.datasets:
# dataset.make_buckets()
def set_reg_randomize(self, reg_randomize = False):
for dataset in self.datasets:
dataset.set_reg_randomize(reg_randomize)
def make_buckets(self):
for dataset in self.datasets:
dataset.make_buckets()
def enable_XTI(self, *args, **kwargs):
for dataset in self.datasets:
@@ -2234,7 +2361,14 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
def disable_token_padding(self):
for dataset in self.datasets:
dataset.disable_token_padding()
def incremental_reg_load(self, make_bucket = False):
for dataset in self.datasets:
dataset.incremental_reg_load(make_bucket)
def __len__(self):
self.cumulative_sizes = self.cumsum(self.datasets)
return self.cumulative_sizes[-1]
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alpha_mask: bool):
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
@@ -3579,6 +3713,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
)
parser.add_argument(
"--incremental_reg_load",
action="store_true",
help="Forces reload of regularization images at each Epoch. Will sequentially load regularization images unless '--randomized_regularization_image' is set. Useful if there are more regularization images than training images",
)
parser.add_argument(
"--randomized_regularization_image",
action="store_true",
help="Shuffles regularization images to even out distribution. Useful if there are more regularization images than training images",
)
if support_dreambooth:
# DreamBooth training

View File

@@ -8,6 +8,7 @@ import time
import json
from multiprocessing import Value
import toml
import copy
from tqdm import tqdm
@@ -135,6 +136,11 @@ class NetworkTrainer:
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
def train(self, args):
# acceleratorを準備する
logger.info("preparing accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
session_id = random.randint(0, 2**32)
training_started_at = time.time()
train_util.verify_training_args(args)
@@ -202,8 +208,15 @@ class NetworkTrainer:
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
if args.incremental_reg_reload:
if args.persistent_data_loader_workers:
logger.warning("persistent_data_loader_workers has been set to False because incremental_reg_reload is enabled.")
args.persistent_data_loader_workers = False
if args.randomized_regularization_image:
# train_dataset_group.set_reg_randomize() triggers a reload to initial state with randomized regularization images. Ensure that this occurs before initial caching to prevent data mismatch
logger.info("Reloading sequentially loaded regularization images to replace with randomly selected regularization images...")
train_dataset_group.set_reg_randomize(args.randomized_regularization_image)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
@@ -221,11 +234,6 @@ class NetworkTrainer:
self.assert_extra_args(args, train_dataset_group)
# acceleratorを準備する
logger.info("preparing accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
@@ -263,23 +271,24 @@ class NetworkTrainer:
accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
self.cache_text_encoder_outputs_if_needed(
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
)
'''
Replacing cache latents and cache text encoder outputs here with code to simulate running through self.cache_text_encoder_outputs_if_needed().
Reduces unnecessary caching by avoiding caching until data loaded into train_dataset_group has been finalized.
This step is required to ensure text_encoders are loaded onto the correct device for the training.
Possibly should replace with method that can be overridden for different handling of TE for different models.
'''
if args.cache_text_encoder_outputs and self.is_sdxl:
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
for t_enc in text_encoders:
t_enc.to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
if not self.is_sdxl:
accelerator.print("Text Encoder caching not supported. Overriding args.cache_text_encoder_output to False")
args.cache_text_encoder_outputs = False
for t_enc in text_encoders:
t_enc.to(accelerator.device, dtype=weight_dtype)
# prepare network
net_kwargs = {}
@@ -368,8 +377,12 @@ class NetworkTrainer:
# dataloaderを準備する
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
@@ -845,7 +858,6 @@ class NetworkTrainer:
)
loss_recorder = train_util.LossRecorder()
del train_dataset_group
# callback for step start
if hasattr(accelerator.unwrap_model(network), "on_step_start"):
@@ -885,6 +897,8 @@ class NetworkTrainer:
for skip_epoch in range(epoch_to_start): # skip epochs
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
initial_step -= len(train_dataloader)
if args.incremental_reg_reload:
train_dataset_group.incremental_reg_load(True) # Updates the loaded dataset to the next epoch
global_step = initial_step
for epoch in range(epoch_to_start, num_train_epochs):
@@ -893,6 +907,39 @@ class NetworkTrainer:
metadata["ss_epoch"] = str(epoch + 1)
if epoch == epoch_to_start or args.incremental_reg_reload:
if cache_latents:
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
self.cache_text_encoder_outputs_if_needed(
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
)
accelerator.wait_for_everyone() # Ensure all processes sync after potential dataset/cache changes in initial_step block
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group, # This is the updated train_dataset_group
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers, # Ensure n_workers is available
persistent_workers=args.persistent_data_loader_workers,
)
accelerator.wait_for_everyone()
train_dataloader = accelerator.prepare(train_dataloader)
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
skipped_dataloader = None
@@ -1091,6 +1138,9 @@ class NetworkTrainer:
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# Load next batch of regularization images if necessary
if args.incremental_reg_reload and epoch + 1 < num_train_epochs:
train_dataset_group.incremental_reg_load(True)
# end of epoch