mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Cache weight norms estimate on initialization. Move to update norms every step
This commit is contained in:
@@ -15,6 +15,7 @@ from diffusers import AutoencoderKL
|
|||||||
from transformers import CLIPTextModel
|
from transformers import CLIPTextModel
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
import re
|
import re
|
||||||
from library.utils import setup_logging
|
from library.utils import setup_logging
|
||||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||||
@@ -145,8 +146,13 @@ class LoRAModule(torch.nn.Module):
|
|||||||
self.ggpo_sigma = ggpo_sigma
|
self.ggpo_sigma = ggpo_sigma
|
||||||
self.ggpo_beta = ggpo_beta
|
self.ggpo_beta = ggpo_beta
|
||||||
|
|
||||||
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
|
if self.ggpo_beta is not None and self.ggpo_sigma is not None:
|
||||||
self._org_module_weight = self.org_module.weight.detach()
|
self.combined_weight_norms = None
|
||||||
|
self.grad_norms = None
|
||||||
|
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
|
||||||
|
self.perturbation_seed = torch.randint(0, 2**32 - 1, (1,)).detach().item()
|
||||||
|
self.initialize_norm_cache(org_module.weight)
|
||||||
|
self.org_module_shape: tuple[int] = org_module.weight.shape
|
||||||
|
|
||||||
def apply_to(self):
|
def apply_to(self):
|
||||||
self.org_forward = self.org_module.forward
|
self.org_forward = self.org_module.forward
|
||||||
@@ -187,10 +193,12 @@ class LoRAModule(torch.nn.Module):
|
|||||||
lx = self.lora_up(lx)
|
lx = self.lora_up(lx)
|
||||||
|
|
||||||
# LoRA Gradient-Guided Perturbation Optimization
|
# LoRA Gradient-Guided Perturbation Optimization
|
||||||
if self.training and hasattr(self, 'perturbation_seed') and self.ggpo_sigma is not None and self.ggpo_beta is not None:
|
if self.training and hasattr(self, 'perturbation_seed') and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
|
||||||
with torch.no_grad(), torch.autocast(self.device.type), temp_random_seed(self.perturbation_seed):
|
with torch.no_grad(), temp_random_seed(self.perturbation_seed):
|
||||||
perturbation = torch.randn_like(self._org_module_weight, dtype=self.dtype, device=self.device)
|
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
|
||||||
perturbation.mul_(self.perturbation_scale_factor)
|
perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device)
|
||||||
|
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
|
||||||
|
perturbation.mul_(perturbation_scale_factor)
|
||||||
perturbation_output = x @ perturbation.T # Result: (batch × n)
|
perturbation_output = x @ perturbation.T # Result: (batch × n)
|
||||||
return org_forwarded + (self.multiplier * scale * lx) + perturbation_output
|
return org_forwarded + (self.multiplier * scale * lx) + perturbation_output
|
||||||
else:
|
else:
|
||||||
@@ -221,6 +229,69 @@ class LoRAModule(torch.nn.Module):
|
|||||||
|
|
||||||
return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
|
return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def initialize_norm_cache(self, org_module_weight: Tensor):
|
||||||
|
# Choose a reasonable sample size
|
||||||
|
n_rows = org_module_weight.shape[0]
|
||||||
|
sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller
|
||||||
|
|
||||||
|
# Sample random indices across all rows
|
||||||
|
indices = torch.randperm(n_rows)[:sample_size]
|
||||||
|
|
||||||
|
# Convert to a supported data type first, then index
|
||||||
|
# Use float32 for indexing operations
|
||||||
|
weights_float32 = org_module_weight.to(dtype=torch.float32)
|
||||||
|
sampled_weights = weights_float32[indices].to(device=self.device)
|
||||||
|
|
||||||
|
# Calculate sampled norms
|
||||||
|
sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True)
|
||||||
|
|
||||||
|
# Store the mean norm as our estimate
|
||||||
|
self.org_weight_norm_estimate = sampled_norms.mean()
|
||||||
|
|
||||||
|
# Optional: store standard deviation for confidence intervals
|
||||||
|
self.org_weight_norm_std = sampled_norms.std()
|
||||||
|
|
||||||
|
# Free memory
|
||||||
|
del sampled_weights, weights_float32
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def validate_norm_approximation(self, org_module_weight: Tensor, verbose=True):
|
||||||
|
# Calculate the true norm (this will be slow but it's just for validation)
|
||||||
|
true_norms = []
|
||||||
|
chunk_size = 1024 # Process in chunks to avoid OOM
|
||||||
|
|
||||||
|
for i in range(0, org_module_weight.shape[0], chunk_size):
|
||||||
|
end_idx = min(i + chunk_size, org_module_weight.shape[0])
|
||||||
|
chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype)
|
||||||
|
chunk_norms = torch.norm(chunk, dim=1, keepdim=True)
|
||||||
|
true_norms.append(chunk_norms.cpu())
|
||||||
|
del chunk
|
||||||
|
|
||||||
|
true_norms = torch.cat(true_norms, dim=0)
|
||||||
|
true_mean_norm = true_norms.mean().item()
|
||||||
|
|
||||||
|
# Compare with our estimate
|
||||||
|
estimated_norm = self.org_weight_norm_estimate.item()
|
||||||
|
|
||||||
|
# Calculate error metrics
|
||||||
|
absolute_error = abs(true_mean_norm - estimated_norm)
|
||||||
|
relative_error = absolute_error / true_mean_norm * 100 # as percentage
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
logger.info(f"True mean norm: {true_mean_norm:.6f}")
|
||||||
|
logger.info(f"Estimated norm: {estimated_norm:.6f}")
|
||||||
|
logger.info(f"Absolute error: {absolute_error:.6f}")
|
||||||
|
logger.info(f"Relative error: {relative_error:.2f}%")
|
||||||
|
|
||||||
|
return {
|
||||||
|
'true_mean_norm': true_mean_norm,
|
||||||
|
'estimated_norm': estimated_norm,
|
||||||
|
'absolute_error': absolute_error,
|
||||||
|
'relative_error': relative_error
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def update_norms(self):
|
def update_norms(self):
|
||||||
# Not running GGPO so not currently running update norms
|
# Not running GGPO so not currently running update norms
|
||||||
@@ -228,8 +299,20 @@ class LoRAModule(torch.nn.Module):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# only update norms when we are training
|
# only update norms when we are training
|
||||||
if self.lora_down.weight.requires_grad is not True:
|
if self.training is False:
|
||||||
print(f"skipping update_norms for {self.lora_name}")
|
return
|
||||||
|
|
||||||
|
module_weights = self.lora_up.weight @ self.lora_down.weight
|
||||||
|
module_weights.mul(self.scale)
|
||||||
|
|
||||||
|
self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True)
|
||||||
|
self.combined_weight_norms = torch.sqrt((self.org_weight_norm_estimate**2) +
|
||||||
|
torch.sum(module_weights**2, dim=1, keepdim=True))
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def update_grad_norms(self):
|
||||||
|
if self.training is False:
|
||||||
|
print(f"skipping update_grad_norms for {self.lora_name}")
|
||||||
return
|
return
|
||||||
|
|
||||||
lora_down_grad = None
|
lora_down_grad = None
|
||||||
@@ -241,29 +324,12 @@ class LoRAModule(torch.nn.Module):
|
|||||||
elif name == "lora_up.weight":
|
elif name == "lora_up.weight":
|
||||||
lora_up_grad = param.grad
|
lora_up_grad = param.grad
|
||||||
|
|
||||||
with torch.autocast(self.device.type):
|
|
||||||
module_weights = self.scale * (self.lora_up.weight @ self.lora_down.weight)
|
|
||||||
org_device = self._org_module_weight.device
|
|
||||||
org_dtype = self._org_module_weight.dtype
|
|
||||||
org_weight = self._org_module_weight.to(device=self.device, dtype=self.dtype)
|
|
||||||
combined_weight = org_weight + module_weights
|
|
||||||
|
|
||||||
self.combined_weight_norms = torch.norm(combined_weight, dim=1, keepdim=True)
|
|
||||||
|
|
||||||
self._org_module_weight.to(device=org_device, dtype=org_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
# Calculate gradient norms if we have both gradients
|
# Calculate gradient norms if we have both gradients
|
||||||
if lora_down_grad is not None and lora_up_grad is not None:
|
if lora_down_grad is not None and lora_up_grad is not None:
|
||||||
with torch.autocast(self.device.type):
|
with torch.autocast(self.device.type):
|
||||||
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
|
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
|
||||||
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)
|
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)
|
||||||
|
|
||||||
self.perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
|
|
||||||
self.perturbation_scale_factor = (self.perturbation_scale * self.perturbation_norm_factor).to(self.device)
|
|
||||||
|
|
||||||
# LoRA Gradient-Guided Perturbation Optimization
|
|
||||||
self.perturbation_seed = torch.randint(0, 2**32 - 1, (1,)).detach().item()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
@@ -922,6 +988,32 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
lora.update_norms()
|
lora.update_norms()
|
||||||
|
|
||||||
|
def update_grad_norms(self):
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
lora.update_grad_norms()
|
||||||
|
|
||||||
|
def grad_norms(self) -> Tensor:
|
||||||
|
grad_norms = []
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
if hasattr(lora, "grad_norms") and lora.grad_norms is not None:
|
||||||
|
grad_norms.append(lora.grad_norms.mean(dim=0))
|
||||||
|
return torch.stack(grad_norms) if len(grad_norms) > 0 else torch.tensor([])
|
||||||
|
|
||||||
|
def weight_norms(self) -> Tensor:
|
||||||
|
weight_norms = []
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
if hasattr(lora, "weight_norms") and lora.weight_norms is not None:
|
||||||
|
weight_norms.append(lora.weight_norms.mean(dim=0))
|
||||||
|
return torch.stack(weight_norms) if len(weight_norms) > 0 else torch.tensor([])
|
||||||
|
|
||||||
|
def combined_weight_norms(self) -> Tensor:
|
||||||
|
combined_weight_norms = []
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None:
|
||||||
|
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
|
||||||
|
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([])
|
||||||
|
|
||||||
|
|
||||||
def load_weights(self, file):
|
def load_weights(self, file):
|
||||||
if os.path.splitext(file)[1] == ".safetensors":
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|||||||
@@ -69,13 +69,20 @@ class NetworkTrainer:
|
|||||||
keys_scaled=None,
|
keys_scaled=None,
|
||||||
mean_norm=None,
|
mean_norm=None,
|
||||||
maximum_norm=None,
|
maximum_norm=None,
|
||||||
|
mean_grad_norm=None,
|
||||||
|
mean_combined_norm=None
|
||||||
):
|
):
|
||||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||||
|
|
||||||
if keys_scaled is not None:
|
if keys_scaled is not None:
|
||||||
logs["max_norm/keys_scaled"] = keys_scaled
|
logs["max_norm/keys_scaled"] = keys_scaled
|
||||||
logs["max_norm/average_key_norm"] = mean_norm
|
|
||||||
logs["max_norm/max_key_norm"] = maximum_norm
|
logs["max_norm/max_key_norm"] = maximum_norm
|
||||||
|
if mean_norm is not None:
|
||||||
|
logs["norm/avg_key_norm"] = mean_norm
|
||||||
|
if mean_grad_norm is not None:
|
||||||
|
logs["norm/avg_grad_norm"] = mean_grad_norm
|
||||||
|
if mean_combined_norm is not None:
|
||||||
|
logs["norm/avg_combined_norm"] = mean_combined_norm
|
||||||
|
|
||||||
lrs = lr_scheduler.get_last_lr()
|
lrs = lr_scheduler.get_last_lr()
|
||||||
for i, lr in enumerate(lrs):
|
for i, lr in enumerate(lrs):
|
||||||
@@ -1400,10 +1407,12 @@ class NetworkTrainer:
|
|||||||
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
|
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
|
||||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||||
|
|
||||||
if global_step % 5 == 0:
|
if hasattr(network, "update_grad_norms"):
|
||||||
|
network.update_grad_norms()
|
||||||
if hasattr(network, "update_norms"):
|
if hasattr(network, "update_norms"):
|
||||||
network.update_norms()
|
network.update_norms()
|
||||||
|
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
@@ -1412,9 +1421,23 @@ class NetworkTrainer:
|
|||||||
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
|
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
|
||||||
args.scale_weight_norms, accelerator.device
|
args.scale_weight_norms, accelerator.device
|
||||||
)
|
)
|
||||||
|
mean_grad_norm = None
|
||||||
|
mean_combined_norm = None
|
||||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||||
else:
|
else:
|
||||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
if hasattr(network, "weight_norms"):
|
||||||
|
mean_norm = network.weight_norms().mean().item()
|
||||||
|
mean_grad_norm = network.grad_norms().mean().item()
|
||||||
|
mean_combined_norm = network.combined_weight_norms().mean().item()
|
||||||
|
weight_norms = network.weight_norms()
|
||||||
|
maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None
|
||||||
|
keys_scaled = None
|
||||||
|
max_mean_logs = {"avg weight norm": mean_norm, "avg grad norm": mean_grad_norm, "avg comb norm": mean_combined_norm}
|
||||||
|
else:
|
||||||
|
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||||
|
mean_grad_norm = None
|
||||||
|
mean_combined_norm = None
|
||||||
|
max_mean_logs = {}
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
@@ -1446,14 +1469,11 @@ class NetworkTrainer:
|
|||||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||||
avr_loss: float = loss_recorder.moving_average
|
avr_loss: float = loss_recorder.moving_average
|
||||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||||
|
|
||||||
if args.scale_weight_norms:
|
|
||||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
|
||||||
|
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
logs = self.generate_step_logs(
|
logs = self.generate_step_logs(
|
||||||
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
|
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm, mean_grad_norm, mean_combined_norm
|
||||||
)
|
)
|
||||||
self.step_logging(accelerator, logs, global_step, epoch + 1)
|
self.step_logging(accelerator, logs, global_step, epoch + 1)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user