chore: formatting, add TODO comment

This commit is contained in:
Kohya S
2025-03-30 21:15:37 +09:00
parent 59d98e45a9
commit d0b5c0e5cf

View File

@@ -70,7 +70,7 @@ class NetworkTrainer:
mean_norm=None,
maximum_norm=None,
mean_grad_norm=None,
mean_combined_norm=None
mean_combined_norm=None,
):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
@@ -658,6 +658,10 @@ class NetworkTrainer:
return
network_has_multiplier = hasattr(network, "set_multiplier")
# TODO remove `hasattr`s by setting up methods if not defined in the network like (hacky but works):
# if not hasattr(network, "prepare_network"):
# network.prepare_network = lambda args: None
if hasattr(network, "prepare_network"):
network.prepare_network(args)
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
@@ -1019,12 +1023,12 @@ class NetworkTrainer:
"ss_huber_c": args.huber_c,
"ss_fp8_base": bool(args.fp8_base),
"ss_fp8_base_unet": bool(args.fp8_base_unet),
"ss_validation_seed": args.validation_seed,
"ss_validation_split": args.validation_split,
"ss_max_validation_steps": args.max_validation_steps,
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
"ss_validate_every_n_steps": args.validate_every_n_steps,
"ss_resize_interpolation": args.resize_interpolation
"ss_validation_seed": args.validation_seed,
"ss_validation_split": args.validation_split,
"ss_max_validation_steps": args.max_validation_steps,
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
"ss_validate_every_n_steps": args.validate_every_n_steps,
"ss_resize_interpolation": args.resize_interpolation,
}
self.update_metadata(metadata, args) # architecture specific metadata
@@ -1415,7 +1419,6 @@ class NetworkTrainer:
if hasattr(network, "update_norms"):
network.update_norms()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
@@ -1476,7 +1479,17 @@ class NetworkTrainer:
if is_tracking:
logs = self.generate_step_logs(
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm, mean_grad_norm, mean_combined_norm
args,
current_loss,
avr_loss,
lr_scheduler,
lr_descriptions,
optimizer,
keys_scaled,
mean_norm,
maximum_norm,
mean_grad_norm,
mean_combined_norm,
)
self.step_logging(accelerator, logs, global_step, epoch + 1)