From 5b19bda85c2ce01e4a1c7f324b7ef14bffed3315 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 12:35:46 -0500 Subject: [PATCH] Add validation loss --- library/train_util.py | 4 ++ train_network.py | 117 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 120 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index cc9ac455..e26f3979 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4736,6 +4736,10 @@ class collator_class: else: dataset = self.dataset + # If we split a dataset we will get a Subset + if type(dataset) is torch.utils.data.Subset: + dataset = dataset.dataset + # set epoch and step dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_step(self.current_step.value) diff --git a/train_network.py b/train_network.py index d50916b7..58767b6f 100644 --- a/train_network.py +++ b/train_network.py @@ -345,8 +345,21 @@ class NetworkTrainer: # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + if args.validation_ratio > 0.0: + train_ratio = 1 - args.validation_ratio + validation_ratio = args.validation_ratio + train, val = torch.utils.data.random_split( + train_dataset_group, + [train_ratio, validation_ratio] + ) + print(f"split dataset by ratio: train {train_ratio}, validation {validation_ratio}") + print(f"train images: {len(train)}, validation images: {len(val)}") + else: + train = train_dataset_group + val = [] + train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, + train, batch_size=1, shuffle=True, collate_fn=collator, @@ -354,6 +367,15 @@ class NetworkTrainer: persistent_workers=args.persistent_data_loader_workers, ) + val_dataloader = torch.utils.data.DataLoader( + val, + shuffle=False, + batch_size=1, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + # 学習ステップ数を計算する if args.max_train_epochs is not None: args.max_train_steps = args.max_train_epochs * math.ceil( @@ -711,6 +733,8 @@ class NetworkTrainer: ) loss_recorder = train_util.LossRecorder() + val_loss_recorder = train_util.LossRecorder() + del train_dataset_group # callback for step start @@ -752,6 +776,8 @@ class NetworkTrainer: network.on_epoch_start(text_encoder, unet) + # TRAINING + for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(network): @@ -877,6 +903,87 @@ class NetworkTrainer: if global_step >= args.max_train_steps: break + # VALIDATION + + if len(val_dataloader) > 0: + print("Validating バリデーション処理...") + + with torch.no_grad(): + for val_step, batch in enumerate(val_dataloader): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(device=accelerator.device, dtype=vae_dtype)).latent_dist.sample() + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = latents * self.vae_scale_factor + b_size = latents.shape[0] + + # Get the text embedding for conditioning + if args.weighted_captions: + text_encoder_conds = get_weighted_text_embeddings( + tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + text_encoder_conds = self.get_text_cond( + args, accelerator, batch, tokenizers, text_encoders, weight_dtype + ) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = self.call_unet( + args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight + + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + current_loss = loss.detach().item() + + val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) + + if len(val_dataloader) > 0: + avr_loss: float = val_loss_recorder.moving_average + + if args.logging_dir is not None: + logs = {"loss/validation": avr_loss} + accelerator.log(logs, step=epoch + 1) + + if args.logging_dir is not None: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) @@ -999,6 +1106,14 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + + parser.add_argument( + "--validation_ratio", + type=float, + default=0.0, + help="Ratio for validation images out of the training dataset" + ) + return parser