formatting

This commit is contained in:
Kohya S
2025-01-27 20:50:42 +09:00
parent 59b3b94faf
commit 532f5c58a6

View File

@@ -100,9 +100,7 @@ class NetworkTrainer:
if ( if (
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
): # tracking d*lr value of unet. ): # tracking d*lr value of unet.
logs["lr/d*lr"] = ( logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
)
else: else:
idx = 0 idx = 0
if not args.network_train_unet_only: if not args.network_train_unet_only:
@@ -115,16 +113,17 @@ class NetworkTrainer:
logs[f"lr/d*lr/group{i}"] = ( logs[f"lr/d*lr/group{i}"] = (
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
) )
if ( if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None:
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
):
logs[f"lr/d*lr/group{i}"] = (
optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
)
return logs return logs
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): def assert_extra_args(
self,
args,
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
train_dataset_group.verify_bucket_reso_steps(64) train_dataset_group.verify_bucket_reso_steps(64)
if val_dataset_group is not None: if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(64) val_dataset_group.verify_bucket_reso_steps(64)
@@ -219,7 +218,7 @@ class NetworkTrainer:
network, network,
weight_dtype, weight_dtype,
train_unet, train_unet,
is_train=True is_train=True,
): ):
# Sample noise, sample a random timestep for each image, and add noise to the latents, # Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified # with noise offset and/or multires noise if specified
@@ -315,22 +314,22 @@ class NetworkTrainer:
# endregion # endregion
def process_batch( def process_batch(
self, self,
batch, batch,
text_encoders, text_encoders,
unet, unet,
network, network,
vae, vae,
noise_scheduler, noise_scheduler,
vae_dtype, vae_dtype,
weight_dtype, weight_dtype,
accelerator, accelerator,
args, args,
text_encoding_strategy: strategy_base.TextEncodingStrategy, text_encoding_strategy: strategy_base.TextEncodingStrategy,
tokenize_strategy: strategy_base.TokenizeStrategy, tokenize_strategy: strategy_base.TokenizeStrategy,
is_train=True, is_train=True,
train_text_encoder=True, train_text_encoder=True,
train_unet=True train_unet=True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Process a batch for the network Process a batch for the network
@@ -397,7 +396,7 @@ class NetworkTrainer:
network, network,
weight_dtype, weight_dtype,
train_unet, train_unet,
is_train=is_train is_train=is_train,
) )
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
@@ -484,7 +483,7 @@ class NetworkTrainer:
else: else:
# use arbitrary dataset class # use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args) train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None # placeholder until validation dataset supported for arbitrary val_dataset_group = None # placeholder until validation dataset supported for arbitrary
current_epoch = Value("i", 0) current_epoch = Value("i", 0)
current_step = Value("i", 0) current_step = Value("i", 0)
@@ -701,7 +700,7 @@ class NetworkTrainer:
num_workers=n_workers, num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers, persistent_workers=args.persistent_data_loader_workers,
) )
val_dataloader = torch.utils.data.DataLoader( val_dataloader = torch.utils.data.DataLoader(
val_dataset_group if val_dataset_group is not None else [], val_dataset_group if val_dataset_group is not None else [],
shuffle=False, shuffle=False,
@@ -900,7 +899,9 @@ class NetworkTrainer:
accelerator.print("running training / 学習開始") accelerator.print("running training / 学習開始")
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}") accelerator.print(
f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}"
)
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
@@ -968,11 +969,11 @@ class NetworkTrainer:
"ss_huber_c": args.huber_c, "ss_huber_c": args.huber_c,
"ss_fp8_base": bool(args.fp8_base), "ss_fp8_base": bool(args.fp8_base),
"ss_fp8_base_unet": bool(args.fp8_base_unet), "ss_fp8_base_unet": bool(args.fp8_base_unet),
"ss_validation_seed": args.validation_seed, "ss_validation_seed": args.validation_seed,
"ss_validation_split": args.validation_split, "ss_validation_split": args.validation_split,
"ss_max_validation_steps": args.max_validation_steps, "ss_max_validation_steps": args.max_validation_steps,
"ss_validate_every_n_epochs": args.validate_every_n_epochs, "ss_validate_every_n_epochs": args.validate_every_n_epochs,
"ss_validate_every_n_steps": args.validate_every_n_steps, "ss_validate_every_n_steps": args.validate_every_n_steps,
} }
self.update_metadata(metadata, args) # architecture specific metadata self.update_metadata(metadata, args) # architecture specific metadata
@@ -1248,9 +1249,7 @@ class NetworkTrainer:
accelerator.log({}, step=0) accelerator.log({}, step=0)
validation_steps = ( validation_steps = (
min(args.max_validation_steps, len(val_dataloader)) min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
if args.max_validation_steps is not None
else len(val_dataloader)
) )
# training loop # training loop
@@ -1298,21 +1297,21 @@ class NetworkTrainer:
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
loss = self.process_batch( loss = self.process_batch(
batch, batch,
text_encoders, text_encoders,
unet, unet,
network, network,
vae, vae,
noise_scheduler, noise_scheduler,
vae_dtype, vae_dtype,
weight_dtype, weight_dtype,
accelerator, accelerator,
args, args,
text_encoding_strategy, text_encoding_strategy,
tokenize_strategy, tokenize_strategy,
is_train=True, is_train=True,
train_text_encoder=train_text_encoder, train_text_encoder=train_text_encoder,
train_unet=train_unet train_unet=train_unet,
) )
accelerator.backward(loss) accelerator.backward(loss)
@@ -1369,32 +1368,21 @@ class NetworkTrainer:
if args.scale_weight_norms: if args.scale_weight_norms:
progress_bar.set_postfix(**{**max_mean_logs, **logs}) progress_bar.set_postfix(**{**max_mean_logs, **logs})
if is_tracking: if is_tracking:
logs = self.generate_step_logs( logs = self.generate_step_logs(
args, args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
current_loss,
avr_loss,
lr_scheduler,
lr_descriptions,
optimizer,
keys_scaled,
mean_norm,
maximum_norm
) )
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
# VALIDATION PER STEP # VALIDATION PER STEP
should_validate_step = ( should_validate_step = (
args.validate_every_n_steps is not None args.validate_every_n_steps is not None
and global_step != 0 # Skip first step and global_step != 0 # Skip first step
and global_step % args.validate_every_n_steps == 0 and global_step % args.validate_every_n_steps == 0
) )
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
val_progress_bar = tqdm( val_progress_bar = tqdm(
range(validation_steps), smoothing=0, range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps"
disable=not accelerator.is_local_main_process,
desc="validation steps"
) )
for val_step, batch in enumerate(val_dataloader): for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps: if val_step >= validation_steps:
@@ -1404,27 +1392,27 @@ class NetworkTrainer:
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
loss = self.process_batch( loss = self.process_batch(
batch, batch,
text_encoders, text_encoders,
unet, unet,
network, network,
vae, vae,
noise_scheduler, noise_scheduler,
vae_dtype, vae_dtype,
weight_dtype, weight_dtype,
accelerator, accelerator,
args, args,
text_encoding_strategy, text_encoding_strategy,
tokenize_strategy, tokenize_strategy,
is_train=False, is_train=False,
train_text_encoder=False, train_text_encoder=False,
train_unet=False train_unet=False,
) )
current_loss = loss.detach().item() current_loss = loss.detach().item()
val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
val_progress_bar.update(1) val_progress_bar.update(1)
val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average }) val_progress_bar.set_postfix({"val_avg_loss": val_step_loss_recorder.moving_average})
if is_tracking: if is_tracking:
logs = { logs = {
@@ -1436,26 +1424,25 @@ class NetworkTrainer:
if is_tracking: if is_tracking:
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
logs = { logs = {
"loss/validation/step_average": val_step_loss_recorder.moving_average, "loss/validation/step_average": val_step_loss_recorder.moving_average,
"loss/validation/step_divergence": loss_validation_divergence, "loss/validation/step_divergence": loss_validation_divergence,
} }
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
# EPOCH VALIDATION # EPOCH VALIDATION
should_validate_epoch = ( should_validate_epoch = (
(epoch + 1) % args.validate_every_n_epochs == 0 (epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None else True
if args.validate_every_n_epochs is not None
else True
) )
if should_validate_epoch and len(val_dataloader) > 0: if should_validate_epoch and len(val_dataloader) > 0:
val_progress_bar = tqdm( val_progress_bar = tqdm(
range(validation_steps), smoothing=0, range(validation_steps),
disable=not accelerator.is_local_main_process, smoothing=0,
desc="epoch validation steps" disable=not accelerator.is_local_main_process,
desc="epoch validation steps",
) )
for val_step, batch in enumerate(val_dataloader): for val_step, batch in enumerate(val_dataloader):
@@ -1466,43 +1453,43 @@ class NetworkTrainer:
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
loss = self.process_batch( loss = self.process_batch(
batch, batch,
text_encoders, text_encoders,
unet, unet,
network, network,
vae, vae,
noise_scheduler, noise_scheduler,
vae_dtype, vae_dtype,
weight_dtype, weight_dtype,
accelerator, accelerator,
args, args,
text_encoding_strategy, text_encoding_strategy,
tokenize_strategy, tokenize_strategy,
is_train=False, is_train=False,
train_text_encoder=False, train_text_encoder=False,
train_unet=False train_unet=False,
) )
current_loss = loss.detach().item() current_loss = loss.detach().item()
val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
val_progress_bar.update(1) val_progress_bar.update(1)
val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average }) val_progress_bar.set_postfix({"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average})
if is_tracking: if is_tracking:
logs = { logs = {
"loss/validation/epoch_current": current_loss, "loss/validation/epoch_current": current_loss,
"epoch": epoch + 1, "epoch": epoch + 1,
"val_step": (epoch * validation_steps) + val_step "val_step": (epoch * validation_steps) + val_step,
} }
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
if is_tracking: if is_tracking:
avr_loss: float = val_epoch_loss_recorder.moving_average avr_loss: float = val_epoch_loss_recorder.moving_average
loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss
logs = { logs = {
"loss/validation/epoch_average": avr_loss, "loss/validation/epoch_average": avr_loss,
"loss/validation/epoch_divergence": loss_validation_divergence, "loss/validation/epoch_divergence": loss_validation_divergence,
"epoch": epoch + 1 "epoch": epoch + 1,
} }
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
@@ -1510,7 +1497,7 @@ class NetworkTrainer:
if is_tracking: if is_tracking:
logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1}
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# 指定エポックごとにモデルを保存 # 指定エポックごとにモデルを保存
@@ -1696,31 +1683,31 @@ def setup_parser() -> argparse.ArgumentParser:
"--validation_seed", "--validation_seed",
type=int, type=int,
default=None, default=None,
help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する" help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する",
) )
parser.add_argument( parser.add_argument(
"--validation_split", "--validation_split",
type=float, type=float,
default=0.0, default=0.0,
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合" help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合",
) )
parser.add_argument( parser.add_argument(
"--validate_every_n_steps", "--validate_every_n_steps",
type=int, type=int,
default=None, default=None,
help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます" help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます",
) )
parser.add_argument( parser.add_argument(
"--validate_every_n_epochs", "--validate_every_n_epochs",
type=int, type=int,
default=None, default=None,
help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます" help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます",
) )
parser.add_argument( parser.add_argument(
"--max_validation_steps", "--max_validation_steps",
type=int, type=int,
default=None, default=None,
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します" help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します",
) )
return parser return parser