mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
Add IVON optimizer support
This commit is contained in:
15
library/network_utils.py
Normal file
15
library/network_utils.py
Normal file
@@ -0,0 +1,15 @@
|
||||
def maybe_sample_params(optimizer):
|
||||
"""
|
||||
Returns parameter sampling context for IVON optimizers, otherwise returns no-op context.
|
||||
|
||||
pip install ivon-opt
|
||||
|
||||
Args:
|
||||
optimizer: PyTorch optimizer instance.
|
||||
|
||||
Returns:
|
||||
Context manager for parameter sampling if optimizer supports it, otherwise nullcontext().
|
||||
"""
|
||||
from contextlib import nullcontext
|
||||
|
||||
return optimizer.sampled_params(train=True) if hasattr(optimizer, "sampled_params") else nullcontext()
|
||||
@@ -17,6 +17,7 @@ from tqdm import tqdm
|
||||
import torch
|
||||
from torch.types import Number
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library.network_utils import maybe_sample_params
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -1399,26 +1400,26 @@ class NetworkTrainer:
|
||||
|
||||
# preprocess batch for each model
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
|
||||
with maybe_sample_params(optimizer.optimizer):
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=True,
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet,
|
||||
)
|
||||
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=True,
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet,
|
||||
)
|
||||
|
||||
accelerator.backward(loss)
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
self.all_reduce_network(accelerator, network) # sync DDP grad manually
|
||||
if args.max_grad_norm != 0.0:
|
||||
@@ -1432,7 +1433,8 @@ class NetworkTrainer:
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
# optimizer.zero_grad(set_to_none=True)
|
||||
optimizer.zero_grad(set_to_none=False)
|
||||
|
||||
if args.scale_weight_norms:
|
||||
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
|
||||
|
||||
Reference in New Issue
Block a user