Add IVON optimizer support

This commit is contained in:
rockerBOO
2025-06-18 15:45:31 -04:00
parent 3e6935a07e
commit 7ef68b5dc6
2 changed files with 37 additions and 20 deletions

15
library/network_utils.py Normal file
View File

@@ -0,0 +1,15 @@
def maybe_sample_params(optimizer):
"""
Returns parameter sampling context for IVON optimizers, otherwise returns no-op context.
pip install ivon-opt
Args:
optimizer: PyTorch optimizer instance.
Returns:
Context manager for parameter sampling if optimizer supports it, otherwise nullcontext().
"""
from contextlib import nullcontext
return optimizer.sampled_params(train=True) if hasattr(optimizer, "sampled_params") else nullcontext()

View File

@@ -17,6 +17,7 @@ from tqdm import tqdm
import torch
from torch.types import Number
from library.device_utils import init_ipex, clean_memory_on_device
from library.network_utils import maybe_sample_params
init_ipex()
@@ -1399,26 +1400,26 @@ class NetworkTrainer:
# preprocess batch for each model
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
with maybe_sample_params(optimizer.optimizer):
loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=True,
train_text_encoder=train_text_encoder,
train_unet=train_unet,
)
loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=True,
train_text_encoder=train_text_encoder,
train_unet=train_unet,
)
accelerator.backward(loss)
accelerator.backward(loss)
if accelerator.sync_gradients:
self.all_reduce_network(accelerator, network) # sync DDP grad manually
if args.max_grad_norm != 0.0:
@@ -1432,7 +1433,8 @@ class NetworkTrainer:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# optimizer.zero_grad(set_to_none=True)
optimizer.zero_grad(set_to_none=False)
if args.scale_weight_norms:
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(