From ea53290f625b29c2cfc1c63cc83d6dcd1492731c Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 6 Mar 2025 00:00:38 -0500 Subject: [PATCH 1/4] Add LoRA-GGPO for Flux --- networks/lora_flux.py | 134 +++++++++++++++++++++++++++++++++++++++++- train_network.py | 4 ++ 2 files changed, 137 insertions(+), 1 deletion(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 91e9cd77..98cf8c55 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -9,6 +9,7 @@ import math import os +from contextlib import contextmanager from typing import Dict, List, Optional, Tuple, Type, Union from diffusers import AutoencoderKL from transformers import CLIPTextModel @@ -27,6 +28,42 @@ logger = logging.getLogger(__name__) NUM_DOUBLE_BLOCKS = 19 NUM_SINGLE_BLOCKS = 38 +@contextmanager +def temp_random_seed(seed, device=None): + """ + Context manager that temporarily sets a specific random seed and then + restores the original RNG state afterward. + + Args: + seed (int): The random seed to set temporarily + device (torch.device, optional): The device to set the seed for. + If None, will detect from the current context. + """ + # Save original RNG states + original_cpu_rng_state = torch.get_rng_state() + original_cuda_rng_states = None + if torch.cuda.is_available(): + original_cuda_rng_states = torch.cuda.get_rng_state_all() + + # Determine if we need to set CUDA seed + set_cuda = False + if device is not None: + set_cuda = device.type == 'cuda' + elif torch.cuda.is_available(): + set_cuda = True + + try: + # Set the temporary seed + torch.manual_seed(seed) + if set_cuda: + torch.cuda.manual_seed_all(seed) + yield + finally: + # Restore original RNG states + torch.set_rng_state(original_cpu_rng_state) + if torch.cuda.is_available() and original_cuda_rng_states is not None: + torch.cuda.set_rng_state_all(original_cuda_rng_states) + class LoRAModule(torch.nn.Module): """ @@ -44,6 +81,8 @@ class LoRAModule(torch.nn.Module): rank_dropout=None, module_dropout=None, split_dims: Optional[List[int]] = None, + ggpo_beta: Optional[float] = None, + ggpo_sigma: Optional[float] = None, ): """ if alpha == 0 or None, alpha is rank (no scaling). @@ -103,9 +142,16 @@ class LoRAModule(torch.nn.Module): self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.ggpo_sigma = ggpo_sigma + self.ggpo_beta = ggpo_beta + + self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0]) + self._org_module_weight = self.org_module.weight.detach() + def apply_to(self): self.org_forward = self.org_module.forward self.org_module.forward = self.forward + del self.org_module def forward(self, x): @@ -140,7 +186,15 @@ class LoRAModule(torch.nn.Module): lx = self.lora_up(lx) - return org_forwarded + lx * self.multiplier * scale + # LoRA Gradient-Guided Perturbation Optimization + if self.training and hasattr(self, 'perturbation_seed') and self.ggpo_sigma is not None and self.ggpo_beta is not None: + with torch.no_grad(), torch.autocast(self.device.type), temp_random_seed(self.perturbation_seed): + perturbation = torch.randn_like(self._org_module_weight, dtype=self.dtype, device=self.device) + perturbation.mul_(self.perturbation_scale_factor) + perturbation_output = x @ perturbation.T # Result: (batch × n) + return org_forwarded + (self.multiplier * scale * lx) + perturbation_output + else: + return org_forwarded + lx * self.multiplier * scale else: lxs = [lora_down(x) for lora_down in self.lora_down] @@ -167,6 +221,58 @@ class LoRAModule(torch.nn.Module): return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale + @torch.no_grad() + def update_norms(self): + # Not running GGPO so not currently running update norms + if self.ggpo_beta is None or self.ggpo_sigma is None: + return + + # only update norms when we are training + if self.lora_down.weight.requires_grad is not True: + print(f"skipping update_norms for {self.lora_name}") + return + + lora_down_grad = None + lora_up_grad = None + + for name, param in self.named_parameters(): + if name == "lora_down.weight": + lora_down_grad = param.grad + elif name == "lora_up.weight": + lora_up_grad = param.grad + + with torch.autocast(self.device.type): + module_weights = self.scale * (self.lora_up.weight @ self.lora_down.weight) + org_device = self._org_module_weight.device + org_dtype = self._org_module_weight.dtype + org_weight = self._org_module_weight.to(device=self.device, dtype=self.dtype) + combined_weight = org_weight + module_weights + + self.combined_weight_norms = torch.norm(combined_weight, dim=1, keepdim=True) + + self._org_module_weight.to(device=org_device, dtype=org_dtype) + + + # Calculate gradient norms if we have both gradients + if lora_down_grad is not None and lora_up_grad is not None: + with torch.autocast(self.device.type): + approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight)) + self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True) + + self.perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2)) + self.perturbation_scale_factor = (self.perturbation_scale * self.perturbation_norm_factor).to(self.device) + + # LoRA Gradient-Guided Perturbation Optimization + self.perturbation_seed = torch.randint(0, 2**32 - 1, (1,)).detach().item() + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + class LoRAInfModule(LoRAModule): def __init__( @@ -420,6 +526,16 @@ def create_network( if split_qkv is not None: split_qkv = True if split_qkv == "True" else False + ggpo_beta = kwargs.get("ggpo_beta", None) + ggpo_sigma = kwargs.get("ggpo_sigma", None) + + if ggpo_beta is not None: + ggpo_beta = float(ggpo_beta) + + if ggpo_sigma is not None: + ggpo_sigma = float(ggpo_sigma) + + # train T5XXL train_t5xxl = kwargs.get("train_t5xxl", False) if train_t5xxl is not None: @@ -449,6 +565,8 @@ def create_network( in_dims=in_dims, train_double_block_indices=train_double_block_indices, train_single_block_indices=train_single_block_indices, + ggpo_beta=ggpo_beta, + ggpo_sigma=ggpo_sigma, verbose=verbose, ) @@ -561,6 +679,8 @@ class LoRANetwork(torch.nn.Module): in_dims: Optional[List[int]] = None, train_double_block_indices: Optional[List[bool]] = None, train_single_block_indices: Optional[List[bool]] = None, + ggpo_beta: Optional[float] = None, + ggpo_sigma: Optional[float] = None, verbose: Optional[bool] = False, ) -> None: super().__init__() @@ -599,10 +719,16 @@ class LoRANetwork(torch.nn.Module): # logger.info( # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" # ) + + if ggpo_beta is not None and ggpo_sigma is not None: + logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}") + if self.split_qkv: logger.info(f"split qkv for LoRA") if self.train_blocks is not None: logger.info(f"train {self.train_blocks} blocks only") + + if train_t5xxl: logger.info(f"train T5XXL as well") @@ -722,6 +848,8 @@ class LoRANetwork(torch.nn.Module): rank_dropout=rank_dropout, module_dropout=module_dropout, split_dims=split_dims, + ggpo_beta=ggpo_beta, + ggpo_sigma=ggpo_sigma, ) loras.append(lora) @@ -790,6 +918,10 @@ class LoRANetwork(torch.nn.Module): for lora in self.text_encoder_loras + self.unet_loras: lora.enabled = is_enabled + def update_norms(self): + for lora in self.text_encoder_loras + self.unet_loras: + lora.update_norms() + def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file diff --git a/train_network.py b/train_network.py index 2d279b3b..9db335b0 100644 --- a/train_network.py +++ b/train_network.py @@ -1400,6 +1400,10 @@ class NetworkTrainer: params_to_clip = accelerator.unwrap_model(network).get_trainable_params() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + if global_step % 5 == 0: + if hasattr(network, "update_norms"): + network.update_norms() + optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) From 3647d065b50d74ade3642edd0ec99a2ce1041edf Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 18 Mar 2025 14:25:09 -0400 Subject: [PATCH 2/4] Cache weight norms estimate on initialization. Move to update norms every step --- networks/lora_flux.py | 142 ++++++++++++++++++++++++++++++++++-------- train_network.py | 36 ++++++++--- 2 files changed, 145 insertions(+), 33 deletions(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 98cf8c55..9f5f1916 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -15,6 +15,7 @@ from diffusers import AutoencoderKL from transformers import CLIPTextModel import numpy as np import torch +from torch import Tensor import re from library.utils import setup_logging from library.sdxl_original_unet import SdxlUNet2DConditionModel @@ -145,8 +146,13 @@ class LoRAModule(torch.nn.Module): self.ggpo_sigma = ggpo_sigma self.ggpo_beta = ggpo_beta - self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0]) - self._org_module_weight = self.org_module.weight.detach() + if self.ggpo_beta is not None and self.ggpo_sigma is not None: + self.combined_weight_norms = None + self.grad_norms = None + self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0]) + self.perturbation_seed = torch.randint(0, 2**32 - 1, (1,)).detach().item() + self.initialize_norm_cache(org_module.weight) + self.org_module_shape: tuple[int] = org_module.weight.shape def apply_to(self): self.org_forward = self.org_module.forward @@ -187,10 +193,12 @@ class LoRAModule(torch.nn.Module): lx = self.lora_up(lx) # LoRA Gradient-Guided Perturbation Optimization - if self.training and hasattr(self, 'perturbation_seed') and self.ggpo_sigma is not None and self.ggpo_beta is not None: - with torch.no_grad(), torch.autocast(self.device.type), temp_random_seed(self.perturbation_seed): - perturbation = torch.randn_like(self._org_module_weight, dtype=self.dtype, device=self.device) - perturbation.mul_(self.perturbation_scale_factor) + if self.training and hasattr(self, 'perturbation_seed') and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None: + with torch.no_grad(), temp_random_seed(self.perturbation_seed): + perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2)) + perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device) + perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device) + perturbation.mul_(perturbation_scale_factor) perturbation_output = x @ perturbation.T # Result: (batch × n) return org_forwarded + (self.multiplier * scale * lx) + perturbation_output else: @@ -221,6 +229,69 @@ class LoRAModule(torch.nn.Module): return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale + @torch.no_grad() + def initialize_norm_cache(self, org_module_weight: Tensor): + # Choose a reasonable sample size + n_rows = org_module_weight.shape[0] + sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller + + # Sample random indices across all rows + indices = torch.randperm(n_rows)[:sample_size] + + # Convert to a supported data type first, then index + # Use float32 for indexing operations + weights_float32 = org_module_weight.to(dtype=torch.float32) + sampled_weights = weights_float32[indices].to(device=self.device) + + # Calculate sampled norms + sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True) + + # Store the mean norm as our estimate + self.org_weight_norm_estimate = sampled_norms.mean() + + # Optional: store standard deviation for confidence intervals + self.org_weight_norm_std = sampled_norms.std() + + # Free memory + del sampled_weights, weights_float32 + + @torch.no_grad() + def validate_norm_approximation(self, org_module_weight: Tensor, verbose=True): + # Calculate the true norm (this will be slow but it's just for validation) + true_norms = [] + chunk_size = 1024 # Process in chunks to avoid OOM + + for i in range(0, org_module_weight.shape[0], chunk_size): + end_idx = min(i + chunk_size, org_module_weight.shape[0]) + chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype) + chunk_norms = torch.norm(chunk, dim=1, keepdim=True) + true_norms.append(chunk_norms.cpu()) + del chunk + + true_norms = torch.cat(true_norms, dim=0) + true_mean_norm = true_norms.mean().item() + + # Compare with our estimate + estimated_norm = self.org_weight_norm_estimate.item() + + # Calculate error metrics + absolute_error = abs(true_mean_norm - estimated_norm) + relative_error = absolute_error / true_mean_norm * 100 # as percentage + + if verbose: + logger.info(f"True mean norm: {true_mean_norm:.6f}") + logger.info(f"Estimated norm: {estimated_norm:.6f}") + logger.info(f"Absolute error: {absolute_error:.6f}") + logger.info(f"Relative error: {relative_error:.2f}%") + + return { + 'true_mean_norm': true_mean_norm, + 'estimated_norm': estimated_norm, + 'absolute_error': absolute_error, + 'relative_error': relative_error + } + + @torch.no_grad() def update_norms(self): # Not running GGPO so not currently running update norms @@ -228,8 +299,20 @@ class LoRAModule(torch.nn.Module): return # only update norms when we are training - if self.lora_down.weight.requires_grad is not True: - print(f"skipping update_norms for {self.lora_name}") + if self.training is False: + return + + module_weights = self.lora_up.weight @ self.lora_down.weight + module_weights.mul(self.scale) + + self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True) + self.combined_weight_norms = torch.sqrt((self.org_weight_norm_estimate**2) + + torch.sum(module_weights**2, dim=1, keepdim=True)) + + @torch.no_grad() + def update_grad_norms(self): + if self.training is False: + print(f"skipping update_grad_norms for {self.lora_name}") return lora_down_grad = None @@ -241,29 +324,12 @@ class LoRAModule(torch.nn.Module): elif name == "lora_up.weight": lora_up_grad = param.grad - with torch.autocast(self.device.type): - module_weights = self.scale * (self.lora_up.weight @ self.lora_down.weight) - org_device = self._org_module_weight.device - org_dtype = self._org_module_weight.dtype - org_weight = self._org_module_weight.to(device=self.device, dtype=self.dtype) - combined_weight = org_weight + module_weights - - self.combined_weight_norms = torch.norm(combined_weight, dim=1, keepdim=True) - - self._org_module_weight.to(device=org_device, dtype=org_dtype) - - # Calculate gradient norms if we have both gradients if lora_down_grad is not None and lora_up_grad is not None: with torch.autocast(self.device.type): approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight)) self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True) - self.perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2)) - self.perturbation_scale_factor = (self.perturbation_scale * self.perturbation_norm_factor).to(self.device) - - # LoRA Gradient-Guided Perturbation Optimization - self.perturbation_seed = torch.randint(0, 2**32 - 1, (1,)).detach().item() @property def device(self): @@ -922,6 +988,32 @@ class LoRANetwork(torch.nn.Module): for lora in self.text_encoder_loras + self.unet_loras: lora.update_norms() + def update_grad_norms(self): + for lora in self.text_encoder_loras + self.unet_loras: + lora.update_grad_norms() + + def grad_norms(self) -> Tensor: + grad_norms = [] + for lora in self.text_encoder_loras + self.unet_loras: + if hasattr(lora, "grad_norms") and lora.grad_norms is not None: + grad_norms.append(lora.grad_norms.mean(dim=0)) + return torch.stack(grad_norms) if len(grad_norms) > 0 else torch.tensor([]) + + def weight_norms(self) -> Tensor: + weight_norms = [] + for lora in self.text_encoder_loras + self.unet_loras: + if hasattr(lora, "weight_norms") and lora.weight_norms is not None: + weight_norms.append(lora.weight_norms.mean(dim=0)) + return torch.stack(weight_norms) if len(weight_norms) > 0 else torch.tensor([]) + + def combined_weight_norms(self) -> Tensor: + combined_weight_norms = [] + for lora in self.text_encoder_loras + self.unet_loras: + if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None: + combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0)) + return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([]) + + def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file diff --git a/train_network.py b/train_network.py index 9db335b0..4898e798 100644 --- a/train_network.py +++ b/train_network.py @@ -69,13 +69,20 @@ class NetworkTrainer: keys_scaled=None, mean_norm=None, maximum_norm=None, + mean_grad_norm=None, + mean_combined_norm=None ): logs = {"loss/current": current_loss, "loss/average": avr_loss} if keys_scaled is not None: logs["max_norm/keys_scaled"] = keys_scaled - logs["max_norm/average_key_norm"] = mean_norm logs["max_norm/max_key_norm"] = maximum_norm + if mean_norm is not None: + logs["norm/avg_key_norm"] = mean_norm + if mean_grad_norm is not None: + logs["norm/avg_grad_norm"] = mean_grad_norm + if mean_combined_norm is not None: + logs["norm/avg_combined_norm"] = mean_combined_norm lrs = lr_scheduler.get_last_lr() for i, lr in enumerate(lrs): @@ -1400,10 +1407,12 @@ class NetworkTrainer: params_to_clip = accelerator.unwrap_model(network).get_trainable_params() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - if global_step % 5 == 0: + if hasattr(network, "update_grad_norms"): + network.update_grad_norms() if hasattr(network, "update_norms"): network.update_norms() + optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) @@ -1412,9 +1421,23 @@ class NetworkTrainer: keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization( args.scale_weight_norms, accelerator.device ) + mean_grad_norm = None + mean_combined_norm = None max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} else: - keys_scaled, mean_norm, maximum_norm = None, None, None + if hasattr(network, "weight_norms"): + mean_norm = network.weight_norms().mean().item() + mean_grad_norm = network.grad_norms().mean().item() + mean_combined_norm = network.combined_weight_norms().mean().item() + weight_norms = network.weight_norms() + maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None + keys_scaled = None + max_mean_logs = {"avg weight norm": mean_norm, "avg grad norm": mean_grad_norm, "avg comb norm": mean_combined_norm} + else: + keys_scaled, mean_norm, maximum_norm = None, None, None + mean_grad_norm = None + mean_combined_norm = None + max_mean_logs = {} # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -1446,14 +1469,11 @@ class NetworkTrainer: loss_recorder.add(epoch=epoch, step=step, loss=current_loss) avr_loss: float = loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if args.scale_weight_norms: - progress_bar.set_postfix(**{**max_mean_logs, **logs}) + progress_bar.set_postfix(**{**max_mean_logs, **logs}) if is_tracking: logs = self.generate_step_logs( - args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm + args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm, mean_grad_norm, mean_combined_norm ) self.step_logging(accelerator, logs, global_step, epoch + 1) From 182544dcce383a433527e446bfc7fa8374e375a8 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 26 Mar 2025 14:23:04 -0400 Subject: [PATCH 3/4] Remove pertubation seed --- networks/lora_flux.py | 41 ++--------------------------------------- 1 file changed, 2 insertions(+), 39 deletions(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 9f5f1916..92b3979a 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -29,42 +29,6 @@ logger = logging.getLogger(__name__) NUM_DOUBLE_BLOCKS = 19 NUM_SINGLE_BLOCKS = 38 -@contextmanager -def temp_random_seed(seed, device=None): - """ - Context manager that temporarily sets a specific random seed and then - restores the original RNG state afterward. - - Args: - seed (int): The random seed to set temporarily - device (torch.device, optional): The device to set the seed for. - If None, will detect from the current context. - """ - # Save original RNG states - original_cpu_rng_state = torch.get_rng_state() - original_cuda_rng_states = None - if torch.cuda.is_available(): - original_cuda_rng_states = torch.cuda.get_rng_state_all() - - # Determine if we need to set CUDA seed - set_cuda = False - if device is not None: - set_cuda = device.type == 'cuda' - elif torch.cuda.is_available(): - set_cuda = True - - try: - # Set the temporary seed - torch.manual_seed(seed) - if set_cuda: - torch.cuda.manual_seed_all(seed) - yield - finally: - # Restore original RNG states - torch.set_rng_state(original_cpu_rng_state) - if torch.cuda.is_available() and original_cuda_rng_states is not None: - torch.cuda.set_rng_state_all(original_cuda_rng_states) - class LoRAModule(torch.nn.Module): """ @@ -150,7 +114,6 @@ class LoRAModule(torch.nn.Module): self.combined_weight_norms = None self.grad_norms = None self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0]) - self.perturbation_seed = torch.randint(0, 2**32 - 1, (1,)).detach().item() self.initialize_norm_cache(org_module.weight) self.org_module_shape: tuple[int] = org_module.weight.shape @@ -193,8 +156,8 @@ class LoRAModule(torch.nn.Module): lx = self.lora_up(lx) # LoRA Gradient-Guided Perturbation Optimization - if self.training and hasattr(self, 'perturbation_seed') and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None: - with torch.no_grad(), temp_random_seed(self.perturbation_seed): + if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None: + with torch.no_grad(): perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2)) perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device) perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device) From 0181b7a0425fd58012f7e3ece10345c86d9b6fc8 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 27 Mar 2025 03:28:33 -0400 Subject: [PATCH 4/4] Remove progress bar avg norms --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 4898e798..5b2f377a 100644 --- a/train_network.py +++ b/train_network.py @@ -1432,7 +1432,7 @@ class NetworkTrainer: weight_norms = network.weight_norms() maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None keys_scaled = None - max_mean_logs = {"avg weight norm": mean_norm, "avg grad norm": mean_grad_norm, "avg comb norm": mean_combined_norm} + max_mean_logs = {} else: keys_scaled, mean_norm, maximum_norm = None, None, None mean_grad_norm = None