Removed call of sum()

This commit is contained in:
Yuta Hayashibe
2023-02-14 21:11:30 +09:00
parent 21f5b618c3
commit 8aed5125de
2 changed files with 10 additions and 4 deletions

View File

@@ -207,6 +207,7 @@ def train(args):
accelerator.init_trackers("dreambooth")
loss_list = []
loss_total = 0.0
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
@@ -294,8 +295,10 @@ def train(args):
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
avr_loss = sum(loss_list) / len(loss_list)
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
@@ -303,7 +306,7 @@ def train(args):
break
if args.logging_dir is not None:
logs = {"epoch_loss": sum(loss_list) / len(loss_list)}
logs = {"epoch_loss": loss_total / len(loss_list)}
accelerator.log(logs, step=epoch+1)
accelerator.wait_for_everyone()

View File

@@ -379,6 +379,7 @@ def train(args):
accelerator.init_trackers("network_train")
loss_list = []
loss_total = 0.0
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
@@ -449,8 +450,10 @@ def train(args):
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
avr_loss = sum(loss_list) / len(loss_list)
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
@@ -462,7 +465,7 @@ def train(args):
break
if args.logging_dir is not None:
logs = {"loss/epoch": sum(loss_list) / len(loss_list)}
logs = {"loss/epoch": loss_total / len(loss_list)}
accelerator.log(logs, step=epoch+1)
accelerator.wait_for_everyone()