mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Merge pull request #899 from shirayu/use_moving_average
Show moving average loss in the progress bar
This commit is contained in:
11
fine_tune.py
11
fine_tune.py
@@ -289,6 +289,7 @@ def train(args):
|
|||||||
init_kwargs = toml.load(args.log_tracker_config)
|
init_kwargs = toml.load(args.log_tracker_config)
|
||||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||||
|
|
||||||
|
loss_recorder = train_util.LossRecorder()
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||||
current_epoch.value = epoch + 1
|
current_epoch.value = epoch + 1
|
||||||
@@ -296,7 +297,6 @@ def train(args):
|
|||||||
for m in training_models:
|
for m in training_models:
|
||||||
m.train()
|
m.train()
|
||||||
|
|
||||||
loss_total = 0
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
||||||
@@ -408,17 +408,16 @@ def train(args):
|
|||||||
)
|
)
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
# TODO moving averageにする
|
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||||
loss_total += current_loss
|
avr_loss: float = loss_recorder.moving_average
|
||||||
avr_loss = loss_total / (step + 1)
|
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**logs)
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||||
accelerator.log(logs, step=epoch + 1)
|
accelerator.log(logs, step=epoch + 1)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|||||||
@@ -4697,3 +4697,21 @@ class collator_class:
|
|||||||
dataset.set_current_epoch(self.current_epoch.value)
|
dataset.set_current_epoch(self.current_epoch.value)
|
||||||
dataset.set_current_step(self.current_step.value)
|
dataset.set_current_step(self.current_step.value)
|
||||||
return examples[0]
|
return examples[0]
|
||||||
|
|
||||||
|
|
||||||
|
class LossRecorder:
|
||||||
|
def __init__(self):
|
||||||
|
self.loss_list: List[float] = []
|
||||||
|
self.loss_total: float = 0.0
|
||||||
|
|
||||||
|
def add(self, *, epoch:int, step: int, loss: float) -> None:
|
||||||
|
if epoch == 0:
|
||||||
|
self.loss_list.append(loss)
|
||||||
|
else:
|
||||||
|
self.loss_total -= self.loss_list[step]
|
||||||
|
self.loss_list[step] = loss
|
||||||
|
self.loss_total += loss
|
||||||
|
|
||||||
|
@property
|
||||||
|
def moving_average(self) -> float:
|
||||||
|
return self.loss_total / len(self.loss_list)
|
||||||
|
|||||||
@@ -451,6 +451,7 @@ def train(args):
|
|||||||
init_kwargs = toml.load(args.log_tracker_config)
|
init_kwargs = toml.load(args.log_tracker_config)
|
||||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||||
|
|
||||||
|
loss_recorder = train_util.LossRecorder()
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||||
current_epoch.value = epoch + 1
|
current_epoch.value = epoch + 1
|
||||||
@@ -458,7 +459,6 @@ def train(args):
|
|||||||
for m in training_models:
|
for m in training_models:
|
||||||
m.train()
|
m.train()
|
||||||
|
|
||||||
loss_total = 0
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
||||||
@@ -633,17 +633,16 @@ def train(args):
|
|||||||
|
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
# TODO moving averageにする
|
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||||
loss_total += current_loss
|
avr_loss: float = loss_recorder.moving_average
|
||||||
avr_loss = loss_total / (step + 1)
|
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**logs)
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||||
accelerator.log(logs, step=epoch + 1)
|
accelerator.log(logs, step=epoch + 1)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|||||||
@@ -351,8 +351,7 @@ def train(args):
|
|||||||
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_list = []
|
loss_recorder = train_util.LossRecorder()
|
||||||
loss_total = 0.0
|
|
||||||
del train_dataset_group
|
del train_dataset_group
|
||||||
|
|
||||||
# function for saving/removing
|
# function for saving/removing
|
||||||
@@ -503,14 +502,9 @@ def train(args):
|
|||||||
remove_model(remove_ckpt_name)
|
remove_model(remove_ckpt_name)
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
if epoch == 0:
|
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||||
loss_list.append(current_loss)
|
avr_loss: float = loss_recorder.moving_average
|
||||||
else:
|
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
loss_total -= loss_list[step]
|
|
||||||
loss_list[step] = current_loss
|
|
||||||
loss_total += current_loss
|
|
||||||
avr_loss = loss_total / len(loss_list)
|
|
||||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**logs)
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
@@ -521,7 +515,7 @@ def train(args):
|
|||||||
break
|
break
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = {"loss/epoch": loss_total / len(loss_list)}
|
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||||
accelerator.log(logs, step=epoch + 1)
|
accelerator.log(logs, step=epoch + 1)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|||||||
@@ -324,8 +324,7 @@ def train(args):
|
|||||||
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_list = []
|
loss_recorder = train_util.LossRecorder()
|
||||||
loss_total = 0.0
|
|
||||||
del train_dataset_group
|
del train_dataset_group
|
||||||
|
|
||||||
# function for saving/removing
|
# function for saving/removing
|
||||||
@@ -473,14 +472,9 @@ def train(args):
|
|||||||
remove_model(remove_ckpt_name)
|
remove_model(remove_ckpt_name)
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
if epoch == 0:
|
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||||
loss_list.append(current_loss)
|
avr_loss: float = loss_recorder.moving_average
|
||||||
else:
|
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
loss_total -= loss_list[step]
|
|
||||||
loss_list[step] = current_loss
|
|
||||||
loss_total += current_loss
|
|
||||||
avr_loss = loss_total / len(loss_list)
|
|
||||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**logs)
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
@@ -491,7 +485,7 @@ def train(args):
|
|||||||
break
|
break
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = {"loss/epoch": loss_total / len(loss_list)}
|
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||||
accelerator.log(logs, step=epoch + 1)
|
accelerator.log(logs, step=epoch + 1)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|||||||
@@ -337,8 +337,7 @@ def train(args):
|
|||||||
init_kwargs = toml.load(args.log_tracker_config)
|
init_kwargs = toml.load(args.log_tracker_config)
|
||||||
accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||||
|
|
||||||
loss_list = []
|
loss_recorder = train_util.LossRecorder()
|
||||||
loss_total = 0.0
|
|
||||||
del train_dataset_group
|
del train_dataset_group
|
||||||
|
|
||||||
# function for saving/removing
|
# function for saving/removing
|
||||||
@@ -500,14 +499,9 @@ def train(args):
|
|||||||
remove_model(remove_ckpt_name)
|
remove_model(remove_ckpt_name)
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
if epoch == 0:
|
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||||
loss_list.append(current_loss)
|
avr_loss: float = loss_recorder.moving_average
|
||||||
else:
|
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
loss_total -= loss_list[step]
|
|
||||||
loss_list[step] = current_loss
|
|
||||||
loss_total += current_loss
|
|
||||||
avr_loss = loss_total / len(loss_list)
|
|
||||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**logs)
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
@@ -518,7 +512,7 @@ def train(args):
|
|||||||
break
|
break
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = {"loss/epoch": loss_total / len(loss_list)}
|
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||||
accelerator.log(logs, step=epoch + 1)
|
accelerator.log(logs, step=epoch + 1)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|||||||
16
train_db.py
16
train_db.py
@@ -265,8 +265,7 @@ def train(args):
|
|||||||
init_kwargs = toml.load(args.log_tracker_config)
|
init_kwargs = toml.load(args.log_tracker_config)
|
||||||
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||||
|
|
||||||
loss_list = []
|
loss_recorder = train_util.LossRecorder()
|
||||||
loss_total = 0.0
|
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||||
current_epoch.value = epoch + 1
|
current_epoch.value = epoch + 1
|
||||||
@@ -395,21 +394,16 @@ def train(args):
|
|||||||
)
|
)
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
if epoch == 0:
|
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||||
loss_list.append(current_loss)
|
avr_loss: float = loss_recorder.moving_average
|
||||||
else:
|
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
loss_total -= loss_list[step]
|
|
||||||
loss_list[step] = current_loss
|
|
||||||
loss_total += current_loss
|
|
||||||
avr_loss = loss_total / len(loss_list)
|
|
||||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**logs)
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = {"loss/epoch": loss_total / len(loss_list)}
|
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||||
accelerator.log(logs, step=epoch + 1)
|
accelerator.log(logs, step=epoch + 1)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|||||||
@@ -710,8 +710,7 @@ class NetworkTrainer:
|
|||||||
"network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
"network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_list = []
|
loss_recorder = train_util.LossRecorder()
|
||||||
loss_total = 0.0
|
|
||||||
del train_dataset_group
|
del train_dataset_group
|
||||||
|
|
||||||
# callback for step start
|
# callback for step start
|
||||||
@@ -863,14 +862,9 @@ class NetworkTrainer:
|
|||||||
remove_model(remove_ckpt_name)
|
remove_model(remove_ckpt_name)
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
if epoch == 0:
|
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||||
loss_list.append(current_loss)
|
avr_loss: float = loss_recorder.moving_average
|
||||||
else:
|
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
loss_total -= loss_list[step]
|
|
||||||
loss_list[step] = current_loss
|
|
||||||
loss_total += current_loss
|
|
||||||
avr_loss = loss_total / len(loss_list)
|
|
||||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**logs)
|
||||||
|
|
||||||
if args.scale_weight_norms:
|
if args.scale_weight_norms:
|
||||||
@@ -884,7 +878,7 @@ class NetworkTrainer:
|
|||||||
break
|
break
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = {"loss/epoch": loss_total / len(loss_list)}
|
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||||
accelerator.log(logs, step=epoch + 1)
|
accelerator.log(logs, step=epoch + 1)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|||||||
Reference in New Issue
Block a user