This commit is contained in:
gesen2egee
2024-03-10 18:55:48 +08:00
parent b558a5b73d
commit 78cfb01922
2 changed files with 230 additions and 89 deletions

View File

@@ -44,6 +44,7 @@ from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
import itertools
logger = logging.getLogger(__name__)
@@ -438,6 +439,7 @@ class NetworkTrainer:
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
cyclic_val_dataloader = itertools.cycle(val_dataloader)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
@@ -979,23 +981,24 @@ class NetworkTrainer:
if args.logging_dir is not None:
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
accelerator.log(logs, step=global_step)
if global_step % 25 == 0:
if len(val_dataloader) > 0:
print("Validating バリデーション処理...")
with torch.no_grad():
val_dataloader_iter = iter(val_dataloader)
batch = next(val_dataloader_iter)
is_train = False
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
current_loss = loss.detach().item()
val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss)
if args.validation_every_n_step is not None:
if global_step % (args.validation_every_n_step) == 0:
if len(val_dataloader) > 0:
print("Validating バリデーション処理...")
total_loss = 0.0
with torch.no_grad():
for val_step in min(len(val_dataloader), args.validation_batches):
is_train = False
batch = next(cyclic_val_dataloader)
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
total_loss += loss.detach().item()
current_loss = total_loss / args.validation_batches
val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss)
if args.logging_dir is not None:
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/validation_current": current_loss}
logs = {"loss/avr_val_loss": avr_loss}
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
@@ -1005,12 +1008,24 @@ class NetworkTrainer:
logs = {"loss/epoch_average": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
if len(val_dataloader) > 0:
if args.logging_dir is not None:
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/validation_epoch_average": avr_loss}
accelerator.log(logs, step=epoch + 1)
if args.validation_every_n_step is None:
if len(val_dataloader) > 0:
print("Validating バリデーション処理...")
total_loss = 0.0
with torch.no_grad():
for val_step in min(len(val_dataloader), args.validation_batches):
is_train = False
batch = next(cyclic_val_dataloader)
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
total_loss += loss.detach().item()
current_loss = total_loss / args.validation_batches
val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss)
if args.logging_dir is not None:
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/val_epoch_average": avr_loss}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
# 指定エポックごとにモデルを保存
@@ -1162,6 +1177,18 @@ def setup_parser() -> argparse.ArgumentParser:
default=0.0,
help="Split for validation images out of the training dataset"
)
parser.add_argument(
"--validation_every_n_step",
type=int,
default=None,
help="Number of steps for counting validation loss. By default, validation per epoch is performed"
)
parser.add_argument(
"--validation_batches",
type=int,
default=1,
help="Number of val steps for counting validation loss. By default, validation one batch is performed"
)
return parser