diff --git a/train_network.py b/train_network.py index 3bab0cad..f66cdeb4 100644 --- a/train_network.py +++ b/train_network.py @@ -70,7 +70,7 @@ class NetworkTrainer: mean_norm=None, maximum_norm=None, mean_grad_norm=None, - mean_combined_norm=None + mean_combined_norm=None, ): logs = {"loss/current": current_loss, "loss/average": avr_loss} @@ -658,6 +658,10 @@ class NetworkTrainer: return network_has_multiplier = hasattr(network, "set_multiplier") + # TODO remove `hasattr`s by setting up methods if not defined in the network like (hacky but works): + # if not hasattr(network, "prepare_network"): + # network.prepare_network = lambda args: None + if hasattr(network, "prepare_network"): network.prepare_network(args) if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"): @@ -1019,12 +1023,12 @@ class NetworkTrainer: "ss_huber_c": args.huber_c, "ss_fp8_base": bool(args.fp8_base), "ss_fp8_base_unet": bool(args.fp8_base_unet), - "ss_validation_seed": args.validation_seed, - "ss_validation_split": args.validation_split, - "ss_max_validation_steps": args.max_validation_steps, - "ss_validate_every_n_epochs": args.validate_every_n_epochs, - "ss_validate_every_n_steps": args.validate_every_n_steps, - "ss_resize_interpolation": args.resize_interpolation + "ss_validation_seed": args.validation_seed, + "ss_validation_split": args.validation_split, + "ss_max_validation_steps": args.max_validation_steps, + "ss_validate_every_n_epochs": args.validate_every_n_epochs, + "ss_validate_every_n_steps": args.validate_every_n_steps, + "ss_resize_interpolation": args.resize_interpolation, } self.update_metadata(metadata, args) # architecture specific metadata @@ -1415,7 +1419,6 @@ class NetworkTrainer: if hasattr(network, "update_norms"): network.update_norms() - optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) @@ -1476,7 +1479,17 @@ class NetworkTrainer: if is_tracking: logs = self.generate_step_logs( - args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm, mean_grad_norm, mean_combined_norm + args, + current_loss, + avr_loss, + lr_scheduler, + lr_descriptions, + optimizer, + keys_scaled, + mean_norm, + maximum_norm, + mean_grad_norm, + mean_combined_norm, ) self.step_logging(accelerator, logs, global_step, epoch + 1)