Cache weight norms estimate on initialization. Move to update norms every step

This commit is contained in:
rockerBOO
2025-03-18 14:25:09 -04:00
parent ea53290f62
commit 3647d065b5
2 changed files with 145 additions and 33 deletions

View File

@@ -15,6 +15,7 @@ from diffusers import AutoencoderKL
from transformers import CLIPTextModel from transformers import CLIPTextModel
import numpy as np import numpy as np
import torch import torch
from torch import Tensor
import re import re
from library.utils import setup_logging from library.utils import setup_logging
from library.sdxl_original_unet import SdxlUNet2DConditionModel from library.sdxl_original_unet import SdxlUNet2DConditionModel
@@ -145,8 +146,13 @@ class LoRAModule(torch.nn.Module):
self.ggpo_sigma = ggpo_sigma self.ggpo_sigma = ggpo_sigma
self.ggpo_beta = ggpo_beta self.ggpo_beta = ggpo_beta
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0]) if self.ggpo_beta is not None and self.ggpo_sigma is not None:
self._org_module_weight = self.org_module.weight.detach() self.combined_weight_norms = None
self.grad_norms = None
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
self.perturbation_seed = torch.randint(0, 2**32 - 1, (1,)).detach().item()
self.initialize_norm_cache(org_module.weight)
self.org_module_shape: tuple[int] = org_module.weight.shape
def apply_to(self): def apply_to(self):
self.org_forward = self.org_module.forward self.org_forward = self.org_module.forward
@@ -187,10 +193,12 @@ class LoRAModule(torch.nn.Module):
lx = self.lora_up(lx) lx = self.lora_up(lx)
# LoRA Gradient-Guided Perturbation Optimization # LoRA Gradient-Guided Perturbation Optimization
if self.training and hasattr(self, 'perturbation_seed') and self.ggpo_sigma is not None and self.ggpo_beta is not None: if self.training and hasattr(self, 'perturbation_seed') and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
with torch.no_grad(), torch.autocast(self.device.type), temp_random_seed(self.perturbation_seed): with torch.no_grad(), temp_random_seed(self.perturbation_seed):
perturbation = torch.randn_like(self._org_module_weight, dtype=self.dtype, device=self.device) perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
perturbation.mul_(self.perturbation_scale_factor) perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device)
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
perturbation.mul_(perturbation_scale_factor)
perturbation_output = x @ perturbation.T # Result: (batch × n) perturbation_output = x @ perturbation.T # Result: (batch × n)
return org_forwarded + (self.multiplier * scale * lx) + perturbation_output return org_forwarded + (self.multiplier * scale * lx) + perturbation_output
else: else:
@@ -221,6 +229,69 @@ class LoRAModule(torch.nn.Module):
return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
@torch.no_grad()
def initialize_norm_cache(self, org_module_weight: Tensor):
# Choose a reasonable sample size
n_rows = org_module_weight.shape[0]
sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller
# Sample random indices across all rows
indices = torch.randperm(n_rows)[:sample_size]
# Convert to a supported data type first, then index
# Use float32 for indexing operations
weights_float32 = org_module_weight.to(dtype=torch.float32)
sampled_weights = weights_float32[indices].to(device=self.device)
# Calculate sampled norms
sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True)
# Store the mean norm as our estimate
self.org_weight_norm_estimate = sampled_norms.mean()
# Optional: store standard deviation for confidence intervals
self.org_weight_norm_std = sampled_norms.std()
# Free memory
del sampled_weights, weights_float32
@torch.no_grad()
def validate_norm_approximation(self, org_module_weight: Tensor, verbose=True):
# Calculate the true norm (this will be slow but it's just for validation)
true_norms = []
chunk_size = 1024 # Process in chunks to avoid OOM
for i in range(0, org_module_weight.shape[0], chunk_size):
end_idx = min(i + chunk_size, org_module_weight.shape[0])
chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype)
chunk_norms = torch.norm(chunk, dim=1, keepdim=True)
true_norms.append(chunk_norms.cpu())
del chunk
true_norms = torch.cat(true_norms, dim=0)
true_mean_norm = true_norms.mean().item()
# Compare with our estimate
estimated_norm = self.org_weight_norm_estimate.item()
# Calculate error metrics
absolute_error = abs(true_mean_norm - estimated_norm)
relative_error = absolute_error / true_mean_norm * 100 # as percentage
if verbose:
logger.info(f"True mean norm: {true_mean_norm:.6f}")
logger.info(f"Estimated norm: {estimated_norm:.6f}")
logger.info(f"Absolute error: {absolute_error:.6f}")
logger.info(f"Relative error: {relative_error:.2f}%")
return {
'true_mean_norm': true_mean_norm,
'estimated_norm': estimated_norm,
'absolute_error': absolute_error,
'relative_error': relative_error
}
@torch.no_grad() @torch.no_grad()
def update_norms(self): def update_norms(self):
# Not running GGPO so not currently running update norms # Not running GGPO so not currently running update norms
@@ -228,8 +299,20 @@ class LoRAModule(torch.nn.Module):
return return
# only update norms when we are training # only update norms when we are training
if self.lora_down.weight.requires_grad is not True: if self.training is False:
print(f"skipping update_norms for {self.lora_name}") return
module_weights = self.lora_up.weight @ self.lora_down.weight
module_weights.mul(self.scale)
self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True)
self.combined_weight_norms = torch.sqrt((self.org_weight_norm_estimate**2) +
torch.sum(module_weights**2, dim=1, keepdim=True))
@torch.no_grad()
def update_grad_norms(self):
if self.training is False:
print(f"skipping update_grad_norms for {self.lora_name}")
return return
lora_down_grad = None lora_down_grad = None
@@ -241,29 +324,12 @@ class LoRAModule(torch.nn.Module):
elif name == "lora_up.weight": elif name == "lora_up.weight":
lora_up_grad = param.grad lora_up_grad = param.grad
with torch.autocast(self.device.type):
module_weights = self.scale * (self.lora_up.weight @ self.lora_down.weight)
org_device = self._org_module_weight.device
org_dtype = self._org_module_weight.dtype
org_weight = self._org_module_weight.to(device=self.device, dtype=self.dtype)
combined_weight = org_weight + module_weights
self.combined_weight_norms = torch.norm(combined_weight, dim=1, keepdim=True)
self._org_module_weight.to(device=org_device, dtype=org_dtype)
# Calculate gradient norms if we have both gradients # Calculate gradient norms if we have both gradients
if lora_down_grad is not None and lora_up_grad is not None: if lora_down_grad is not None and lora_up_grad is not None:
with torch.autocast(self.device.type): with torch.autocast(self.device.type):
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight)) approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True) self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)
self.perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
self.perturbation_scale_factor = (self.perturbation_scale * self.perturbation_norm_factor).to(self.device)
# LoRA Gradient-Guided Perturbation Optimization
self.perturbation_seed = torch.randint(0, 2**32 - 1, (1,)).detach().item()
@property @property
def device(self): def device(self):
@@ -922,6 +988,32 @@ class LoRANetwork(torch.nn.Module):
for lora in self.text_encoder_loras + self.unet_loras: for lora in self.text_encoder_loras + self.unet_loras:
lora.update_norms() lora.update_norms()
def update_grad_norms(self):
for lora in self.text_encoder_loras + self.unet_loras:
lora.update_grad_norms()
def grad_norms(self) -> Tensor:
grad_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "grad_norms") and lora.grad_norms is not None:
grad_norms.append(lora.grad_norms.mean(dim=0))
return torch.stack(grad_norms) if len(grad_norms) > 0 else torch.tensor([])
def weight_norms(self) -> Tensor:
weight_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "weight_norms") and lora.weight_norms is not None:
weight_norms.append(lora.weight_norms.mean(dim=0))
return torch.stack(weight_norms) if len(weight_norms) > 0 else torch.tensor([])
def combined_weight_norms(self) -> Tensor:
combined_weight_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None:
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([])
def load_weights(self, file): def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors": if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file from safetensors.torch import load_file

View File

@@ -69,13 +69,20 @@ class NetworkTrainer:
keys_scaled=None, keys_scaled=None,
mean_norm=None, mean_norm=None,
maximum_norm=None, maximum_norm=None,
mean_grad_norm=None,
mean_combined_norm=None
): ):
logs = {"loss/current": current_loss, "loss/average": avr_loss} logs = {"loss/current": current_loss, "loss/average": avr_loss}
if keys_scaled is not None: if keys_scaled is not None:
logs["max_norm/keys_scaled"] = keys_scaled logs["max_norm/keys_scaled"] = keys_scaled
logs["max_norm/average_key_norm"] = mean_norm
logs["max_norm/max_key_norm"] = maximum_norm logs["max_norm/max_key_norm"] = maximum_norm
if mean_norm is not None:
logs["norm/avg_key_norm"] = mean_norm
if mean_grad_norm is not None:
logs["norm/avg_grad_norm"] = mean_grad_norm
if mean_combined_norm is not None:
logs["norm/avg_combined_norm"] = mean_combined_norm
lrs = lr_scheduler.get_last_lr() lrs = lr_scheduler.get_last_lr()
for i, lr in enumerate(lrs): for i, lr in enumerate(lrs):
@@ -1400,10 +1407,12 @@ class NetworkTrainer:
params_to_clip = accelerator.unwrap_model(network).get_trainable_params() params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
if global_step % 5 == 0: if hasattr(network, "update_grad_norms"):
network.update_grad_norms()
if hasattr(network, "update_norms"): if hasattr(network, "update_norms"):
network.update_norms() network.update_norms()
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
@@ -1412,9 +1421,23 @@ class NetworkTrainer:
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization( keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
args.scale_weight_norms, accelerator.device args.scale_weight_norms, accelerator.device
) )
mean_grad_norm = None
mean_combined_norm = None
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
else: else:
keys_scaled, mean_norm, maximum_norm = None, None, None if hasattr(network, "weight_norms"):
mean_norm = network.weight_norms().mean().item()
mean_grad_norm = network.grad_norms().mean().item()
mean_combined_norm = network.combined_weight_norms().mean().item()
weight_norms = network.weight_norms()
maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None
keys_scaled = None
max_mean_logs = {"avg weight norm": mean_norm, "avg grad norm": mean_grad_norm, "avg comb norm": mean_combined_norm}
else:
keys_scaled, mean_norm, maximum_norm = None, None, None
mean_grad_norm = None
mean_combined_norm = None
max_mean_logs = {}
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
@@ -1446,14 +1469,11 @@ class NetworkTrainer:
loss_recorder.add(epoch=epoch, step=step, loss=current_loss) loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs) progress_bar.set_postfix(**{**max_mean_logs, **logs})
if args.scale_weight_norms:
progress_bar.set_postfix(**{**max_mean_logs, **logs})
if is_tracking: if is_tracking:
logs = self.generate_step_logs( logs = self.generate_step_logs(
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm, mean_grad_norm, mean_combined_norm
) )
self.step_logging(accelerator, logs, global_step, epoch + 1) self.step_logging(accelerator, logs, global_step, epoch + 1)