From 05bb9183fae18c62a1730fe5060f80c0b99a21f3 Mon Sep 17 00:00:00 2001 From: Hina Chen Date: Fri, 27 Dec 2024 16:47:59 +0800 Subject: [PATCH] Add Validation loss for LoRA training --- library/config_util.py | 78 +++++++++++++++++++++++- library/train_util.py | 54 ++++++++++++++++- train_network.py | 131 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 257 insertions(+), 6 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 12d0be17..a57cd36f 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -73,6 +73,8 @@ class BaseSubsetParams: token_warmup_min: int = 1 token_warmup_step: float = 0 custom_attributes: Optional[Dict[str, Any]] = None + validation_seed: int = 0 + validation_split: float = 0.0 @dataclass @@ -102,6 +104,8 @@ class BaseDatasetParams: resolution: Optional[Tuple[int, int]] = None network_multiplier: float = 1.0 debug_dataset: bool = False + validation_seed: Optional[int] = None + validation_split: float = 0.0 @dataclass @@ -478,9 +482,27 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params)) datasets.append(dataset) + val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.params.validation_split <= 0.0: + continue + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset + else: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset + + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) + val_datasets.append(dataset) + # print info info = "" for i, dataset in enumerate(datasets): @@ -566,6 +588,50 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu logger.info(f"{info}") + if len(val_datasets) > 0: + info = "" + + for i, dataset in enumerate(val_datasets): + info += dedent( + f"""\ + [Validation Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + network_multiplier: {dataset.network_multiplier} + """ + ) + + if dataset.enable_bucket: + info += indent( + dedent( + f"""\ + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n""" + ), + " ", + ) + else: + info += "\n" + + for j, subset in enumerate(dataset.subsets): + info += indent( + dedent( + f"""\ + [Subset {j} of Validation Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + """ + ), + " ", + ) + + logger.info(f"{info}") + # make buckets first because it determines the length of dataset # and set the same seed for all datasets seed = random.randint(0, 2**31) # actual seed is seed + epoch_no @@ -574,7 +640,15 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset.make_buckets() dataset.set_seed(seed) - return DatasetGroup(datasets) + for i, dataset in enumerate(val_datasets): + logger.info(f"[Validation Dataset {i}]") + dataset.make_buckets() + dataset.set_seed(seed) + + return ( + DatasetGroup(datasets), + DatasetGroup(val_datasets) if val_datasets else None + ) def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): diff --git a/library/train_util.py b/library/train_util.py index 72b5b24d..a3fa98e9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -145,6 +145,17 @@ IMAGE_TRANSFORMS = transforms.Compose( TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" +def split_train_val(paths: List[str], validation_split: float, validation_seed: int) -> List[str]: + if validation_seed is not None: + print(f"Using validation seed: {validation_seed}") + prevstate = random.getstate() + random.seed(validation_seed) + random.shuffle(paths) + random.setstate(prevstate) + else: + random.shuffle(paths) + + return paths[len(paths) - round(len(paths) * validation_split):] class ImageInfo: def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: @@ -397,6 +408,8 @@ class BaseSubset: token_warmup_min: int, token_warmup_step: Union[float, int], custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: self.image_dir = image_dir self.alpha_mask = alpha_mask if alpha_mask is not None else False @@ -424,6 +437,9 @@ class BaseSubset: self.img_count = 0 + self.validation_seed = validation_seed + self.validation_split = validation_split + class DreamBoothSubset(BaseSubset): def __init__( @@ -453,6 +469,8 @@ class DreamBoothSubset(BaseSubset): token_warmup_min, token_warmup_step, custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -478,6 +496,8 @@ class DreamBoothSubset(BaseSubset): token_warmup_min, token_warmup_step, custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, ) self.is_reg = is_reg @@ -518,6 +538,8 @@ class FineTuningSubset(BaseSubset): token_warmup_min, token_warmup_step, custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -543,6 +565,8 @@ class FineTuningSubset(BaseSubset): token_warmup_min, token_warmup_step, custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, ) self.metadata_file = metadata_file @@ -579,6 +603,8 @@ class ControlNetSubset(BaseSubset): token_warmup_min, token_warmup_step, custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -604,6 +630,8 @@ class ControlNetSubset(BaseSubset): token_warmup_min, token_warmup_step, custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, ) self.conditioning_data_dir = conditioning_data_dir @@ -1799,6 +1827,9 @@ class DreamBoothDataset(BaseDataset): bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset: bool, + is_train: bool, + validation_seed: int, + validation_split: float, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -1808,6 +1839,9 @@ class DreamBoothDataset(BaseDataset): self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight self.latents_cache = None + self.is_train = is_train + self.validation_seed = validation_seed + self.validation_split = validation_split self.enable_bucket = enable_bucket if self.enable_bucket: @@ -1992,6 +2026,9 @@ class DreamBoothDataset(BaseDataset): ) continue + if self.is_train == False: + img_paths = split_train_val(img_paths, self.validation_split, self.validation_seed) + if subset.is_reg: num_reg_images += subset.num_repeats * len(img_paths) else: @@ -2009,7 +2046,11 @@ class DreamBoothDataset(BaseDataset): subset.img_count = len(img_paths) self.subsets.append(subset) - logger.info(f"{num_train_images} train images with repeating.") + if self.is_train: + logger.info(f"{num_train_images} train images with repeating.") + else: + logger.info(f"{num_train_images} validation images with repeating.") + self.num_train_images = num_train_images logger.info(f"{num_reg_images} reg images.") @@ -2050,6 +2091,9 @@ class FineTuningDataset(BaseDataset): bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset: bool, + is_train: bool, + validation_seed: int, + validation_split: float, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -2276,6 +2320,9 @@ class ControlNetDataset(BaseDataset): bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset: float, + is_train: bool, + validation_seed: int, + validation_split: float, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -2324,6 +2371,9 @@ class ControlNetDataset(BaseDataset): bucket_no_upscale, 1.0, debug_dataset, + is_train, + validation_seed, + validation_split, ) # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) @@ -4887,7 +4937,7 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]: import schedulefree as sf except ImportError: raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") - + if optimizer_type == "RAdamScheduleFree".lower(): optimizer_class = sf.RAdamScheduleFree logger.info(f"use RAdamScheduleFree optimizer | {optimizer_kwargs}") diff --git a/train_network.py b/train_network.py index 5e82b307..776feaf7 100644 --- a/train_network.py +++ b/train_network.py @@ -9,6 +9,7 @@ import json from multiprocessing import Value from typing import Any, List import toml +import itertools from tqdm import tqdm @@ -114,7 +115,7 @@ class NetworkTrainer: ) if ( args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None - ): + ): logs[f"lr/d*lr/group{i}"] = ( optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] ) @@ -373,10 +374,11 @@ class NetworkTrainer: } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -398,6 +400,11 @@ class NetworkTrainer: train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + if val_dataset_group is not None: + assert ( + val_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + self.assert_extra_args(args, train_dataset_group) # may change some args # acceleratorを準備する @@ -444,6 +451,8 @@ class NetworkTrainer: vae.eval() train_dataset_group.new_cache_latents(vae, accelerator) + if val_dataset_group is not None: + val_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -459,6 +468,8 @@ class NetworkTrainer: if text_encoder_outputs_caching_strategy is not None: strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy) self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype) + if val_dataset_group is not None: + self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype) # prepare network net_kwargs = {} @@ -567,6 +578,8 @@ class NetworkTrainer: # strategies are set here because they cannot be referenced in another process. Copy them with the dataset # some strategies can be None train_dataset_group.set_current_strategies() + if val_dataset_group is not None: + val_dataset_group.set_current_strategies() # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -580,6 +593,17 @@ class NetworkTrainer: persistent_workers=args.persistent_data_loader_workers, ) + val_dataloader = torch.utils.data.DataLoader( + val_dataset_group if val_dataset_group is not None else [], + batch_size=1, + shuffle=False, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + cyclic_val_dataloader = itertools.cycle(val_dataloader) + # 学習ステップ数を計算する if args.max_train_epochs is not None: args.max_train_steps = args.max_train_epochs * math.ceil( @@ -592,6 +616,10 @@ class NetworkTrainer: # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) + # Not for sure here. + # if val_dataset_group is not None: + # val_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -1064,7 +1092,11 @@ class NetworkTrainer: ) loss_recorder = train_util.LossRecorder() + # val_loss_recorder = train_util.LossRecorder() + del train_dataset_group + if val_dataset_group is not None: + del val_dataset_group # callback for step start if hasattr(accelerator.unwrap_model(network), "on_step_start"): @@ -1308,6 +1340,77 @@ class NetworkTrainer: ) accelerator.log(logs, step=global_step) + if len(val_dataloader) > 0: + if ((args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps): + accelerator.print("\nValidating バリデーション処理...") + + total_loss = 0.0 + + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc="Validation Steps バリデーションテップ"): + batch = next(cyclic_val_dataloader) + + timesteps_list = [10, 350, 500, 650, 990] + + val_loss = 0.0 + + for fixed_timesteps in timesteps_list: + with torch.set_grad_enabled(False), accelerator.autocast(): + noise = torch.randn_like(latents, device=latents.device) + b_size = latents.shape[0] + + timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device="cpu") + timesteps = timesteps.long().to(latents.device) + + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + with accelerator.autocast(): + noise_pred = self.call_unet( + args, + accelerator, + unet, + noisy_latents.requires_grad_(False), + timesteps, + text_encoder_conds, + batch, + weight_dtype, + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc. + loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + val_loss += loss / len(timesteps_list) + + total_loss += val_loss.detach().item() + + current_val_loss = total_loss / validation_steps + # val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_val_loss) + + if len(accelerator.trackers) > 0: + logs = {"loss/current_val_loss": current_val_loss} + accelerator.log(logs, step=global_step) + + # avr_loss: float = val_loss_recorder.moving_average + # logs = {"loss/average_val_loss": avr_loss} + # accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break @@ -1496,6 +1599,30 @@ def setup_parser() -> argparse.ArgumentParser: help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch." + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする", ) + parser.add_argument( + "--validation_seed", + type=int, + default=None, + help="Validation seed / 検証シード" + ) + parser.add_argument( + "--validation_split", + type=float, + default=0.0, + help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合" + ) + parser.add_argument( + "--validation_every_n_step", + type=int, + default=None, + help="Number of train steps for counting validation loss. By default, validation per train epoch is performed / 学習エポックごとに検証を行う場合はNoneを指定する" + ) + parser.add_argument( + "--max_validation_steps", + type=int, + default=None, + help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset / 検証データセット全体を検証する場合はNoneを指定する" + ) # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")