mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
chore: formatting, add TODO comment
This commit is contained in:
@@ -70,7 +70,7 @@ class NetworkTrainer:
|
|||||||
mean_norm=None,
|
mean_norm=None,
|
||||||
maximum_norm=None,
|
maximum_norm=None,
|
||||||
mean_grad_norm=None,
|
mean_grad_norm=None,
|
||||||
mean_combined_norm=None
|
mean_combined_norm=None,
|
||||||
):
|
):
|
||||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||||
|
|
||||||
@@ -658,6 +658,10 @@ class NetworkTrainer:
|
|||||||
return
|
return
|
||||||
network_has_multiplier = hasattr(network, "set_multiplier")
|
network_has_multiplier = hasattr(network, "set_multiplier")
|
||||||
|
|
||||||
|
# TODO remove `hasattr`s by setting up methods if not defined in the network like (hacky but works):
|
||||||
|
# if not hasattr(network, "prepare_network"):
|
||||||
|
# network.prepare_network = lambda args: None
|
||||||
|
|
||||||
if hasattr(network, "prepare_network"):
|
if hasattr(network, "prepare_network"):
|
||||||
network.prepare_network(args)
|
network.prepare_network(args)
|
||||||
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
|
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
|
||||||
@@ -1024,7 +1028,7 @@ class NetworkTrainer:
|
|||||||
"ss_max_validation_steps": args.max_validation_steps,
|
"ss_max_validation_steps": args.max_validation_steps,
|
||||||
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
|
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
|
||||||
"ss_validate_every_n_steps": args.validate_every_n_steps,
|
"ss_validate_every_n_steps": args.validate_every_n_steps,
|
||||||
"ss_resize_interpolation": args.resize_interpolation
|
"ss_resize_interpolation": args.resize_interpolation,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.update_metadata(metadata, args) # architecture specific metadata
|
self.update_metadata(metadata, args) # architecture specific metadata
|
||||||
@@ -1415,7 +1419,6 @@ class NetworkTrainer:
|
|||||||
if hasattr(network, "update_norms"):
|
if hasattr(network, "update_norms"):
|
||||||
network.update_norms()
|
network.update_norms()
|
||||||
|
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
@@ -1476,7 +1479,17 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
logs = self.generate_step_logs(
|
logs = self.generate_step_logs(
|
||||||
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm, mean_grad_norm, mean_combined_norm
|
args,
|
||||||
|
current_loss,
|
||||||
|
avr_loss,
|
||||||
|
lr_scheduler,
|
||||||
|
lr_descriptions,
|
||||||
|
optimizer,
|
||||||
|
keys_scaled,
|
||||||
|
mean_norm,
|
||||||
|
maximum_norm,
|
||||||
|
mean_grad_norm,
|
||||||
|
mean_combined_norm,
|
||||||
)
|
)
|
||||||
self.step_logging(accelerator, logs, global_step, epoch + 1)
|
self.step_logging(accelerator, logs, global_step, epoch + 1)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user