mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
improve
This commit is contained in:
@@ -44,6 +44,7 @@ from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
import itertools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -438,6 +439,7 @@ class NetworkTrainer:
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
cyclic_val_dataloader = itertools.cycle(val_dataloader)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
@@ -979,23 +981,24 @@ class NetworkTrainer:
|
||||
if args.logging_dir is not None:
|
||||
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step % 25 == 0:
|
||||
if len(val_dataloader) > 0:
|
||||
print("Validating バリデーション処理...")
|
||||
|
||||
with torch.no_grad():
|
||||
val_dataloader_iter = iter(val_dataloader)
|
||||
batch = next(val_dataloader_iter)
|
||||
is_train = False
|
||||
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss)
|
||||
|
||||
if args.validation_every_n_step is not None:
|
||||
if global_step % (args.validation_every_n_step) == 0:
|
||||
if len(val_dataloader) > 0:
|
||||
print("Validating バリデーション処理...")
|
||||
total_loss = 0.0
|
||||
with torch.no_grad():
|
||||
for val_step in min(len(val_dataloader), args.validation_batches):
|
||||
is_train = False
|
||||
batch = next(cyclic_val_dataloader)
|
||||
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
|
||||
total_loss += loss.detach().item()
|
||||
current_loss = total_loss / args.validation_batches
|
||||
val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss)
|
||||
|
||||
if args.logging_dir is not None:
|
||||
avr_loss: float = val_loss_recorder.moving_average
|
||||
logs = {"loss/validation_current": current_loss}
|
||||
logs = {"loss/avr_val_loss": avr_loss}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
@@ -1005,12 +1008,24 @@ class NetworkTrainer:
|
||||
logs = {"loss/epoch_average": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
if len(val_dataloader) > 0:
|
||||
if args.logging_dir is not None:
|
||||
avr_loss: float = val_loss_recorder.moving_average
|
||||
logs = {"loss/validation_epoch_average": avr_loss}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
if args.validation_every_n_step is None:
|
||||
if len(val_dataloader) > 0:
|
||||
print("Validating バリデーション処理...")
|
||||
total_loss = 0.0
|
||||
with torch.no_grad():
|
||||
for val_step in min(len(val_dataloader), args.validation_batches):
|
||||
is_train = False
|
||||
batch = next(cyclic_val_dataloader)
|
||||
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
|
||||
total_loss += loss.detach().item()
|
||||
current_loss = total_loss / args.validation_batches
|
||||
val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss)
|
||||
|
||||
if args.logging_dir is not None:
|
||||
avr_loss: float = val_loss_recorder.moving_average
|
||||
logs = {"loss/val_epoch_average": avr_loss}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 指定エポックごとにモデルを保存
|
||||
@@ -1162,6 +1177,18 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=0.0,
|
||||
help="Split for validation images out of the training dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_every_n_step",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of steps for counting validation loss. By default, validation per epoch is performed"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_batches",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of val steps for counting validation loss. By default, validation one batch is performed"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user