mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
formatting
This commit is contained in:
@@ -100,9 +100,7 @@ class NetworkTrainer:
|
|||||||
if (
|
if (
|
||||||
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
|
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
|
||||||
): # tracking d*lr value of unet.
|
): # tracking d*lr value of unet.
|
||||||
logs["lr/d*lr"] = (
|
logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
|
||||||
optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
idx = 0
|
idx = 0
|
||||||
if not args.network_train_unet_only:
|
if not args.network_train_unet_only:
|
||||||
@@ -115,16 +113,17 @@ class NetworkTrainer:
|
|||||||
logs[f"lr/d*lr/group{i}"] = (
|
logs[f"lr/d*lr/group{i}"] = (
|
||||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||||
)
|
)
|
||||||
if (
|
if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None:
|
||||||
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
|
logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
|
||||||
):
|
|
||||||
logs[f"lr/d*lr/group{i}"] = (
|
|
||||||
optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
|
|
||||||
)
|
|
||||||
|
|
||||||
return logs
|
return logs
|
||||||
|
|
||||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
def assert_extra_args(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||||
|
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||||
|
):
|
||||||
train_dataset_group.verify_bucket_reso_steps(64)
|
train_dataset_group.verify_bucket_reso_steps(64)
|
||||||
if val_dataset_group is not None:
|
if val_dataset_group is not None:
|
||||||
val_dataset_group.verify_bucket_reso_steps(64)
|
val_dataset_group.verify_bucket_reso_steps(64)
|
||||||
@@ -219,7 +218,7 @@ class NetworkTrainer:
|
|||||||
network,
|
network,
|
||||||
weight_dtype,
|
weight_dtype,
|
||||||
train_unet,
|
train_unet,
|
||||||
is_train=True
|
is_train=True,
|
||||||
):
|
):
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
@@ -330,7 +329,7 @@ class NetworkTrainer:
|
|||||||
tokenize_strategy: strategy_base.TokenizeStrategy,
|
tokenize_strategy: strategy_base.TokenizeStrategy,
|
||||||
is_train=True,
|
is_train=True,
|
||||||
train_text_encoder=True,
|
train_text_encoder=True,
|
||||||
train_unet=True
|
train_unet=True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Process a batch for the network
|
Process a batch for the network
|
||||||
@@ -397,7 +396,7 @@ class NetworkTrainer:
|
|||||||
network,
|
network,
|
||||||
weight_dtype,
|
weight_dtype,
|
||||||
train_unet,
|
train_unet,
|
||||||
is_train=is_train
|
is_train=is_train,
|
||||||
)
|
)
|
||||||
|
|
||||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||||
@@ -900,7 +899,9 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
accelerator.print("running training / 学習開始")
|
accelerator.print("running training / 学習開始")
|
||||||
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||||
accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}")
|
accelerator.print(
|
||||||
|
f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}"
|
||||||
|
)
|
||||||
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
@@ -1248,9 +1249,7 @@ class NetworkTrainer:
|
|||||||
accelerator.log({}, step=0)
|
accelerator.log({}, step=0)
|
||||||
|
|
||||||
validation_steps = (
|
validation_steps = (
|
||||||
min(args.max_validation_steps, len(val_dataloader))
|
min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
|
||||||
if args.max_validation_steps is not None
|
|
||||||
else len(val_dataloader)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
@@ -1312,7 +1311,7 @@ class NetworkTrainer:
|
|||||||
tokenize_strategy,
|
tokenize_strategy,
|
||||||
is_train=True,
|
is_train=True,
|
||||||
train_text_encoder=train_text_encoder,
|
train_text_encoder=train_text_encoder,
|
||||||
train_unet=train_unet
|
train_unet=train_unet,
|
||||||
)
|
)
|
||||||
|
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
@@ -1369,18 +1368,9 @@ class NetworkTrainer:
|
|||||||
if args.scale_weight_norms:
|
if args.scale_weight_norms:
|
||||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||||
|
|
||||||
|
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
logs = self.generate_step_logs(
|
logs = self.generate_step_logs(
|
||||||
args,
|
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
|
||||||
current_loss,
|
|
||||||
avr_loss,
|
|
||||||
lr_scheduler,
|
|
||||||
lr_descriptions,
|
|
||||||
optimizer,
|
|
||||||
keys_scaled,
|
|
||||||
mean_norm,
|
|
||||||
maximum_norm
|
|
||||||
)
|
)
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
@@ -1392,9 +1382,7 @@ class NetworkTrainer:
|
|||||||
)
|
)
|
||||||
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
|
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
|
||||||
val_progress_bar = tqdm(
|
val_progress_bar = tqdm(
|
||||||
range(validation_steps), smoothing=0,
|
range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps"
|
||||||
disable=not accelerator.is_local_main_process,
|
|
||||||
desc="validation steps"
|
|
||||||
)
|
)
|
||||||
for val_step, batch in enumerate(val_dataloader):
|
for val_step, batch in enumerate(val_dataloader):
|
||||||
if val_step >= validation_steps:
|
if val_step >= validation_steps:
|
||||||
@@ -1418,7 +1406,7 @@ class NetworkTrainer:
|
|||||||
tokenize_strategy,
|
tokenize_strategy,
|
||||||
is_train=False,
|
is_train=False,
|
||||||
train_text_encoder=False,
|
train_text_encoder=False,
|
||||||
train_unet=False
|
train_unet=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
@@ -1446,16 +1434,15 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
# EPOCH VALIDATION
|
# EPOCH VALIDATION
|
||||||
should_validate_epoch = (
|
should_validate_epoch = (
|
||||||
(epoch + 1) % args.validate_every_n_epochs == 0
|
(epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None else True
|
||||||
if args.validate_every_n_epochs is not None
|
|
||||||
else True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if should_validate_epoch and len(val_dataloader) > 0:
|
if should_validate_epoch and len(val_dataloader) > 0:
|
||||||
val_progress_bar = tqdm(
|
val_progress_bar = tqdm(
|
||||||
range(validation_steps), smoothing=0,
|
range(validation_steps),
|
||||||
|
smoothing=0,
|
||||||
disable=not accelerator.is_local_main_process,
|
disable=not accelerator.is_local_main_process,
|
||||||
desc="epoch validation steps"
|
desc="epoch validation steps",
|
||||||
)
|
)
|
||||||
|
|
||||||
for val_step, batch in enumerate(val_dataloader):
|
for val_step, batch in enumerate(val_dataloader):
|
||||||
@@ -1480,7 +1467,7 @@ class NetworkTrainer:
|
|||||||
tokenize_strategy,
|
tokenize_strategy,
|
||||||
is_train=False,
|
is_train=False,
|
||||||
train_text_encoder=False,
|
train_text_encoder=False,
|
||||||
train_unet=False
|
train_unet=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
@@ -1492,7 +1479,7 @@ class NetworkTrainer:
|
|||||||
logs = {
|
logs = {
|
||||||
"loss/validation/epoch_current": current_loss,
|
"loss/validation/epoch_current": current_loss,
|
||||||
"epoch": epoch + 1,
|
"epoch": epoch + 1,
|
||||||
"val_step": (epoch * validation_steps) + val_step
|
"val_step": (epoch * validation_steps) + val_step,
|
||||||
}
|
}
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
@@ -1502,7 +1489,7 @@ class NetworkTrainer:
|
|||||||
logs = {
|
logs = {
|
||||||
"loss/validation/epoch_average": avr_loss,
|
"loss/validation/epoch_average": avr_loss,
|
||||||
"loss/validation/epoch_divergence": loss_validation_divergence,
|
"loss/validation/epoch_divergence": loss_validation_divergence,
|
||||||
"epoch": epoch + 1
|
"epoch": epoch + 1,
|
||||||
}
|
}
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
@@ -1696,31 +1683,31 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
"--validation_seed",
|
"--validation_seed",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する"
|
help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--validation_split",
|
"--validation_split",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.0,
|
||||||
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合"
|
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--validate_every_n_steps",
|
"--validate_every_n_steps",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます"
|
help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--validate_every_n_epochs",
|
"--validate_every_n_epochs",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます"
|
help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_validation_steps",
|
"--max_validation_steps",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します"
|
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します",
|
||||||
)
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user