Add validate_every_n_epochs, change name validate_every_n_steps

This commit is contained in:
rockerBOO
2025-01-06 11:30:21 -05:00
parent 1c63e7cc49
commit c64d1a22fc

View File

@@ -1199,7 +1199,8 @@ class NetworkTrainer:
) )
loss_recorder = train_util.LossRecorder() loss_recorder = train_util.LossRecorder()
val_loss_recorder = train_util.LossRecorder() val_step_loss_recorder = train_util.LossRecorder()
val_epoch_loss_recorder = train_util.LossRecorder()
del train_dataset_group del train_dataset_group
if val_dataset_group is not None: if val_dataset_group is not None:
@@ -1299,7 +1300,8 @@ class NetworkTrainer:
# temporary, for batch processing # temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
loss = self.process_batch(batch, loss = self.process_batch(
batch,
text_encoders, text_encoders,
unet, unet,
network, network,
@@ -1373,15 +1375,25 @@ class NetworkTrainer:
if is_tracking: if is_tracking:
logs = self.generate_step_logs( logs = self.generate_step_logs(
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm args,
current_loss,
avr_loss,
lr_scheduler,
lr_descriptions,
optimizer,
keys_scaled,
mean_norm,
maximum_norm
) )
# accelerator.log(logs, step=global_step) # accelerator.log(logs, step=global_step)
accelerator.log(logs) accelerator.log(logs)
# VALIDATION PER STEP # VALIDATION PER STEP
should_validate = (args.validation_every_n_step is not None should_validate_epoch = (
and global_step % args.validation_every_n_step == 0) args.validate_every_n_steps is not None
if validation_steps > 0 and should_validate: and global_step % args.validate_every_n_steps == 0
)
if validation_steps > 0 and should_validate_epoch:
accelerator.print("Validating バリデーション処理...") accelerator.print("Validating バリデーション処理...")
val_progress_bar = tqdm( val_progress_bar = tqdm(
@@ -1409,16 +1421,17 @@ class NetworkTrainer:
is_train=False is_train=False
) )
val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item()) current_loss = loss.detach().item()
val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
val_progress_bar.update(1) val_progress_bar.update(1)
val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average }) val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average })
if is_tracking: if is_tracking:
logs = {"loss/current_val_loss": loss.detach().item()} logs = {"loss/step_validation_current": current_loss}
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
accelerator.log(logs) accelerator.log(logs)
logs = {"loss/average_val_loss": val_loss_recorder.moving_average} logs = {"loss/step_validation_average": val_step_loss_recorder.moving_average}
# accelerator.log(logs, step=global_step) # accelerator.log(logs, step=global_step)
accelerator.log(logs) accelerator.log(logs)
@@ -1426,12 +1439,18 @@ class NetworkTrainer:
break break
# VALIDATION EPOCH # VALIDATION EPOCH
if len(val_dataloader) > 0: should_validate_epoch = (
(epoch + 1) % args.validate_every_n_epochs == 0
if args.validate_every_n_epochs is not None
else False
)
if should_validate_epoch and len(val_dataloader) > 0:
accelerator.print("Validating バリデーション処理...") accelerator.print("Validating バリデーション処理...")
val_progress_bar = tqdm( val_progress_bar = tqdm(
range(validation_steps), smoothing=0, range(validation_steps), smoothing=0,
disable=not accelerator.is_local_main_process, disable=not accelerator.is_local_main_process,
desc="validation steps" desc="epoch validation steps"
) )
for val_step, batch in enumerate(val_dataloader): for val_step, batch in enumerate(val_dataloader):
@@ -1455,18 +1474,18 @@ class NetworkTrainer:
) )
current_loss = loss.detach().item() current_loss = loss.detach().item()
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
val_progress_bar.update(1) val_progress_bar.update(1)
val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average }) val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average })
if is_tracking: if is_tracking:
logs = {"loss/validation_current": current_loss} logs = {"loss/epoch_validation_current": current_loss}
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
accelerator.log(logs) accelerator.log(logs)
if is_tracking: if is_tracking:
avr_loss: float = val_loss_recorder.moving_average avr_loss: float = val_epoch_loss_recorder.moving_average
logs = {"loss/validation_average": avr_loss} logs = {"loss/epoch_validation_average": avr_loss}
# accelerator.log(logs, step=epoch + 1) # accelerator.log(logs, step=epoch + 1)
accelerator.log(logs) accelerator.log(logs)
@@ -1476,12 +1495,6 @@ class NetworkTrainer:
# accelerator.log(logs, step=epoch + 1) # accelerator.log(logs, step=epoch + 1)
accelerator.log(logs) accelerator.log(logs)
if len(val_dataloader) > 0 and is_tracking:
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/validation_epoch_average": avr_loss}
# accelerator.log(logs, step=epoch + 1)
accelerator.log(logs)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# 指定エポックごとにモデルを保存 # 指定エポックごとにモデルを保存
@@ -1676,10 +1689,16 @@ def setup_parser() -> argparse.ArgumentParser:
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合" help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合"
) )
parser.add_argument( parser.add_argument(
"--validation_every_n_step", "--validate_every_n_steps",
type=int, type=int,
default=None, default=None,
help="Number of train steps for counting validation loss. By default, validation per train epoch is performed / 学習エポックごとに検証を行う場合はNoneを指定する" help="Run validation dataset every N steps"
)
parser.add_argument(
"--validate_every_n_epochs",
type=int,
default=None,
help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available"
) )
parser.add_argument( parser.add_argument(
"--max_validation_steps", "--max_validation_steps",