mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Update train_db.py
This commit is contained in:
34
train_db.py
34
train_db.py
@@ -503,28 +503,26 @@ def train(args):
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if len(val_dataloader) > 0:
|
||||
if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps:
|
||||
accelerator.print("Validating バリデーション処理...")
|
||||
total_loss = 0.0
|
||||
with torch.no_grad():
|
||||
validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
|
||||
for val_step in tqdm(range(validation_steps), desc='Validation Steps'):
|
||||
batch = next(cyclic_val_dataloader)
|
||||
loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
|
||||
total_loss += loss.detach().item()
|
||||
current_loss = total_loss / validation_steps
|
||||
val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss)
|
||||
accelerator.print("Validating バリデーション処理...")
|
||||
total_loss = 0.0
|
||||
with torch.no_grad():
|
||||
validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
|
||||
for val_step in tqdm(range(validation_steps), desc='Validation Steps'):
|
||||
batch = next(cyclic_val_dataloader)
|
||||
loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
|
||||
total_loss += loss.detach().item()
|
||||
current_loss = total_loss / validation_steps
|
||||
val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss)
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/current_val_loss": current_loss}
|
||||
accelerator.log(logs, step=global_step)
|
||||
avr_loss: float = val_loss_recorder.moving_average
|
||||
logs = {"loss/average_val_loss": avr_loss}
|
||||
accelerator.log(logs, step=global_step)
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/current_val_loss": current_loss}
|
||||
accelerator.log(logs, step=global_step)
|
||||
avr_loss: float = val_loss_recorder.moving_average
|
||||
logs = {"loss/average_val_loss": avr_loss}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
|
||||
Reference in New Issue
Block a user