mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix
This commit is contained in:
@@ -981,20 +981,19 @@ class NetworkTrainer:
|
|||||||
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
|
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
if args.validation_every_n_step is not None:
|
if len(val_dataloader) > 0:
|
||||||
if global_step % (args.validation_every_n_step) == 0:
|
if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or step == len(train_dataloader) - 1 or global_step >= args.max_train_steps:
|
||||||
if len(val_dataloader) > 0:
|
|
||||||
print(f"\nValidating バリデーション処理...")
|
print(f"\nValidating バリデーション処理...")
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader)
|
validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
|
||||||
for val_step in tqdm(range(validation_steps), desc='Validation Steps'):
|
for val_step in tqdm(range(validation_steps), desc='Validation Steps'):
|
||||||
is_train = False
|
is_train = False
|
||||||
batch = next(cyclic_val_dataloader)
|
batch = next(cyclic_val_dataloader)
|
||||||
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
|
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
|
||||||
total_loss += loss.detach().item()
|
total_loss += loss.detach().item()
|
||||||
current_loss = total_loss / args.validation_batches
|
current_loss = total_loss / validation_steps
|
||||||
val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss)
|
val_loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = {"loss/current_val_loss": current_loss}
|
logs = {"loss/current_val_loss": current_loss}
|
||||||
@@ -1009,25 +1008,6 @@ class NetworkTrainer:
|
|||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = {"loss/epoch_average": loss_recorder.moving_average}
|
logs = {"loss/epoch_average": loss_recorder.moving_average}
|
||||||
accelerator.log(logs, step=epoch + 1)
|
accelerator.log(logs, step=epoch + 1)
|
||||||
|
|
||||||
if args.validation_every_n_step is None:
|
|
||||||
if len(val_dataloader) > 0:
|
|
||||||
print(f"\nValidating バリデーション処理...")
|
|
||||||
total_loss = 0.0
|
|
||||||
with torch.no_grad():
|
|
||||||
validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader)
|
|
||||||
for val_step in tqdm(range(validation_steps), desc='Validation Steps'):
|
|
||||||
is_train = False
|
|
||||||
batch = next(cyclic_val_dataloader)
|
|
||||||
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
|
|
||||||
total_loss += loss.detach().item()
|
|
||||||
current_loss = total_loss / args.validation_batches
|
|
||||||
val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss)
|
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
|
||||||
avr_loss: float = val_loss_recorder.moving_average
|
|
||||||
logs = {"loss/epoch_val_average": avr_loss}
|
|
||||||
accelerator.log(logs, step=epoch + 1)
|
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
@@ -1184,14 +1164,14 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
"--validation_every_n_step",
|
"--validation_every_n_step",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Number of steps for counting validation loss. By default, validation per epoch is performed"
|
help="Number of train steps for counting validation loss. By default, validation per train epoch is performed"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--validation_batches",
|
"--max_validation_steps",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Number of val steps for counting validation loss. By default, validation for all val_dataset is performed"
|
help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset"
|
||||||
)
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user