mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add validate_every_n_epochs, change name validate_every_n_steps
This commit is contained in:
@@ -1199,7 +1199,8 @@ class NetworkTrainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
loss_recorder = train_util.LossRecorder()
|
loss_recorder = train_util.LossRecorder()
|
||||||
val_loss_recorder = train_util.LossRecorder()
|
val_step_loss_recorder = train_util.LossRecorder()
|
||||||
|
val_epoch_loss_recorder = train_util.LossRecorder()
|
||||||
|
|
||||||
del train_dataset_group
|
del train_dataset_group
|
||||||
if val_dataset_group is not None:
|
if val_dataset_group is not None:
|
||||||
@@ -1299,7 +1300,8 @@ class NetworkTrainer:
|
|||||||
# temporary, for batch processing
|
# temporary, for batch processing
|
||||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||||
|
|
||||||
loss = self.process_batch(batch,
|
loss = self.process_batch(
|
||||||
|
batch,
|
||||||
text_encoders,
|
text_encoders,
|
||||||
unet,
|
unet,
|
||||||
network,
|
network,
|
||||||
@@ -1373,15 +1375,25 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
logs = self.generate_step_logs(
|
logs = self.generate_step_logs(
|
||||||
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
|
args,
|
||||||
|
current_loss,
|
||||||
|
avr_loss,
|
||||||
|
lr_scheduler,
|
||||||
|
lr_descriptions,
|
||||||
|
optimizer,
|
||||||
|
keys_scaled,
|
||||||
|
mean_norm,
|
||||||
|
maximum_norm
|
||||||
)
|
)
|
||||||
# accelerator.log(logs, step=global_step)
|
# accelerator.log(logs, step=global_step)
|
||||||
accelerator.log(logs)
|
accelerator.log(logs)
|
||||||
|
|
||||||
# VALIDATION PER STEP
|
# VALIDATION PER STEP
|
||||||
should_validate = (args.validation_every_n_step is not None
|
should_validate_epoch = (
|
||||||
and global_step % args.validation_every_n_step == 0)
|
args.validate_every_n_steps is not None
|
||||||
if validation_steps > 0 and should_validate:
|
and global_step % args.validate_every_n_steps == 0
|
||||||
|
)
|
||||||
|
if validation_steps > 0 and should_validate_epoch:
|
||||||
accelerator.print("Validating バリデーション処理...")
|
accelerator.print("Validating バリデーション処理...")
|
||||||
|
|
||||||
val_progress_bar = tqdm(
|
val_progress_bar = tqdm(
|
||||||
@@ -1409,16 +1421,17 @@ class NetworkTrainer:
|
|||||||
is_train=False
|
is_train=False
|
||||||
)
|
)
|
||||||
|
|
||||||
val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item())
|
current_loss = loss.detach().item()
|
||||||
|
val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||||
val_progress_bar.update(1)
|
val_progress_bar.update(1)
|
||||||
val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average })
|
val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average })
|
||||||
|
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
logs = {"loss/current_val_loss": loss.detach().item()}
|
logs = {"loss/step_validation_current": current_loss}
|
||||||
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
|
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
|
||||||
accelerator.log(logs)
|
accelerator.log(logs)
|
||||||
|
|
||||||
logs = {"loss/average_val_loss": val_loss_recorder.moving_average}
|
logs = {"loss/step_validation_average": val_step_loss_recorder.moving_average}
|
||||||
# accelerator.log(logs, step=global_step)
|
# accelerator.log(logs, step=global_step)
|
||||||
accelerator.log(logs)
|
accelerator.log(logs)
|
||||||
|
|
||||||
@@ -1426,12 +1439,18 @@ class NetworkTrainer:
|
|||||||
break
|
break
|
||||||
|
|
||||||
# VALIDATION EPOCH
|
# VALIDATION EPOCH
|
||||||
if len(val_dataloader) > 0:
|
should_validate_epoch = (
|
||||||
|
(epoch + 1) % args.validate_every_n_epochs == 0
|
||||||
|
if args.validate_every_n_epochs is not None
|
||||||
|
else False
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_validate_epoch and len(val_dataloader) > 0:
|
||||||
accelerator.print("Validating バリデーション処理...")
|
accelerator.print("Validating バリデーション処理...")
|
||||||
val_progress_bar = tqdm(
|
val_progress_bar = tqdm(
|
||||||
range(validation_steps), smoothing=0,
|
range(validation_steps), smoothing=0,
|
||||||
disable=not accelerator.is_local_main_process,
|
disable=not accelerator.is_local_main_process,
|
||||||
desc="validation steps"
|
desc="epoch validation steps"
|
||||||
)
|
)
|
||||||
|
|
||||||
for val_step, batch in enumerate(val_dataloader):
|
for val_step, batch in enumerate(val_dataloader):
|
||||||
@@ -1455,18 +1474,18 @@ class NetworkTrainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||||
val_progress_bar.update(1)
|
val_progress_bar.update(1)
|
||||||
val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average })
|
val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average })
|
||||||
|
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
logs = {"loss/validation_current": current_loss}
|
logs = {"loss/epoch_validation_current": current_loss}
|
||||||
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
|
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
|
||||||
accelerator.log(logs)
|
accelerator.log(logs)
|
||||||
|
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
avr_loss: float = val_loss_recorder.moving_average
|
avr_loss: float = val_epoch_loss_recorder.moving_average
|
||||||
logs = {"loss/validation_average": avr_loss}
|
logs = {"loss/epoch_validation_average": avr_loss}
|
||||||
# accelerator.log(logs, step=epoch + 1)
|
# accelerator.log(logs, step=epoch + 1)
|
||||||
accelerator.log(logs)
|
accelerator.log(logs)
|
||||||
|
|
||||||
@@ -1475,12 +1494,6 @@ class NetworkTrainer:
|
|||||||
logs = {"loss/epoch_average": loss_recorder.moving_average}
|
logs = {"loss/epoch_average": loss_recorder.moving_average}
|
||||||
# accelerator.log(logs, step=epoch + 1)
|
# accelerator.log(logs, step=epoch + 1)
|
||||||
accelerator.log(logs)
|
accelerator.log(logs)
|
||||||
|
|
||||||
if len(val_dataloader) > 0 and is_tracking:
|
|
||||||
avr_loss: float = val_loss_recorder.moving_average
|
|
||||||
logs = {"loss/validation_epoch_average": avr_loss}
|
|
||||||
# accelerator.log(logs, step=epoch + 1)
|
|
||||||
accelerator.log(logs)
|
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
@@ -1676,10 +1689,16 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合"
|
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--validation_every_n_step",
|
"--validate_every_n_steps",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Number of train steps for counting validation loss. By default, validation per train epoch is performed / 学習エポックごとに検証を行う場合はNoneを指定する"
|
help="Run validation dataset every N steps"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--validate_every_n_epochs",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_validation_steps",
|
"--max_validation_steps",
|
||||||
|
|||||||
Reference in New Issue
Block a user