mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Cache weight norms estimate on initialization. Move to update norms every step
This commit is contained in:
@@ -15,6 +15,7 @@ from diffusers import AutoencoderKL
|
||||
from transformers import CLIPTextModel
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import re
|
||||
from library.utils import setup_logging
|
||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||
@@ -145,8 +146,13 @@ class LoRAModule(torch.nn.Module):
|
||||
self.ggpo_sigma = ggpo_sigma
|
||||
self.ggpo_beta = ggpo_beta
|
||||
|
||||
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
|
||||
self._org_module_weight = self.org_module.weight.detach()
|
||||
if self.ggpo_beta is not None and self.ggpo_sigma is not None:
|
||||
self.combined_weight_norms = None
|
||||
self.grad_norms = None
|
||||
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
|
||||
self.perturbation_seed = torch.randint(0, 2**32 - 1, (1,)).detach().item()
|
||||
self.initialize_norm_cache(org_module.weight)
|
||||
self.org_module_shape: tuple[int] = org_module.weight.shape
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
@@ -187,10 +193,12 @@ class LoRAModule(torch.nn.Module):
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
# LoRA Gradient-Guided Perturbation Optimization
|
||||
if self.training and hasattr(self, 'perturbation_seed') and self.ggpo_sigma is not None and self.ggpo_beta is not None:
|
||||
with torch.no_grad(), torch.autocast(self.device.type), temp_random_seed(self.perturbation_seed):
|
||||
perturbation = torch.randn_like(self._org_module_weight, dtype=self.dtype, device=self.device)
|
||||
perturbation.mul_(self.perturbation_scale_factor)
|
||||
if self.training and hasattr(self, 'perturbation_seed') and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
|
||||
with torch.no_grad(), temp_random_seed(self.perturbation_seed):
|
||||
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
|
||||
perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device)
|
||||
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
|
||||
perturbation.mul_(perturbation_scale_factor)
|
||||
perturbation_output = x @ perturbation.T # Result: (batch × n)
|
||||
return org_forwarded + (self.multiplier * scale * lx) + perturbation_output
|
||||
else:
|
||||
@@ -221,6 +229,69 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
|
||||
|
||||
@torch.no_grad()
|
||||
def initialize_norm_cache(self, org_module_weight: Tensor):
|
||||
# Choose a reasonable sample size
|
||||
n_rows = org_module_weight.shape[0]
|
||||
sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller
|
||||
|
||||
# Sample random indices across all rows
|
||||
indices = torch.randperm(n_rows)[:sample_size]
|
||||
|
||||
# Convert to a supported data type first, then index
|
||||
# Use float32 for indexing operations
|
||||
weights_float32 = org_module_weight.to(dtype=torch.float32)
|
||||
sampled_weights = weights_float32[indices].to(device=self.device)
|
||||
|
||||
# Calculate sampled norms
|
||||
sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True)
|
||||
|
||||
# Store the mean norm as our estimate
|
||||
self.org_weight_norm_estimate = sampled_norms.mean()
|
||||
|
||||
# Optional: store standard deviation for confidence intervals
|
||||
self.org_weight_norm_std = sampled_norms.std()
|
||||
|
||||
# Free memory
|
||||
del sampled_weights, weights_float32
|
||||
|
||||
@torch.no_grad()
|
||||
def validate_norm_approximation(self, org_module_weight: Tensor, verbose=True):
|
||||
# Calculate the true norm (this will be slow but it's just for validation)
|
||||
true_norms = []
|
||||
chunk_size = 1024 # Process in chunks to avoid OOM
|
||||
|
||||
for i in range(0, org_module_weight.shape[0], chunk_size):
|
||||
end_idx = min(i + chunk_size, org_module_weight.shape[0])
|
||||
chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype)
|
||||
chunk_norms = torch.norm(chunk, dim=1, keepdim=True)
|
||||
true_norms.append(chunk_norms.cpu())
|
||||
del chunk
|
||||
|
||||
true_norms = torch.cat(true_norms, dim=0)
|
||||
true_mean_norm = true_norms.mean().item()
|
||||
|
||||
# Compare with our estimate
|
||||
estimated_norm = self.org_weight_norm_estimate.item()
|
||||
|
||||
# Calculate error metrics
|
||||
absolute_error = abs(true_mean_norm - estimated_norm)
|
||||
relative_error = absolute_error / true_mean_norm * 100 # as percentage
|
||||
|
||||
if verbose:
|
||||
logger.info(f"True mean norm: {true_mean_norm:.6f}")
|
||||
logger.info(f"Estimated norm: {estimated_norm:.6f}")
|
||||
logger.info(f"Absolute error: {absolute_error:.6f}")
|
||||
logger.info(f"Relative error: {relative_error:.2f}%")
|
||||
|
||||
return {
|
||||
'true_mean_norm': true_mean_norm,
|
||||
'estimated_norm': estimated_norm,
|
||||
'absolute_error': absolute_error,
|
||||
'relative_error': relative_error
|
||||
}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def update_norms(self):
|
||||
# Not running GGPO so not currently running update norms
|
||||
@@ -228,8 +299,20 @@ class LoRAModule(torch.nn.Module):
|
||||
return
|
||||
|
||||
# only update norms when we are training
|
||||
if self.lora_down.weight.requires_grad is not True:
|
||||
print(f"skipping update_norms for {self.lora_name}")
|
||||
if self.training is False:
|
||||
return
|
||||
|
||||
module_weights = self.lora_up.weight @ self.lora_down.weight
|
||||
module_weights.mul(self.scale)
|
||||
|
||||
self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True)
|
||||
self.combined_weight_norms = torch.sqrt((self.org_weight_norm_estimate**2) +
|
||||
torch.sum(module_weights**2, dim=1, keepdim=True))
|
||||
|
||||
@torch.no_grad()
|
||||
def update_grad_norms(self):
|
||||
if self.training is False:
|
||||
print(f"skipping update_grad_norms for {self.lora_name}")
|
||||
return
|
||||
|
||||
lora_down_grad = None
|
||||
@@ -241,29 +324,12 @@ class LoRAModule(torch.nn.Module):
|
||||
elif name == "lora_up.weight":
|
||||
lora_up_grad = param.grad
|
||||
|
||||
with torch.autocast(self.device.type):
|
||||
module_weights = self.scale * (self.lora_up.weight @ self.lora_down.weight)
|
||||
org_device = self._org_module_weight.device
|
||||
org_dtype = self._org_module_weight.dtype
|
||||
org_weight = self._org_module_weight.to(device=self.device, dtype=self.dtype)
|
||||
combined_weight = org_weight + module_weights
|
||||
|
||||
self.combined_weight_norms = torch.norm(combined_weight, dim=1, keepdim=True)
|
||||
|
||||
self._org_module_weight.to(device=org_device, dtype=org_dtype)
|
||||
|
||||
|
||||
# Calculate gradient norms if we have both gradients
|
||||
if lora_down_grad is not None and lora_up_grad is not None:
|
||||
with torch.autocast(self.device.type):
|
||||
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
|
||||
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)
|
||||
|
||||
self.perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
|
||||
self.perturbation_scale_factor = (self.perturbation_scale * self.perturbation_norm_factor).to(self.device)
|
||||
|
||||
# LoRA Gradient-Guided Perturbation Optimization
|
||||
self.perturbation_seed = torch.randint(0, 2**32 - 1, (1,)).detach().item()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
@@ -922,6 +988,32 @@ class LoRANetwork(torch.nn.Module):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.update_norms()
|
||||
|
||||
def update_grad_norms(self):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.update_grad_norms()
|
||||
|
||||
def grad_norms(self) -> Tensor:
|
||||
grad_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "grad_norms") and lora.grad_norms is not None:
|
||||
grad_norms.append(lora.grad_norms.mean(dim=0))
|
||||
return torch.stack(grad_norms) if len(grad_norms) > 0 else torch.tensor([])
|
||||
|
||||
def weight_norms(self) -> Tensor:
|
||||
weight_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "weight_norms") and lora.weight_norms is not None:
|
||||
weight_norms.append(lora.weight_norms.mean(dim=0))
|
||||
return torch.stack(weight_norms) if len(weight_norms) > 0 else torch.tensor([])
|
||||
|
||||
def combined_weight_norms(self) -> Tensor:
|
||||
combined_weight_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None:
|
||||
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
|
||||
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([])
|
||||
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
@@ -69,13 +69,20 @@ class NetworkTrainer:
|
||||
keys_scaled=None,
|
||||
mean_norm=None,
|
||||
maximum_norm=None,
|
||||
mean_grad_norm=None,
|
||||
mean_combined_norm=None
|
||||
):
|
||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||
|
||||
if keys_scaled is not None:
|
||||
logs["max_norm/keys_scaled"] = keys_scaled
|
||||
logs["max_norm/average_key_norm"] = mean_norm
|
||||
logs["max_norm/max_key_norm"] = maximum_norm
|
||||
if mean_norm is not None:
|
||||
logs["norm/avg_key_norm"] = mean_norm
|
||||
if mean_grad_norm is not None:
|
||||
logs["norm/avg_grad_norm"] = mean_grad_norm
|
||||
if mean_combined_norm is not None:
|
||||
logs["norm/avg_combined_norm"] = mean_combined_norm
|
||||
|
||||
lrs = lr_scheduler.get_last_lr()
|
||||
for i, lr in enumerate(lrs):
|
||||
@@ -1400,10 +1407,12 @@ class NetworkTrainer:
|
||||
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
if global_step % 5 == 0:
|
||||
if hasattr(network, "update_grad_norms"):
|
||||
network.update_grad_norms()
|
||||
if hasattr(network, "update_norms"):
|
||||
network.update_norms()
|
||||
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
@@ -1412,9 +1421,23 @@ class NetworkTrainer:
|
||||
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
|
||||
args.scale_weight_norms, accelerator.device
|
||||
)
|
||||
mean_grad_norm = None
|
||||
mean_combined_norm = None
|
||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||
else:
|
||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||
if hasattr(network, "weight_norms"):
|
||||
mean_norm = network.weight_norms().mean().item()
|
||||
mean_grad_norm = network.grad_norms().mean().item()
|
||||
mean_combined_norm = network.combined_weight_norms().mean().item()
|
||||
weight_norms = network.weight_norms()
|
||||
maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None
|
||||
keys_scaled = None
|
||||
max_mean_logs = {"avg weight norm": mean_norm, "avg grad norm": mean_grad_norm, "avg comb norm": mean_combined_norm}
|
||||
else:
|
||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||
mean_grad_norm = None
|
||||
mean_combined_norm = None
|
||||
max_mean_logs = {}
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
@@ -1446,14 +1469,11 @@ class NetworkTrainer:
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if args.scale_weight_norms:
|
||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||
|
||||
if is_tracking:
|
||||
logs = self.generate_step_logs(
|
||||
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
|
||||
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm, mean_grad_norm, mean_combined_norm
|
||||
)
|
||||
self.step_logging(accelerator, logs, global_step, epoch + 1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user