mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add validate_every_n_epochs, change name validate_every_n_steps
This commit is contained in:
@@ -1199,7 +1199,8 @@ class NetworkTrainer:
|
||||
)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
val_loss_recorder = train_util.LossRecorder()
|
||||
val_step_loss_recorder = train_util.LossRecorder()
|
||||
val_epoch_loss_recorder = train_util.LossRecorder()
|
||||
|
||||
del train_dataset_group
|
||||
if val_dataset_group is not None:
|
||||
@@ -1299,7 +1300,8 @@ class NetworkTrainer:
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
|
||||
loss = self.process_batch(batch,
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
@@ -1373,15 +1375,25 @@ class NetworkTrainer:
|
||||
|
||||
if is_tracking:
|
||||
logs = self.generate_step_logs(
|
||||
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
|
||||
args,
|
||||
current_loss,
|
||||
avr_loss,
|
||||
lr_scheduler,
|
||||
lr_descriptions,
|
||||
optimizer,
|
||||
keys_scaled,
|
||||
mean_norm,
|
||||
maximum_norm
|
||||
)
|
||||
# accelerator.log(logs, step=global_step)
|
||||
accelerator.log(logs)
|
||||
|
||||
# VALIDATION PER STEP
|
||||
should_validate = (args.validation_every_n_step is not None
|
||||
and global_step % args.validation_every_n_step == 0)
|
||||
if validation_steps > 0 and should_validate:
|
||||
should_validate_epoch = (
|
||||
args.validate_every_n_steps is not None
|
||||
and global_step % args.validate_every_n_steps == 0
|
||||
)
|
||||
if validation_steps > 0 and should_validate_epoch:
|
||||
accelerator.print("Validating バリデーション処理...")
|
||||
|
||||
val_progress_bar = tqdm(
|
||||
@@ -1409,16 +1421,17 @@ class NetworkTrainer:
|
||||
is_train=False
|
||||
)
|
||||
|
||||
val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item())
|
||||
current_loss = loss.detach().item()
|
||||
val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||
val_progress_bar.update(1)
|
||||
val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average })
|
||||
val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average })
|
||||
|
||||
if is_tracking:
|
||||
logs = {"loss/current_val_loss": loss.detach().item()}
|
||||
logs = {"loss/step_validation_current": current_loss}
|
||||
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
|
||||
accelerator.log(logs)
|
||||
|
||||
logs = {"loss/average_val_loss": val_loss_recorder.moving_average}
|
||||
logs = {"loss/step_validation_average": val_step_loss_recorder.moving_average}
|
||||
# accelerator.log(logs, step=global_step)
|
||||
accelerator.log(logs)
|
||||
|
||||
@@ -1426,12 +1439,18 @@ class NetworkTrainer:
|
||||
break
|
||||
|
||||
# VALIDATION EPOCH
|
||||
if len(val_dataloader) > 0:
|
||||
should_validate_epoch = (
|
||||
(epoch + 1) % args.validate_every_n_epochs == 0
|
||||
if args.validate_every_n_epochs is not None
|
||||
else False
|
||||
)
|
||||
|
||||
if should_validate_epoch and len(val_dataloader) > 0:
|
||||
accelerator.print("Validating バリデーション処理...")
|
||||
val_progress_bar = tqdm(
|
||||
range(validation_steps), smoothing=0,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc="validation steps"
|
||||
desc="epoch validation steps"
|
||||
)
|
||||
|
||||
for val_step, batch in enumerate(val_dataloader):
|
||||
@@ -1455,18 +1474,18 @@ class NetworkTrainer:
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||
val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||
val_progress_bar.update(1)
|
||||
val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average })
|
||||
val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average })
|
||||
|
||||
if is_tracking:
|
||||
logs = {"loss/validation_current": current_loss}
|
||||
logs = {"loss/epoch_validation_current": current_loss}
|
||||
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
|
||||
accelerator.log(logs)
|
||||
|
||||
if is_tracking:
|
||||
avr_loss: float = val_loss_recorder.moving_average
|
||||
logs = {"loss/validation_average": avr_loss}
|
||||
avr_loss: float = val_epoch_loss_recorder.moving_average
|
||||
logs = {"loss/epoch_validation_average": avr_loss}
|
||||
# accelerator.log(logs, step=epoch + 1)
|
||||
accelerator.log(logs)
|
||||
|
||||
@@ -1475,12 +1494,6 @@ class NetworkTrainer:
|
||||
logs = {"loss/epoch_average": loss_recorder.moving_average}
|
||||
# accelerator.log(logs, step=epoch + 1)
|
||||
accelerator.log(logs)
|
||||
|
||||
if len(val_dataloader) > 0 and is_tracking:
|
||||
avr_loss: float = val_loss_recorder.moving_average
|
||||
logs = {"loss/validation_epoch_average": avr_loss}
|
||||
# accelerator.log(logs, step=epoch + 1)
|
||||
accelerator.log(logs)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -1676,10 +1689,16 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_every_n_step",
|
||||
"--validate_every_n_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of train steps for counting validation loss. By default, validation per train epoch is performed / 学習エポックごとに検証を行う場合はNoneを指定する"
|
||||
help="Run validation dataset every N steps"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validate_every_n_epochs",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_validation_steps",
|
||||
|
||||
Reference in New Issue
Block a user