Update train_network.py

This commit is contained in:
gesen2egee
2024-03-10 20:17:40 +08:00
parent 923b761ce3
commit 47359b8fac

View File

@@ -989,7 +989,7 @@ class NetworkTrainer:
total_loss = 0.0 total_loss = 0.0
with torch.no_grad(): with torch.no_grad():
validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader)
for val_step in min(len(val_dataloader), args.validation_batches): for val_step in range(validation_steps):
is_train = False is_train = False
batch = next(cyclic_val_dataloader) batch = next(cyclic_val_dataloader)
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
@@ -1015,7 +1015,7 @@ class NetworkTrainer:
total_loss = 0.0 total_loss = 0.0
with torch.no_grad(): with torch.no_grad():
validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader)
for val_step in min(len(val_dataloader), args.validation_batches): for val_step in range(validation_steps):
is_train = False is_train = False
batch = next(cyclic_val_dataloader) batch = next(cyclic_val_dataloader)
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)