From 31507b9901d1d9ab65ba79ebd747b7f35c7e0fc1 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Fri, 2 Aug 2024 13:15:21 +0800 Subject: [PATCH] Remove unnecessary is_train changes and use apply_debiased_estimation to calculate validation loss. Balances the influence of different time steps on training performance (without affecting actual training results) --- train_network.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/train_network.py b/train_network.py index 2a3a4482..4a5940cd 100644 --- a/train_network.py +++ b/train_network.py @@ -135,7 +135,7 @@ class NetworkTrainer: def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) - def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True): + def process_val_batch(self, batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True): total_loss = 0.0 timesteps_list = [10, 350, 500, 650, 990] @@ -153,7 +153,7 @@ class NetworkTrainer: latents = latents * self.vae_scale_factor b_size = latents.shape[0] - with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): + with torch.set_grad_enabled(False), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: text_encoder_conds = get_weighted_text_embeddings( @@ -173,7 +173,7 @@ class NetworkTrainer: # with noise offset and/or multires noise if specified for fixed_timesteps in timesteps_list: - with torch.set_grad_enabled(is_train), accelerator.autocast(): + with torch.set_grad_enabled(False), accelerator.autocast(): noise = torch.randn_like(latents, device=latents.device) b_size = latents.shape[0] timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device) @@ -189,6 +189,7 @@ class NetworkTrainer: loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし total_loss += loss @@ -885,8 +886,7 @@ class NetworkTrainer: for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(training_model): - on_step_start(text_encoder, unet) - is_train = True + on_step_start(text_encoder, unet) if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: @@ -911,7 +911,7 @@ class NetworkTrainer: # print(f"set multiplier: {multipliers}") accelerator.unwrap_model(network).set_multiplier(multipliers) - with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): + with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: text_encoder_conds = get_weighted_text_embeddings( @@ -941,7 +941,7 @@ class NetworkTrainer: t.requires_grad_(True) # Predict the noise residual - with torch.set_grad_enabled(is_train), accelerator.autocast(): + with accelerator.autocast(): noise_pred = self.call_unet( args, accelerator, @@ -1040,10 +1040,9 @@ class NetworkTrainer: total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - is_train = False + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) total_loss += loss.detach().item() current_loss = total_loss / validation_steps val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss)