Update train_network.py

This commit is contained in:
gesen2egee
2024-03-11 19:23:48 +08:00
parent 63e58f78e3
commit a6c41c6bea

View File

@@ -174,7 +174,7 @@ class NetworkTrainer:
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
for fixed_timesteps in tqdm(timesteps_list, desc='Training Progress'):
for fixed_timesteps in timesteps_list:
with torch.set_grad_enabled(is_train), accelerator.autocast():
noise = torch.randn_like(latents, device=latents.device)
b_size = latents.shape[0]
@@ -988,7 +988,7 @@ class NetworkTrainer:
total_loss = 0.0
with torch.no_grad():
validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader)
for val_step in range(validation_steps):
for val_step in tqdm(range(validation_steps), desc='Validation Steps'):
is_train = False
batch = next(cyclic_val_dataloader)
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
@@ -1016,7 +1016,7 @@ class NetworkTrainer:
total_loss = 0.0
with torch.no_grad():
validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader)
for val_step in range(validation_steps):
for val_step in tqdm(range(validation_steps), desc='Validation Steps'):
is_train = False
batch = next(cyclic_val_dataloader)
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)