From 7ef68b5dc69ea9f3594a9ab3880e485022981d3a Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 18 Jun 2025 15:45:31 -0400 Subject: [PATCH] Add IVON optimizer support --- library/network_utils.py | 15 ++++++++++++++ train_network.py | 42 +++++++++++++++++++++------------------- 2 files changed, 37 insertions(+), 20 deletions(-) create mode 100644 library/network_utils.py diff --git a/library/network_utils.py b/library/network_utils.py new file mode 100644 index 00000000..1dafede5 --- /dev/null +++ b/library/network_utils.py @@ -0,0 +1,15 @@ +def maybe_sample_params(optimizer): + """ + Returns parameter sampling context for IVON optimizers, otherwise returns no-op context. + + pip install ivon-opt + + Args: + optimizer: PyTorch optimizer instance. + + Returns: + Context manager for parameter sampling if optimizer supports it, otherwise nullcontext(). + """ + from contextlib import nullcontext + + return optimizer.sampled_params(train=True) if hasattr(optimizer, "sampled_params") else nullcontext() diff --git a/train_network.py b/train_network.py index 1336a0b1..d5812fc0 100644 --- a/train_network.py +++ b/train_network.py @@ -17,6 +17,7 @@ from tqdm import tqdm import torch from torch.types import Number from library.device_utils import init_ipex, clean_memory_on_device +from library.network_utils import maybe_sample_params init_ipex() @@ -1399,26 +1400,26 @@ class NetworkTrainer: # preprocess batch for each model self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True) + with maybe_sample_params(optimizer.optimizer): + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=True, + train_text_encoder=train_text_encoder, + train_unet=train_unet, + ) - loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, - is_train=True, - train_text_encoder=train_text_encoder, - train_unet=train_unet, - ) - - accelerator.backward(loss) + accelerator.backward(loss) if accelerator.sync_gradients: self.all_reduce_network(accelerator, network) # sync DDP grad manually if args.max_grad_norm != 0.0: @@ -1432,7 +1433,8 @@ class NetworkTrainer: optimizer.step() lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + # optimizer.zero_grad(set_to_none=True) + optimizer.zero_grad(set_to_none=False) if args.scale_weight_norms: keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(