mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Merge pull request #1974 from rockerBOO/lora-ggpo
Add LoRA-GGPO for Flux
This commit is contained in:
@@ -9,11 +9,13 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
from diffusers import AutoencoderKL
|
from diffusers import AutoencoderKL
|
||||||
from transformers import CLIPTextModel
|
from transformers import CLIPTextModel
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
import re
|
import re
|
||||||
from library.utils import setup_logging
|
from library.utils import setup_logging
|
||||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||||
@@ -44,6 +46,8 @@ class LoRAModule(torch.nn.Module):
|
|||||||
rank_dropout=None,
|
rank_dropout=None,
|
||||||
module_dropout=None,
|
module_dropout=None,
|
||||||
split_dims: Optional[List[int]] = None,
|
split_dims: Optional[List[int]] = None,
|
||||||
|
ggpo_beta: Optional[float] = None,
|
||||||
|
ggpo_sigma: Optional[float] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
if alpha == 0 or None, alpha is rank (no scaling).
|
if alpha == 0 or None, alpha is rank (no scaling).
|
||||||
@@ -103,9 +107,20 @@ class LoRAModule(torch.nn.Module):
|
|||||||
self.rank_dropout = rank_dropout
|
self.rank_dropout = rank_dropout
|
||||||
self.module_dropout = module_dropout
|
self.module_dropout = module_dropout
|
||||||
|
|
||||||
|
self.ggpo_sigma = ggpo_sigma
|
||||||
|
self.ggpo_beta = ggpo_beta
|
||||||
|
|
||||||
|
if self.ggpo_beta is not None and self.ggpo_sigma is not None:
|
||||||
|
self.combined_weight_norms = None
|
||||||
|
self.grad_norms = None
|
||||||
|
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
|
||||||
|
self.initialize_norm_cache(org_module.weight)
|
||||||
|
self.org_module_shape: tuple[int] = org_module.weight.shape
|
||||||
|
|
||||||
def apply_to(self):
|
def apply_to(self):
|
||||||
self.org_forward = self.org_module.forward
|
self.org_forward = self.org_module.forward
|
||||||
self.org_module.forward = self.forward
|
self.org_module.forward = self.forward
|
||||||
|
|
||||||
del self.org_module
|
del self.org_module
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@@ -140,7 +155,17 @@ class LoRAModule(torch.nn.Module):
|
|||||||
|
|
||||||
lx = self.lora_up(lx)
|
lx = self.lora_up(lx)
|
||||||
|
|
||||||
return org_forwarded + lx * self.multiplier * scale
|
# LoRA Gradient-Guided Perturbation Optimization
|
||||||
|
if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
|
||||||
|
with torch.no_grad():
|
||||||
|
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
|
||||||
|
perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device)
|
||||||
|
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
|
||||||
|
perturbation.mul_(perturbation_scale_factor)
|
||||||
|
perturbation_output = x @ perturbation.T # Result: (batch × n)
|
||||||
|
return org_forwarded + (self.multiplier * scale * lx) + perturbation_output
|
||||||
|
else:
|
||||||
|
return org_forwarded + lx * self.multiplier * scale
|
||||||
else:
|
else:
|
||||||
lxs = [lora_down(x) for lora_down in self.lora_down]
|
lxs = [lora_down(x) for lora_down in self.lora_down]
|
||||||
|
|
||||||
@@ -167,6 +192,116 @@ class LoRAModule(torch.nn.Module):
|
|||||||
|
|
||||||
return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
|
return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def initialize_norm_cache(self, org_module_weight: Tensor):
|
||||||
|
# Choose a reasonable sample size
|
||||||
|
n_rows = org_module_weight.shape[0]
|
||||||
|
sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller
|
||||||
|
|
||||||
|
# Sample random indices across all rows
|
||||||
|
indices = torch.randperm(n_rows)[:sample_size]
|
||||||
|
|
||||||
|
# Convert to a supported data type first, then index
|
||||||
|
# Use float32 for indexing operations
|
||||||
|
weights_float32 = org_module_weight.to(dtype=torch.float32)
|
||||||
|
sampled_weights = weights_float32[indices].to(device=self.device)
|
||||||
|
|
||||||
|
# Calculate sampled norms
|
||||||
|
sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True)
|
||||||
|
|
||||||
|
# Store the mean norm as our estimate
|
||||||
|
self.org_weight_norm_estimate = sampled_norms.mean()
|
||||||
|
|
||||||
|
# Optional: store standard deviation for confidence intervals
|
||||||
|
self.org_weight_norm_std = sampled_norms.std()
|
||||||
|
|
||||||
|
# Free memory
|
||||||
|
del sampled_weights, weights_float32
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def validate_norm_approximation(self, org_module_weight: Tensor, verbose=True):
|
||||||
|
# Calculate the true norm (this will be slow but it's just for validation)
|
||||||
|
true_norms = []
|
||||||
|
chunk_size = 1024 # Process in chunks to avoid OOM
|
||||||
|
|
||||||
|
for i in range(0, org_module_weight.shape[0], chunk_size):
|
||||||
|
end_idx = min(i + chunk_size, org_module_weight.shape[0])
|
||||||
|
chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype)
|
||||||
|
chunk_norms = torch.norm(chunk, dim=1, keepdim=True)
|
||||||
|
true_norms.append(chunk_norms.cpu())
|
||||||
|
del chunk
|
||||||
|
|
||||||
|
true_norms = torch.cat(true_norms, dim=0)
|
||||||
|
true_mean_norm = true_norms.mean().item()
|
||||||
|
|
||||||
|
# Compare with our estimate
|
||||||
|
estimated_norm = self.org_weight_norm_estimate.item()
|
||||||
|
|
||||||
|
# Calculate error metrics
|
||||||
|
absolute_error = abs(true_mean_norm - estimated_norm)
|
||||||
|
relative_error = absolute_error / true_mean_norm * 100 # as percentage
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
logger.info(f"True mean norm: {true_mean_norm:.6f}")
|
||||||
|
logger.info(f"Estimated norm: {estimated_norm:.6f}")
|
||||||
|
logger.info(f"Absolute error: {absolute_error:.6f}")
|
||||||
|
logger.info(f"Relative error: {relative_error:.2f}%")
|
||||||
|
|
||||||
|
return {
|
||||||
|
'true_mean_norm': true_mean_norm,
|
||||||
|
'estimated_norm': estimated_norm,
|
||||||
|
'absolute_error': absolute_error,
|
||||||
|
'relative_error': relative_error
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def update_norms(self):
|
||||||
|
# Not running GGPO so not currently running update norms
|
||||||
|
if self.ggpo_beta is None or self.ggpo_sigma is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# only update norms when we are training
|
||||||
|
if self.training is False:
|
||||||
|
return
|
||||||
|
|
||||||
|
module_weights = self.lora_up.weight @ self.lora_down.weight
|
||||||
|
module_weights.mul(self.scale)
|
||||||
|
|
||||||
|
self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True)
|
||||||
|
self.combined_weight_norms = torch.sqrt((self.org_weight_norm_estimate**2) +
|
||||||
|
torch.sum(module_weights**2, dim=1, keepdim=True))
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def update_grad_norms(self):
|
||||||
|
if self.training is False:
|
||||||
|
print(f"skipping update_grad_norms for {self.lora_name}")
|
||||||
|
return
|
||||||
|
|
||||||
|
lora_down_grad = None
|
||||||
|
lora_up_grad = None
|
||||||
|
|
||||||
|
for name, param in self.named_parameters():
|
||||||
|
if name == "lora_down.weight":
|
||||||
|
lora_down_grad = param.grad
|
||||||
|
elif name == "lora_up.weight":
|
||||||
|
lora_up_grad = param.grad
|
||||||
|
|
||||||
|
# Calculate gradient norms if we have both gradients
|
||||||
|
if lora_down_grad is not None and lora_up_grad is not None:
|
||||||
|
with torch.autocast(self.device.type):
|
||||||
|
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
|
||||||
|
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return next(self.parameters()).dtype
|
||||||
|
|
||||||
|
|
||||||
class LoRAInfModule(LoRAModule):
|
class LoRAInfModule(LoRAModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -420,6 +555,16 @@ def create_network(
|
|||||||
if split_qkv is not None:
|
if split_qkv is not None:
|
||||||
split_qkv = True if split_qkv == "True" else False
|
split_qkv = True if split_qkv == "True" else False
|
||||||
|
|
||||||
|
ggpo_beta = kwargs.get("ggpo_beta", None)
|
||||||
|
ggpo_sigma = kwargs.get("ggpo_sigma", None)
|
||||||
|
|
||||||
|
if ggpo_beta is not None:
|
||||||
|
ggpo_beta = float(ggpo_beta)
|
||||||
|
|
||||||
|
if ggpo_sigma is not None:
|
||||||
|
ggpo_sigma = float(ggpo_sigma)
|
||||||
|
|
||||||
|
|
||||||
# train T5XXL
|
# train T5XXL
|
||||||
train_t5xxl = kwargs.get("train_t5xxl", False)
|
train_t5xxl = kwargs.get("train_t5xxl", False)
|
||||||
if train_t5xxl is not None:
|
if train_t5xxl is not None:
|
||||||
@@ -449,6 +594,8 @@ def create_network(
|
|||||||
in_dims=in_dims,
|
in_dims=in_dims,
|
||||||
train_double_block_indices=train_double_block_indices,
|
train_double_block_indices=train_double_block_indices,
|
||||||
train_single_block_indices=train_single_block_indices,
|
train_single_block_indices=train_single_block_indices,
|
||||||
|
ggpo_beta=ggpo_beta,
|
||||||
|
ggpo_sigma=ggpo_sigma,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -561,6 +708,8 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
in_dims: Optional[List[int]] = None,
|
in_dims: Optional[List[int]] = None,
|
||||||
train_double_block_indices: Optional[List[bool]] = None,
|
train_double_block_indices: Optional[List[bool]] = None,
|
||||||
train_single_block_indices: Optional[List[bool]] = None,
|
train_single_block_indices: Optional[List[bool]] = None,
|
||||||
|
ggpo_beta: Optional[float] = None,
|
||||||
|
ggpo_sigma: Optional[float] = None,
|
||||||
verbose: Optional[bool] = False,
|
verbose: Optional[bool] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -599,10 +748,16 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
# logger.info(
|
# logger.info(
|
||||||
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
if ggpo_beta is not None and ggpo_sigma is not None:
|
||||||
|
logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}")
|
||||||
|
|
||||||
if self.split_qkv:
|
if self.split_qkv:
|
||||||
logger.info(f"split qkv for LoRA")
|
logger.info(f"split qkv for LoRA")
|
||||||
if self.train_blocks is not None:
|
if self.train_blocks is not None:
|
||||||
logger.info(f"train {self.train_blocks} blocks only")
|
logger.info(f"train {self.train_blocks} blocks only")
|
||||||
|
|
||||||
|
|
||||||
if train_t5xxl:
|
if train_t5xxl:
|
||||||
logger.info(f"train T5XXL as well")
|
logger.info(f"train T5XXL as well")
|
||||||
|
|
||||||
@@ -722,6 +877,8 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
rank_dropout=rank_dropout,
|
rank_dropout=rank_dropout,
|
||||||
module_dropout=module_dropout,
|
module_dropout=module_dropout,
|
||||||
split_dims=split_dims,
|
split_dims=split_dims,
|
||||||
|
ggpo_beta=ggpo_beta,
|
||||||
|
ggpo_sigma=ggpo_sigma,
|
||||||
)
|
)
|
||||||
loras.append(lora)
|
loras.append(lora)
|
||||||
|
|
||||||
@@ -790,6 +947,36 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
lora.enabled = is_enabled
|
lora.enabled = is_enabled
|
||||||
|
|
||||||
|
def update_norms(self):
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
lora.update_norms()
|
||||||
|
|
||||||
|
def update_grad_norms(self):
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
lora.update_grad_norms()
|
||||||
|
|
||||||
|
def grad_norms(self) -> Tensor:
|
||||||
|
grad_norms = []
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
if hasattr(lora, "grad_norms") and lora.grad_norms is not None:
|
||||||
|
grad_norms.append(lora.grad_norms.mean(dim=0))
|
||||||
|
return torch.stack(grad_norms) if len(grad_norms) > 0 else torch.tensor([])
|
||||||
|
|
||||||
|
def weight_norms(self) -> Tensor:
|
||||||
|
weight_norms = []
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
if hasattr(lora, "weight_norms") and lora.weight_norms is not None:
|
||||||
|
weight_norms.append(lora.weight_norms.mean(dim=0))
|
||||||
|
return torch.stack(weight_norms) if len(weight_norms) > 0 else torch.tensor([])
|
||||||
|
|
||||||
|
def combined_weight_norms(self) -> Tensor:
|
||||||
|
combined_weight_norms = []
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None:
|
||||||
|
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
|
||||||
|
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([])
|
||||||
|
|
||||||
|
|
||||||
def load_weights(self, file):
|
def load_weights(self, file):
|
||||||
if os.path.splitext(file)[1] == ".safetensors":
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|||||||
@@ -69,13 +69,20 @@ class NetworkTrainer:
|
|||||||
keys_scaled=None,
|
keys_scaled=None,
|
||||||
mean_norm=None,
|
mean_norm=None,
|
||||||
maximum_norm=None,
|
maximum_norm=None,
|
||||||
|
mean_grad_norm=None,
|
||||||
|
mean_combined_norm=None
|
||||||
):
|
):
|
||||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||||
|
|
||||||
if keys_scaled is not None:
|
if keys_scaled is not None:
|
||||||
logs["max_norm/keys_scaled"] = keys_scaled
|
logs["max_norm/keys_scaled"] = keys_scaled
|
||||||
logs["max_norm/average_key_norm"] = mean_norm
|
|
||||||
logs["max_norm/max_key_norm"] = maximum_norm
|
logs["max_norm/max_key_norm"] = maximum_norm
|
||||||
|
if mean_norm is not None:
|
||||||
|
logs["norm/avg_key_norm"] = mean_norm
|
||||||
|
if mean_grad_norm is not None:
|
||||||
|
logs["norm/avg_grad_norm"] = mean_grad_norm
|
||||||
|
if mean_combined_norm is not None:
|
||||||
|
logs["norm/avg_combined_norm"] = mean_combined_norm
|
||||||
|
|
||||||
lrs = lr_scheduler.get_last_lr()
|
lrs = lr_scheduler.get_last_lr()
|
||||||
for i, lr in enumerate(lrs):
|
for i, lr in enumerate(lrs):
|
||||||
@@ -1403,6 +1410,12 @@ class NetworkTrainer:
|
|||||||
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
|
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
|
||||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||||
|
|
||||||
|
if hasattr(network, "update_grad_norms"):
|
||||||
|
network.update_grad_norms()
|
||||||
|
if hasattr(network, "update_norms"):
|
||||||
|
network.update_norms()
|
||||||
|
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
@@ -1411,9 +1424,23 @@ class NetworkTrainer:
|
|||||||
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
|
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
|
||||||
args.scale_weight_norms, accelerator.device
|
args.scale_weight_norms, accelerator.device
|
||||||
)
|
)
|
||||||
|
mean_grad_norm = None
|
||||||
|
mean_combined_norm = None
|
||||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||||
else:
|
else:
|
||||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
if hasattr(network, "weight_norms"):
|
||||||
|
mean_norm = network.weight_norms().mean().item()
|
||||||
|
mean_grad_norm = network.grad_norms().mean().item()
|
||||||
|
mean_combined_norm = network.combined_weight_norms().mean().item()
|
||||||
|
weight_norms = network.weight_norms()
|
||||||
|
maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None
|
||||||
|
keys_scaled = None
|
||||||
|
max_mean_logs = {}
|
||||||
|
else:
|
||||||
|
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||||
|
mean_grad_norm = None
|
||||||
|
mean_combined_norm = None
|
||||||
|
max_mean_logs = {}
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
@@ -1445,14 +1472,11 @@ class NetworkTrainer:
|
|||||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||||
avr_loss: float = loss_recorder.moving_average
|
avr_loss: float = loss_recorder.moving_average
|
||||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||||
|
|
||||||
if args.scale_weight_norms:
|
|
||||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
|
||||||
|
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
logs = self.generate_step_logs(
|
logs = self.generate_step_logs(
|
||||||
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
|
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm, mean_grad_norm, mean_combined_norm
|
||||||
)
|
)
|
||||||
self.step_logging(accelerator, logs, global_step, epoch + 1)
|
self.step_logging(accelerator, logs, global_step, epoch + 1)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user