From 182544dcce383a433527e446bfc7fa8374e375a8 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 26 Mar 2025 14:23:04 -0400 Subject: [PATCH] Remove pertubation seed --- networks/lora_flux.py | 41 ++--------------------------------------- 1 file changed, 2 insertions(+), 39 deletions(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 9f5f1916..92b3979a 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -29,42 +29,6 @@ logger = logging.getLogger(__name__) NUM_DOUBLE_BLOCKS = 19 NUM_SINGLE_BLOCKS = 38 -@contextmanager -def temp_random_seed(seed, device=None): - """ - Context manager that temporarily sets a specific random seed and then - restores the original RNG state afterward. - - Args: - seed (int): The random seed to set temporarily - device (torch.device, optional): The device to set the seed for. - If None, will detect from the current context. - """ - # Save original RNG states - original_cpu_rng_state = torch.get_rng_state() - original_cuda_rng_states = None - if torch.cuda.is_available(): - original_cuda_rng_states = torch.cuda.get_rng_state_all() - - # Determine if we need to set CUDA seed - set_cuda = False - if device is not None: - set_cuda = device.type == 'cuda' - elif torch.cuda.is_available(): - set_cuda = True - - try: - # Set the temporary seed - torch.manual_seed(seed) - if set_cuda: - torch.cuda.manual_seed_all(seed) - yield - finally: - # Restore original RNG states - torch.set_rng_state(original_cpu_rng_state) - if torch.cuda.is_available() and original_cuda_rng_states is not None: - torch.cuda.set_rng_state_all(original_cuda_rng_states) - class LoRAModule(torch.nn.Module): """ @@ -150,7 +114,6 @@ class LoRAModule(torch.nn.Module): self.combined_weight_norms = None self.grad_norms = None self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0]) - self.perturbation_seed = torch.randint(0, 2**32 - 1, (1,)).detach().item() self.initialize_norm_cache(org_module.weight) self.org_module_shape: tuple[int] = org_module.weight.shape @@ -193,8 +156,8 @@ class LoRAModule(torch.nn.Module): lx = self.lora_up(lx) # LoRA Gradient-Guided Perturbation Optimization - if self.training and hasattr(self, 'perturbation_seed') and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None: - with torch.no_grad(), temp_random_seed(self.perturbation_seed): + if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None: + with torch.no_grad(): perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2)) perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device) perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)