From a51723cc2a3dd50b45e60945f97bc5adfe753d1f Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Mon, 11 Mar 2024 09:42:58 +0800 Subject: [PATCH] fix timesteps --- train_network.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/train_network.py b/train_network.py index d549378c..f0f27ea7 100644 --- a/train_network.py +++ b/train_network.py @@ -141,7 +141,6 @@ class NetworkTrainer: total_loss = 0.0 timesteps_list = [10, 350, 500, 650, 990] - with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) @@ -174,16 +173,17 @@ class NetworkTrainer: # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, _ = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) - for timesteps in timesteps_list: - # Predict the noise residual + + for fixed_timesteps in timesteps_list: with torch.set_grad_enabled(is_train), accelerator.autocast(): + noise = torch.randn_like(latents, device=latents.device) + b_size = latents.shape[0] + timesteps = torch.randint(fixed_timesteps, fixed_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noise_pred = self.call_unet( args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype ) - if args.v_parameterization: # v-parameterization training target = noise_scheduler.get_velocity(latents, noise, timesteps) @@ -988,7 +988,7 @@ class NetworkTrainer: print("Validating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): - validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) + validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) for val_step in range(validation_steps): is_train = False batch = next(cyclic_val_dataloader) @@ -999,7 +999,7 @@ class NetworkTrainer: if args.logging_dir is not None: avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/avr_val_loss": avr_loss} + logs = {"loss/average_val_loss": avr_loss} accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: @@ -1014,7 +1014,7 @@ class NetworkTrainer: print("Validating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): - validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) + validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) for val_step in range(validation_steps): is_train = False batch = next(cyclic_val_dataloader)