mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Merge 6fff6cc33e into 206adb6438
This commit is contained in:
@@ -576,9 +576,10 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
||||
for i, dataset in enumerate(datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.make_buckets()
|
||||
dataset.set_seed(seed)
|
||||
|
||||
dataset.incremental_reg_load()
|
||||
dataset.make_buckets()
|
||||
|
||||
return DatasetGroup(datasets)
|
||||
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ import hashlib
|
||||
import subprocess
|
||||
from io import BytesIO
|
||||
import toml
|
||||
import copy
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -164,6 +165,8 @@ class ImageInfo:
|
||||
self.text_encoder_outputs2: Optional[torch.Tensor] = None
|
||||
self.text_encoder_pool2: Optional[torch.Tensor] = None
|
||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||
self.latent_cache_checked: bool = False
|
||||
self.te_cache_checked: bool = False
|
||||
|
||||
|
||||
class BucketManager:
|
||||
@@ -653,6 +656,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# caching
|
||||
self.caching_mode = None # None, 'latents', 'text'
|
||||
|
||||
# lists for incremental loading of regularization images
|
||||
self.reg_infos = None
|
||||
self.reg_infos_index = None
|
||||
self.reg_randomize = False
|
||||
|
||||
def adjust_min_max_bucket_reso_by_steps(
|
||||
self, resolution: Tuple[int, int], min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int
|
||||
) -> Tuple[int, int]:
|
||||
@@ -684,6 +692,12 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
def set_seed(self, seed):
|
||||
self.seed = seed
|
||||
|
||||
def set_reg_randomize(self, reg_randomize = False):
|
||||
self.reg_randomize = reg_randomize
|
||||
|
||||
def incremental_reg_load(self, make_bucket = False): # Placeholder method, does nothing unless overridden in subclasses.
|
||||
return
|
||||
|
||||
def set_caching_mode(self, mode):
|
||||
self.caching_mode = mode
|
||||
|
||||
@@ -951,11 +965,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if self.enable_bucket:
|
||||
self.bucket_info = {"buckets": {}}
|
||||
logger.info("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
|
||||
batch_count: int = 0
|
||||
for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)):
|
||||
count = len(bucket)
|
||||
if count > 0:
|
||||
batch_count += math.ceil(len(bucket) / self.batch_size)
|
||||
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
|
||||
logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
|
||||
logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}, batches: {int(math.ceil(len(bucket) / self.batch_size))}")
|
||||
logger.info(f"Total batch count: {batch_count}")
|
||||
|
||||
if len(img_ar_errors) == 0:
|
||||
mean_img_ar_error = 0 # avoid NaN
|
||||
@@ -967,6 +984,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
|
||||
self.buckets_indices: List[BucketBatchIndex] = []
|
||||
self.buckets_indices.clear()
|
||||
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
|
||||
batch_count = int(math.ceil(len(bucket) / self.batch_size))
|
||||
for batch_index in range(batch_count):
|
||||
@@ -1025,6 +1043,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
logger.info("caching latents.")
|
||||
|
||||
image_infos = list(self.image_data.values())
|
||||
image_infos = list(filter(lambda info: info.latent_cache_checked == False, image_infos))
|
||||
if len(image_infos) == 0:
|
||||
logger.info("All images latents previously checked and cached. Skipping.")
|
||||
return
|
||||
|
||||
# sort by resolution
|
||||
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
|
||||
@@ -1054,11 +1076,17 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
|
||||
if info.latents_npz is not None: # fine tuning dataset
|
||||
info.latent_cache_checked = True
|
||||
if self.reg_infos is not None and info.image_key in self.reg_infos:
|
||||
self.reg_infos[info.image_key][0].latent_cache_checked = True
|
||||
continue
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
if cache_to_disk:
|
||||
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
|
||||
info.latent_cache_checked = True
|
||||
if self.reg_infos is not None and info.image_key in self.reg_infos:
|
||||
self.reg_infos[info.image_key][0].latent_cache_checked = True
|
||||
if not is_main_process: # store to info only
|
||||
continue
|
||||
|
||||
@@ -1094,6 +1122,15 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
logger.info("caching latents...")
|
||||
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
|
||||
if self.reg_infos is not None:
|
||||
for info in batch:
|
||||
if info.image_key in self.reg_infos:
|
||||
self.reg_infos[info.image_key][0].latents_npz = info.latents_npz
|
||||
self.reg_infos[info.image_key][0].latents_original_size = info.latents_original_size
|
||||
self.reg_infos[info.image_key][0].latents_crop_ltrb = info.latents_crop_ltrb
|
||||
self.reg_infos[info.image_key][0].latents_crop_ltrb = info.latents_flipped
|
||||
self.reg_infos[info.image_key][0].latents = info.latents
|
||||
self.reg_infos[info.image_key][0].alpha_mask = info.alpha_mask
|
||||
|
||||
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
|
||||
# SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する
|
||||
@@ -1107,6 +1144,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
logger.info("caching text encoder outputs.")
|
||||
image_infos = list(self.image_data.values())
|
||||
image_infos = list(filter(lambda info: info.te_cache_checked == False, image_infos))
|
||||
if len(image_infos) == 0:
|
||||
logger.info("Text encoder outputs for all images previously checked and cached. Skipping.")
|
||||
return
|
||||
|
||||
logger.info("checking cache existence...")
|
||||
image_infos_to_cache = []
|
||||
@@ -1115,6 +1156,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if cache_to_disk:
|
||||
te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
|
||||
info.text_encoder_outputs_npz = te_out_npz
|
||||
info.te_cache_checked = True
|
||||
if self.reg_infos is not None:
|
||||
self.reg_infos[info.image_key][0].text_encoder_outputs_npz = te_out_npz
|
||||
self.reg_infos[info.image_key][0].te_cache_checked = True
|
||||
|
||||
if not is_main_process: # store to info only
|
||||
continue
|
||||
@@ -1157,6 +1202,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
cache_batch_text_encoder_outputs(
|
||||
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype
|
||||
)
|
||||
if self.reg_infos is not None:
|
||||
for info in batch:
|
||||
if info.image_key in self.reg_infos:
|
||||
self.reg_infos[info.image_key][0].text_encoder_outputs_npz = te_out_npz
|
||||
self.reg_infos[info.image_key][0].te_cache_checked = True
|
||||
self.reg_infos[info.image_key][0].text_encoder_outputs1 = info.text_encoder_outputs1
|
||||
self.reg_infos[info.image_key][0].text_encoder_outputs2 = info.text_encoder_outputs2
|
||||
self.reg_infos[info.image_key][0].text_encoder_pool2 = info.text_encoder_pool2
|
||||
|
||||
def get_image_size(self, image_path):
|
||||
return imagesize.get(image_path)
|
||||
@@ -1561,6 +1614,9 @@ class DreamBoothDataset(BaseDataset):
|
||||
self.size = min(self.width, self.height) # 短いほう
|
||||
self.prior_loss_weight = prior_loss_weight
|
||||
self.latents_cache = None
|
||||
self.reg_infos: Dict[str, Tuple[ImageInfo, DreamBoothSubset]] = {}
|
||||
self.reg_infos_index: List[str] = []
|
||||
self.reg_infos_index_traverser = 0
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
@@ -1689,7 +1745,6 @@ class DreamBoothDataset(BaseDataset):
|
||||
logger.info("prepare images.")
|
||||
num_train_images = 0
|
||||
num_reg_images = 0
|
||||
reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = []
|
||||
for subset in subsets:
|
||||
if subset.num_repeats < 1:
|
||||
logger.warning(
|
||||
@@ -1720,7 +1775,11 @@ class DreamBoothDataset(BaseDataset):
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
if subset.is_reg:
|
||||
reg_infos.append((info, subset))
|
||||
if subset.num_repeats > 1:
|
||||
info.num_repeats = 1
|
||||
self.reg_infos[info.image_key] = (info, subset)
|
||||
for i in range(subset.num_repeats):
|
||||
self.reg_infos_index.append(info.image_key)
|
||||
else:
|
||||
self.register_image(info, subset)
|
||||
|
||||
@@ -1731,30 +1790,89 @@ class DreamBoothDataset(BaseDataset):
|
||||
self.num_train_images = num_train_images
|
||||
|
||||
logger.info(f"{num_reg_images} reg images.")
|
||||
if num_train_images < num_reg_images:
|
||||
self.num_reg_images = num_reg_images
|
||||
self.reg_infos_index_traverser = 0
|
||||
|
||||
def set_reg_randomize(self, reg_randomize = False):
|
||||
self.reg_randomize = reg_randomize
|
||||
# As first set of data is loaded before the first opportunity to shuffle, will need to force reset self.reg_infos_index_traverser and reinitialize dataset
|
||||
self.reg_infos_index_traverser = 0
|
||||
self.bucket_manager = None
|
||||
self.incremental_reg_load(True)
|
||||
|
||||
def subset_loaded_count(self):
|
||||
count_str = ""
|
||||
for index, subset in enumerate(self.subsets):
|
||||
counter = 0
|
||||
count_str += f"\nSubset {index} (Class: {subset.class_tokens}): " if isinstance(subset, DreamBoothSubset) and subset.class_tokens is not None else f"\nSubset {index}: "
|
||||
img_keys = [key for key, value in self.image_to_subset.items() if value == subset]
|
||||
for img_key in img_keys:
|
||||
counter += self.image_data[img_key].num_repeats
|
||||
count_str += f"{counter}/{subset.img_count * subset.num_repeats}"
|
||||
count_str += f"\nSubset dir: {subset.image_dir}" if subset.image_dir is not None else ""
|
||||
count_str += f"\n\n"
|
||||
logger.info(count_str)
|
||||
|
||||
def incremental_reg_load(self, make_bucket = False):
|
||||
#override to for loading random reg images
|
||||
distributed_state = PartialState()
|
||||
|
||||
if self.num_reg_images == 0:
|
||||
logger.warning("no regularization images / 正則化画像が見つかりませんでした")
|
||||
return
|
||||
if self.num_train_images < self.num_reg_images:
|
||||
logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
|
||||
|
||||
if num_reg_images == 0:
|
||||
logger.warning("no regularization images / 正則化画像が見つかりませんでした")
|
||||
else:
|
||||
# num_repeatsを計算する:どうせ大した数ではないのでループで処理する
|
||||
if not self.num_train_images == self.num_reg_images:
|
||||
logger.info(f"Inititating loading of regularizaion images.")
|
||||
for info, subset in self.reg_infos.values():
|
||||
if info.image_key in self.image_data:
|
||||
self.image_data.pop(info.image_key, None)
|
||||
self.image_to_subset.pop(info.image_key, None)
|
||||
|
||||
temp_reg_infos = copy.deepcopy(self.reg_infos)
|
||||
n = 0
|
||||
first_loop = True
|
||||
while n < num_train_images:
|
||||
for info, subset in reg_infos:
|
||||
if first_loop:
|
||||
self.register_image(info, subset)
|
||||
n += info.num_repeats
|
||||
logger.info(f"self.reg_infos_index_traverser at: {self.reg_infos_index_traverser}\n reg_infos_index len = {len(self.reg_infos_index)}")
|
||||
reg_img_log = f"\nDataset seed: {self.seed}"
|
||||
start_index = self.reg_infos_index_traverser
|
||||
|
||||
while n < self.num_train_images :
|
||||
if self.reg_randomize and self.reg_infos_index_traverser == 0:
|
||||
if distributed_state.num_processes > 1:
|
||||
if not distributed_state.is_main_process:
|
||||
self.reg_infos_index = []
|
||||
else:
|
||||
random.shuffle(self.reg_infos_index)
|
||||
distributed_state.wait_for_everyone()
|
||||
self.reg_infos_index = gather_object(self.reg_infos_index)
|
||||
else:
|
||||
info.num_repeats += 1 # rewrite registered info
|
||||
n += 1
|
||||
if n >= num_train_images:
|
||||
break
|
||||
first_loop = False
|
||||
|
||||
self.num_reg_images = num_reg_images
|
||||
random.shuffle(self.reg_infos_index)
|
||||
info, subset = temp_reg_infos[self.reg_infos_index[self.reg_infos_index_traverser]]
|
||||
if info.image_key in self.image_data:
|
||||
info.num_repeats += 1 # rewrite registered info
|
||||
else:
|
||||
self.register_image(info, subset)
|
||||
|
||||
self.reg_infos_index_traverser += 1
|
||||
if self.reg_infos_index_traverser % len(self.reg_infos_index) == 0:
|
||||
self.reg_infos_index_traverser = 0
|
||||
'''
|
||||
if n < 5:
|
||||
reg_img_log += f"\nRegistering image: {info.absolute_path}, count: {info.num_repeats}"
|
||||
'''
|
||||
n += 1
|
||||
|
||||
# logger.info(reg_img_log)
|
||||
if distributed_state.is_main_process:
|
||||
self.subset_loaded_count()
|
||||
self.bucket_manager = None
|
||||
if make_bucket:
|
||||
self.make_buckets()
|
||||
del temp_reg_infos
|
||||
else:
|
||||
logger.warning(f"Number of training images({self.num_train_images}) is the same as number of regularization images({self.num_reg_images}).\nSkipping randomized/incremental loading of regularization images.")
|
||||
|
||||
class FineTuningDataset(BaseDataset):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -2098,6 +2216,11 @@ class ControlNetDataset(BaseDataset):
|
||||
|
||||
self.conditioning_image_transforms = IMAGE_TRANSFORMS
|
||||
|
||||
def incremental_reg_load(self, make_bucket = False):
|
||||
self.dreambooth_dataset_delegate.incremental_reg_load()
|
||||
if make_bucket:
|
||||
self.make_buckets()
|
||||
|
||||
def make_buckets(self):
|
||||
self.dreambooth_dataset_delegate.make_buckets()
|
||||
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
|
||||
@@ -2185,9 +2308,13 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
for dataset in self.datasets:
|
||||
dataset.add_replacement(str_from, str_to)
|
||||
|
||||
# def make_buckets(self):
|
||||
# for dataset in self.datasets:
|
||||
# dataset.make_buckets()
|
||||
def set_reg_randomize(self, reg_randomize = False):
|
||||
for dataset in self.datasets:
|
||||
dataset.set_reg_randomize(reg_randomize)
|
||||
|
||||
def make_buckets(self):
|
||||
for dataset in self.datasets:
|
||||
dataset.make_buckets()
|
||||
|
||||
def enable_XTI(self, *args, **kwargs):
|
||||
for dataset in self.datasets:
|
||||
@@ -2234,7 +2361,14 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
def disable_token_padding(self):
|
||||
for dataset in self.datasets:
|
||||
dataset.disable_token_padding()
|
||||
|
||||
def incremental_reg_load(self, make_bucket = False):
|
||||
for dataset in self.datasets:
|
||||
dataset.incremental_reg_load(make_bucket)
|
||||
|
||||
def __len__(self):
|
||||
self.cumulative_sizes = self.cumsum(self.datasets)
|
||||
return self.cumulative_sizes[-1]
|
||||
|
||||
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
|
||||
@@ -3579,6 +3713,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
default=None,
|
||||
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--incremental_reg_load",
|
||||
action="store_true",
|
||||
help="Forces reload of regularization images at each Epoch. Will sequentially load regularization images unless '--randomized_regularization_image' is set. Useful if there are more regularization images than training images",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--randomized_regularization_image",
|
||||
action="store_true",
|
||||
help="Shuffles regularization images to even out distribution. Useful if there are more regularization images than training images",
|
||||
)
|
||||
|
||||
if support_dreambooth:
|
||||
# DreamBooth training
|
||||
|
||||
@@ -8,6 +8,7 @@ import time
|
||||
import json
|
||||
from multiprocessing import Value
|
||||
import toml
|
||||
import copy
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -135,6 +136,11 @@ class NetworkTrainer:
|
||||
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
def train(self, args):
|
||||
# acceleratorを準備する
|
||||
logger.info("preparing accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
is_main_process = accelerator.is_main_process
|
||||
|
||||
session_id = random.randint(0, 2**32)
|
||||
training_started_at = time.time()
|
||||
train_util.verify_training_args(args)
|
||||
@@ -202,8 +208,15 @@ class NetworkTrainer:
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
if args.incremental_reg_reload:
|
||||
if args.persistent_data_loader_workers:
|
||||
logger.warning("persistent_data_loader_workers has been set to False because incremental_reg_reload is enabled.")
|
||||
args.persistent_data_loader_workers = False
|
||||
if args.randomized_regularization_image:
|
||||
# train_dataset_group.set_reg_randomize() triggers a reload to initial state with randomized regularization images. Ensure that this occurs before initial caching to prevent data mismatch
|
||||
logger.info("Reloading sequentially loaded regularization images to replace with randomly selected regularization images...")
|
||||
train_dataset_group.set_reg_randomize(args.randomized_regularization_image)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
@@ -221,11 +234,6 @@ class NetworkTrainer:
|
||||
|
||||
self.assert_extra_args(args, train_dataset_group)
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("preparing accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
is_main_process = accelerator.is_main_process
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||
@@ -263,23 +271,24 @@ class NetworkTrainer:
|
||||
|
||||
accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
|
||||
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
|
||||
self.cache_text_encoder_outputs_if_needed(
|
||||
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
|
||||
)
|
||||
'''
|
||||
Replacing cache latents and cache text encoder outputs here with code to simulate running through self.cache_text_encoder_outputs_if_needed().
|
||||
Reduces unnecessary caching by avoiding caching until data loaded into train_dataset_group has been finalized.
|
||||
This step is required to ensure text_encoders are loaded onto the correct device for the training.
|
||||
Possibly should replace with method that can be overridden for different handling of TE for different models.
|
||||
'''
|
||||
if args.cache_text_encoder_outputs and self.is_sdxl:
|
||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||
for t_enc in text_encoders:
|
||||
t_enc.to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
||||
else:
|
||||
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
||||
if not self.is_sdxl:
|
||||
accelerator.print("Text Encoder caching not supported. Overriding args.cache_text_encoder_output to False")
|
||||
args.cache_text_encoder_outputs = False
|
||||
for t_enc in text_encoders:
|
||||
t_enc.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# prepare network
|
||||
net_kwargs = {}
|
||||
@@ -368,8 +377,12 @@ class NetworkTrainer:
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
@@ -845,7 +858,6 @@ class NetworkTrainer:
|
||||
)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
del train_dataset_group
|
||||
|
||||
# callback for step start
|
||||
if hasattr(accelerator.unwrap_model(network), "on_step_start"):
|
||||
@@ -885,6 +897,8 @@ class NetworkTrainer:
|
||||
for skip_epoch in range(epoch_to_start): # skip epochs
|
||||
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
|
||||
initial_step -= len(train_dataloader)
|
||||
if args.incremental_reg_reload:
|
||||
train_dataset_group.incremental_reg_load(True) # Updates the loaded dataset to the next epoch
|
||||
global_step = initial_step
|
||||
|
||||
for epoch in range(epoch_to_start, num_train_epochs):
|
||||
@@ -893,6 +907,39 @@ class NetworkTrainer:
|
||||
|
||||
metadata["ss_epoch"] = str(epoch + 1)
|
||||
|
||||
if epoch == epoch_to_start or args.incremental_reg_reload:
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
|
||||
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
|
||||
self.cache_text_encoder_outputs_if_needed(
|
||||
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
|
||||
)
|
||||
accelerator.wait_for_everyone() # Ensure all processes sync after potential dataset/cache changes in initial_step block
|
||||
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group, # This is the updated train_dataset_group
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers, # Ensure n_workers is available
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
train_dataloader = accelerator.prepare(train_dataloader)
|
||||
|
||||
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
|
||||
|
||||
skipped_dataloader = None
|
||||
@@ -1091,6 +1138,9 @@ class NetworkTrainer:
|
||||
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
||||
|
||||
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
# Load next batch of regularization images if necessary
|
||||
if args.incremental_reg_reload and epoch + 1 < num_train_epochs:
|
||||
train_dataset_group.incremental_reg_load(True)
|
||||
|
||||
# end of epoch
|
||||
|
||||
|
||||
Reference in New Issue
Block a user