Merge pull request #1974 from rockerBOO/lora-ggpo

Add LoRA-GGPO for Flux
This commit is contained in:
Kohya S.
2025-03-30 21:07:31 +09:00
committed by GitHub
2 changed files with 219 additions and 8 deletions

View File

@@ -9,11 +9,13 @@
import math import math
import os import os
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Type, Union from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL from diffusers import AutoencoderKL
from transformers import CLIPTextModel from transformers import CLIPTextModel
import numpy as np import numpy as np
import torch import torch
from torch import Tensor
import re import re
from library.utils import setup_logging from library.utils import setup_logging
from library.sdxl_original_unet import SdxlUNet2DConditionModel from library.sdxl_original_unet import SdxlUNet2DConditionModel
@@ -44,6 +46,8 @@ class LoRAModule(torch.nn.Module):
rank_dropout=None, rank_dropout=None,
module_dropout=None, module_dropout=None,
split_dims: Optional[List[int]] = None, split_dims: Optional[List[int]] = None,
ggpo_beta: Optional[float] = None,
ggpo_sigma: Optional[float] = None,
): ):
""" """
if alpha == 0 or None, alpha is rank (no scaling). if alpha == 0 or None, alpha is rank (no scaling).
@@ -103,9 +107,20 @@ class LoRAModule(torch.nn.Module):
self.rank_dropout = rank_dropout self.rank_dropout = rank_dropout
self.module_dropout = module_dropout self.module_dropout = module_dropout
self.ggpo_sigma = ggpo_sigma
self.ggpo_beta = ggpo_beta
if self.ggpo_beta is not None and self.ggpo_sigma is not None:
self.combined_weight_norms = None
self.grad_norms = None
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
self.initialize_norm_cache(org_module.weight)
self.org_module_shape: tuple[int] = org_module.weight.shape
def apply_to(self): def apply_to(self):
self.org_forward = self.org_module.forward self.org_forward = self.org_module.forward
self.org_module.forward = self.forward self.org_module.forward = self.forward
del self.org_module del self.org_module
def forward(self, x): def forward(self, x):
@@ -140,7 +155,17 @@ class LoRAModule(torch.nn.Module):
lx = self.lora_up(lx) lx = self.lora_up(lx)
return org_forwarded + lx * self.multiplier * scale # LoRA Gradient-Guided Perturbation Optimization
if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
with torch.no_grad():
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device)
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
perturbation.mul_(perturbation_scale_factor)
perturbation_output = x @ perturbation.T # Result: (batch × n)
return org_forwarded + (self.multiplier * scale * lx) + perturbation_output
else:
return org_forwarded + lx * self.multiplier * scale
else: else:
lxs = [lora_down(x) for lora_down in self.lora_down] lxs = [lora_down(x) for lora_down in self.lora_down]
@@ -167,6 +192,116 @@ class LoRAModule(torch.nn.Module):
return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
@torch.no_grad()
def initialize_norm_cache(self, org_module_weight: Tensor):
# Choose a reasonable sample size
n_rows = org_module_weight.shape[0]
sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller
# Sample random indices across all rows
indices = torch.randperm(n_rows)[:sample_size]
# Convert to a supported data type first, then index
# Use float32 for indexing operations
weights_float32 = org_module_weight.to(dtype=torch.float32)
sampled_weights = weights_float32[indices].to(device=self.device)
# Calculate sampled norms
sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True)
# Store the mean norm as our estimate
self.org_weight_norm_estimate = sampled_norms.mean()
# Optional: store standard deviation for confidence intervals
self.org_weight_norm_std = sampled_norms.std()
# Free memory
del sampled_weights, weights_float32
@torch.no_grad()
def validate_norm_approximation(self, org_module_weight: Tensor, verbose=True):
# Calculate the true norm (this will be slow but it's just for validation)
true_norms = []
chunk_size = 1024 # Process in chunks to avoid OOM
for i in range(0, org_module_weight.shape[0], chunk_size):
end_idx = min(i + chunk_size, org_module_weight.shape[0])
chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype)
chunk_norms = torch.norm(chunk, dim=1, keepdim=True)
true_norms.append(chunk_norms.cpu())
del chunk
true_norms = torch.cat(true_norms, dim=0)
true_mean_norm = true_norms.mean().item()
# Compare with our estimate
estimated_norm = self.org_weight_norm_estimate.item()
# Calculate error metrics
absolute_error = abs(true_mean_norm - estimated_norm)
relative_error = absolute_error / true_mean_norm * 100 # as percentage
if verbose:
logger.info(f"True mean norm: {true_mean_norm:.6f}")
logger.info(f"Estimated norm: {estimated_norm:.6f}")
logger.info(f"Absolute error: {absolute_error:.6f}")
logger.info(f"Relative error: {relative_error:.2f}%")
return {
'true_mean_norm': true_mean_norm,
'estimated_norm': estimated_norm,
'absolute_error': absolute_error,
'relative_error': relative_error
}
@torch.no_grad()
def update_norms(self):
# Not running GGPO so not currently running update norms
if self.ggpo_beta is None or self.ggpo_sigma is None:
return
# only update norms when we are training
if self.training is False:
return
module_weights = self.lora_up.weight @ self.lora_down.weight
module_weights.mul(self.scale)
self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True)
self.combined_weight_norms = torch.sqrt((self.org_weight_norm_estimate**2) +
torch.sum(module_weights**2, dim=1, keepdim=True))
@torch.no_grad()
def update_grad_norms(self):
if self.training is False:
print(f"skipping update_grad_norms for {self.lora_name}")
return
lora_down_grad = None
lora_up_grad = None
for name, param in self.named_parameters():
if name == "lora_down.weight":
lora_down_grad = param.grad
elif name == "lora_up.weight":
lora_up_grad = param.grad
# Calculate gradient norms if we have both gradients
if lora_down_grad is not None and lora_up_grad is not None:
with torch.autocast(self.device.type):
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
class LoRAInfModule(LoRAModule): class LoRAInfModule(LoRAModule):
def __init__( def __init__(
@@ -420,6 +555,16 @@ def create_network(
if split_qkv is not None: if split_qkv is not None:
split_qkv = True if split_qkv == "True" else False split_qkv = True if split_qkv == "True" else False
ggpo_beta = kwargs.get("ggpo_beta", None)
ggpo_sigma = kwargs.get("ggpo_sigma", None)
if ggpo_beta is not None:
ggpo_beta = float(ggpo_beta)
if ggpo_sigma is not None:
ggpo_sigma = float(ggpo_sigma)
# train T5XXL # train T5XXL
train_t5xxl = kwargs.get("train_t5xxl", False) train_t5xxl = kwargs.get("train_t5xxl", False)
if train_t5xxl is not None: if train_t5xxl is not None:
@@ -449,6 +594,8 @@ def create_network(
in_dims=in_dims, in_dims=in_dims,
train_double_block_indices=train_double_block_indices, train_double_block_indices=train_double_block_indices,
train_single_block_indices=train_single_block_indices, train_single_block_indices=train_single_block_indices,
ggpo_beta=ggpo_beta,
ggpo_sigma=ggpo_sigma,
verbose=verbose, verbose=verbose,
) )
@@ -561,6 +708,8 @@ class LoRANetwork(torch.nn.Module):
in_dims: Optional[List[int]] = None, in_dims: Optional[List[int]] = None,
train_double_block_indices: Optional[List[bool]] = None, train_double_block_indices: Optional[List[bool]] = None,
train_single_block_indices: Optional[List[bool]] = None, train_single_block_indices: Optional[List[bool]] = None,
ggpo_beta: Optional[float] = None,
ggpo_sigma: Optional[float] = None,
verbose: Optional[bool] = False, verbose: Optional[bool] = False,
) -> None: ) -> None:
super().__init__() super().__init__()
@@ -599,10 +748,16 @@ class LoRANetwork(torch.nn.Module):
# logger.info( # logger.info(
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
# ) # )
if ggpo_beta is not None and ggpo_sigma is not None:
logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}")
if self.split_qkv: if self.split_qkv:
logger.info(f"split qkv for LoRA") logger.info(f"split qkv for LoRA")
if self.train_blocks is not None: if self.train_blocks is not None:
logger.info(f"train {self.train_blocks} blocks only") logger.info(f"train {self.train_blocks} blocks only")
if train_t5xxl: if train_t5xxl:
logger.info(f"train T5XXL as well") logger.info(f"train T5XXL as well")
@@ -722,6 +877,8 @@ class LoRANetwork(torch.nn.Module):
rank_dropout=rank_dropout, rank_dropout=rank_dropout,
module_dropout=module_dropout, module_dropout=module_dropout,
split_dims=split_dims, split_dims=split_dims,
ggpo_beta=ggpo_beta,
ggpo_sigma=ggpo_sigma,
) )
loras.append(lora) loras.append(lora)
@@ -790,6 +947,36 @@ class LoRANetwork(torch.nn.Module):
for lora in self.text_encoder_loras + self.unet_loras: for lora in self.text_encoder_loras + self.unet_loras:
lora.enabled = is_enabled lora.enabled = is_enabled
def update_norms(self):
for lora in self.text_encoder_loras + self.unet_loras:
lora.update_norms()
def update_grad_norms(self):
for lora in self.text_encoder_loras + self.unet_loras:
lora.update_grad_norms()
def grad_norms(self) -> Tensor:
grad_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "grad_norms") and lora.grad_norms is not None:
grad_norms.append(lora.grad_norms.mean(dim=0))
return torch.stack(grad_norms) if len(grad_norms) > 0 else torch.tensor([])
def weight_norms(self) -> Tensor:
weight_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "weight_norms") and lora.weight_norms is not None:
weight_norms.append(lora.weight_norms.mean(dim=0))
return torch.stack(weight_norms) if len(weight_norms) > 0 else torch.tensor([])
def combined_weight_norms(self) -> Tensor:
combined_weight_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None:
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([])
def load_weights(self, file): def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors": if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file from safetensors.torch import load_file

View File

@@ -69,13 +69,20 @@ class NetworkTrainer:
keys_scaled=None, keys_scaled=None,
mean_norm=None, mean_norm=None,
maximum_norm=None, maximum_norm=None,
mean_grad_norm=None,
mean_combined_norm=None
): ):
logs = {"loss/current": current_loss, "loss/average": avr_loss} logs = {"loss/current": current_loss, "loss/average": avr_loss}
if keys_scaled is not None: if keys_scaled is not None:
logs["max_norm/keys_scaled"] = keys_scaled logs["max_norm/keys_scaled"] = keys_scaled
logs["max_norm/average_key_norm"] = mean_norm
logs["max_norm/max_key_norm"] = maximum_norm logs["max_norm/max_key_norm"] = maximum_norm
if mean_norm is not None:
logs["norm/avg_key_norm"] = mean_norm
if mean_grad_norm is not None:
logs["norm/avg_grad_norm"] = mean_grad_norm
if mean_combined_norm is not None:
logs["norm/avg_combined_norm"] = mean_combined_norm
lrs = lr_scheduler.get_last_lr() lrs = lr_scheduler.get_last_lr()
for i, lr in enumerate(lrs): for i, lr in enumerate(lrs):
@@ -1403,6 +1410,12 @@ class NetworkTrainer:
params_to_clip = accelerator.unwrap_model(network).get_trainable_params() params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
if hasattr(network, "update_grad_norms"):
network.update_grad_norms()
if hasattr(network, "update_norms"):
network.update_norms()
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
@@ -1411,9 +1424,23 @@ class NetworkTrainer:
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization( keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
args.scale_weight_norms, accelerator.device args.scale_weight_norms, accelerator.device
) )
mean_grad_norm = None
mean_combined_norm = None
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
else: else:
keys_scaled, mean_norm, maximum_norm = None, None, None if hasattr(network, "weight_norms"):
mean_norm = network.weight_norms().mean().item()
mean_grad_norm = network.grad_norms().mean().item()
mean_combined_norm = network.combined_weight_norms().mean().item()
weight_norms = network.weight_norms()
maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None
keys_scaled = None
max_mean_logs = {}
else:
keys_scaled, mean_norm, maximum_norm = None, None, None
mean_grad_norm = None
mean_combined_norm = None
max_mean_logs = {}
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
@@ -1445,14 +1472,11 @@ class NetworkTrainer:
loss_recorder.add(epoch=epoch, step=step, loss=current_loss) loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs) progress_bar.set_postfix(**{**max_mean_logs, **logs})
if args.scale_weight_norms:
progress_bar.set_postfix(**{**max_mean_logs, **logs})
if is_tracking: if is_tracking:
logs = self.generate_step_logs( logs = self.generate_step_logs(
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm, mean_grad_norm, mean_combined_norm
) )
self.step_logging(accelerator, logs, global_step, epoch + 1) self.step_logging(accelerator, logs, global_step, epoch + 1)