Update train_network.py

This commit is contained in:
gesen2egee
2024-03-11 18:47:04 +08:00
parent 7d84ac2177
commit befbec5335

View File

@@ -174,7 +174,7 @@ class NetworkTrainer:
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
for fixed_timesteps in timesteps_list:
for fixed_timesteps in tqdm(timesteps_list, desc='Training Progress'):
with torch.set_grad_enabled(is_train), accelerator.autocast():
noise = torch.randn_like(latents, device=latents.device)
b_size = latents.shape[0]
@@ -985,7 +985,7 @@ class NetworkTrainer:
if args.validation_every_n_step is not None:
if global_step % (args.validation_every_n_step) == 0:
if len(val_dataloader) > 0:
print("Validating バリデーション処理...")
print(f"\nValidating バリデーション処理...")
total_loss = 0.0
with torch.no_grad():
validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader)
@@ -998,6 +998,8 @@ class NetworkTrainer:
val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss)
if args.logging_dir is not None:
logs = {"loss/current_val_loss": current_loss}
accelerator.log(logs, step=global_step)
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/average_val_loss": avr_loss}
accelerator.log(logs, step=global_step)
@@ -1011,7 +1013,7 @@ class NetworkTrainer:
if args.validation_every_n_step is None:
if len(val_dataloader) > 0:
print("Validating バリデーション処理...")
print(f"\nValidating バリデーション処理...")
total_loss = 0.0
with torch.no_grad():
validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader)
@@ -1025,7 +1027,7 @@ class NetworkTrainer:
if args.logging_dir is not None:
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/val_epoch_average": avr_loss}
logs = {"loss/epoch_val_average": avr_loss}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()