Merge pull request #899 from shirayu/use_moving_average

Show moving average loss in the progress bar
This commit is contained in:
Kohya S
2023-10-29 14:37:58 +09:00
committed by GitHub
8 changed files with 53 additions and 67 deletions

View File

@@ -324,8 +324,7 @@ def train(args):
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)
loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
del train_dataset_group
# function for saving/removing
@@ -473,14 +472,9 @@ def train(args):
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
@@ -491,7 +485,7 @@ def train(args):
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()