Add validate_every_n_epochs, change name validate_every_n_steps

This commit is contained in:
rockerBOO
2025-01-06 11:30:21 -05:00
parent 1c63e7cc49
commit c64d1a22fc

View File

@@ -1199,7 +1199,8 @@ class NetworkTrainer:
)
loss_recorder = train_util.LossRecorder()
val_loss_recorder = train_util.LossRecorder()
val_step_loss_recorder = train_util.LossRecorder()
val_epoch_loss_recorder = train_util.LossRecorder()
del train_dataset_group
if val_dataset_group is not None:
@@ -1299,7 +1300,8 @@ class NetworkTrainer:
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
loss = self.process_batch(batch,
loss = self.process_batch(
batch,
text_encoders,
unet,
network,
@@ -1373,15 +1375,25 @@ class NetworkTrainer:
if is_tracking:
logs = self.generate_step_logs(
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
args,
current_loss,
avr_loss,
lr_scheduler,
lr_descriptions,
optimizer,
keys_scaled,
mean_norm,
maximum_norm
)
# accelerator.log(logs, step=global_step)
accelerator.log(logs)
# VALIDATION PER STEP
should_validate = (args.validation_every_n_step is not None
and global_step % args.validation_every_n_step == 0)
if validation_steps > 0 and should_validate:
should_validate_epoch = (
args.validate_every_n_steps is not None
and global_step % args.validate_every_n_steps == 0
)
if validation_steps > 0 and should_validate_epoch:
accelerator.print("Validating バリデーション処理...")
val_progress_bar = tqdm(
@@ -1409,16 +1421,17 @@ class NetworkTrainer:
is_train=False
)
val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item())
current_loss = loss.detach().item()
val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
val_progress_bar.update(1)
val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average })
val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average })
if is_tracking:
logs = {"loss/current_val_loss": loss.detach().item()}
logs = {"loss/step_validation_current": current_loss}
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
accelerator.log(logs)
logs = {"loss/average_val_loss": val_loss_recorder.moving_average}
logs = {"loss/step_validation_average": val_step_loss_recorder.moving_average}
# accelerator.log(logs, step=global_step)
accelerator.log(logs)
@@ -1426,12 +1439,18 @@ class NetworkTrainer:
break
# VALIDATION EPOCH
if len(val_dataloader) > 0:
should_validate_epoch = (
(epoch + 1) % args.validate_every_n_epochs == 0
if args.validate_every_n_epochs is not None
else False
)
if should_validate_epoch and len(val_dataloader) > 0:
accelerator.print("Validating バリデーション処理...")
val_progress_bar = tqdm(
range(validation_steps), smoothing=0,
disable=not accelerator.is_local_main_process,
desc="validation steps"
desc="epoch validation steps"
)
for val_step, batch in enumerate(val_dataloader):
@@ -1455,18 +1474,18 @@ class NetworkTrainer:
)
current_loss = loss.detach().item()
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
val_progress_bar.update(1)
val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average })
val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average })
if is_tracking:
logs = {"loss/validation_current": current_loss}
logs = {"loss/epoch_validation_current": current_loss}
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
accelerator.log(logs)
if is_tracking:
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/validation_average": avr_loss}
avr_loss: float = val_epoch_loss_recorder.moving_average
logs = {"loss/epoch_validation_average": avr_loss}
# accelerator.log(logs, step=epoch + 1)
accelerator.log(logs)
@@ -1475,12 +1494,6 @@ class NetworkTrainer:
logs = {"loss/epoch_average": loss_recorder.moving_average}
# accelerator.log(logs, step=epoch + 1)
accelerator.log(logs)
if len(val_dataloader) > 0 and is_tracking:
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/validation_epoch_average": avr_loss}
# accelerator.log(logs, step=epoch + 1)
accelerator.log(logs)
accelerator.wait_for_everyone()
@@ -1676,10 +1689,16 @@ def setup_parser() -> argparse.ArgumentParser:
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合"
)
parser.add_argument(
"--validation_every_n_step",
"--validate_every_n_steps",
type=int,
default=None,
help="Number of train steps for counting validation loss. By default, validation per train epoch is performed / 学習エポックごとに検証を行う場合はNoneを指定する"
help="Run validation dataset every N steps"
)
parser.add_argument(
"--validate_every_n_epochs",
type=int,
default=None,
help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available"
)
parser.add_argument(
"--max_validation_steps",