From a6c41c6bea0465112c7bd472dff68b7e8ecea46e Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Mon, 11 Mar 2024 19:23:48 +0800 Subject: [PATCH] Update train_network.py --- train_network.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index 82d72df2..6eefdb2b 100644 --- a/train_network.py +++ b/train_network.py @@ -174,7 +174,7 @@ class NetworkTrainer: # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - for fixed_timesteps in tqdm(timesteps_list, desc='Training Progress'): + for fixed_timesteps in timesteps_list: with torch.set_grad_enabled(is_train), accelerator.autocast(): noise = torch.randn_like(latents, device=latents.device) b_size = latents.shape[0] @@ -988,7 +988,7 @@ class NetworkTrainer: total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) - for val_step in range(validation_steps): + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): is_train = False batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) @@ -1016,7 +1016,7 @@ class NetworkTrainer: total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) - for val_step in range(validation_steps): + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): is_train = False batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)