mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Use LossRecorder
This commit is contained in:
@@ -350,8 +350,7 @@ def train(args):
|
|||||||
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_list = []
|
loss_recorder = train_util.LossRecorder()
|
||||||
loss_total = 0.0
|
|
||||||
del train_dataset_group
|
del train_dataset_group
|
||||||
|
|
||||||
# function for saving/removing
|
# function for saving/removing
|
||||||
@@ -500,14 +499,9 @@ def train(args):
|
|||||||
remove_model(remove_ckpt_name)
|
remove_model(remove_ckpt_name)
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
if epoch == 0:
|
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||||
loss_list.append(current_loss)
|
avr_loss: float = loss_recorder.moving_average
|
||||||
else:
|
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
loss_total -= loss_list[step]
|
|
||||||
loss_list[step] = current_loss
|
|
||||||
loss_total += current_loss
|
|
||||||
avr_loss = loss_total / len(loss_list)
|
|
||||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**logs)
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
@@ -518,7 +512,7 @@ def train(args):
|
|||||||
break
|
break
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = {"loss/epoch": loss_total / len(loss_list)}
|
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||||
accelerator.log(logs, step=epoch + 1)
|
accelerator.log(logs, step=epoch + 1)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|||||||
@@ -323,8 +323,7 @@ def train(args):
|
|||||||
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_list = []
|
loss_recorder = train_util.LossRecorder()
|
||||||
loss_total = 0.0
|
|
||||||
del train_dataset_group
|
del train_dataset_group
|
||||||
|
|
||||||
# function for saving/removing
|
# function for saving/removing
|
||||||
@@ -470,14 +469,9 @@ def train(args):
|
|||||||
remove_model(remove_ckpt_name)
|
remove_model(remove_ckpt_name)
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
if epoch == 0:
|
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||||
loss_list.append(current_loss)
|
avr_loss: float = loss_recorder.moving_average
|
||||||
else:
|
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
loss_total -= loss_list[step]
|
|
||||||
loss_list[step] = current_loss
|
|
||||||
loss_total += current_loss
|
|
||||||
avr_loss = loss_total / len(loss_list)
|
|
||||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**logs)
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
@@ -488,7 +482,7 @@ def train(args):
|
|||||||
break
|
break
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = {"loss/epoch": loss_total / len(loss_list)}
|
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||||
accelerator.log(logs, step=epoch + 1)
|
accelerator.log(logs, step=epoch + 1)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|||||||
@@ -337,8 +337,7 @@ def train(args):
|
|||||||
init_kwargs = toml.load(args.log_tracker_config)
|
init_kwargs = toml.load(args.log_tracker_config)
|
||||||
accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||||
|
|
||||||
loss_list = []
|
loss_recorder = train_util.LossRecorder()
|
||||||
loss_total = 0.0
|
|
||||||
del train_dataset_group
|
del train_dataset_group
|
||||||
|
|
||||||
# function for saving/removing
|
# function for saving/removing
|
||||||
@@ -500,14 +499,9 @@ def train(args):
|
|||||||
remove_model(remove_ckpt_name)
|
remove_model(remove_ckpt_name)
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
if epoch == 0:
|
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||||
loss_list.append(current_loss)
|
avr_loss: float = loss_recorder.moving_average
|
||||||
else:
|
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
loss_total -= loss_list[step]
|
|
||||||
loss_list[step] = current_loss
|
|
||||||
loss_total += current_loss
|
|
||||||
avr_loss = loss_total / len(loss_list)
|
|
||||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**logs)
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
@@ -518,7 +512,7 @@ def train(args):
|
|||||||
break
|
break
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = {"loss/epoch": loss_total / len(loss_list)}
|
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||||
accelerator.log(logs, step=epoch + 1)
|
accelerator.log(logs, step=epoch + 1)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|||||||
Reference in New Issue
Block a user