diff --git a/train_network.py b/train_network.py index 77cbe211..af13135d 100644 --- a/train_network.py +++ b/train_network.py @@ -1435,8 +1435,7 @@ class NetworkTrainer: optimizer.step() lr_scheduler.step() - # optimizer.zero_grad(set_to_none=True) - optimizer.zero_grad(set_to_none=False) + optimizer.zero_grad(set_to_none=True) if args.scale_weight_norms: keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(