Remove pertubation seed

This commit is contained in:
rockerBOO
2025-03-26 14:23:04 -04:00
parent 3647d065b5
commit 182544dcce

View File

@@ -29,42 +29,6 @@ logger = logging.getLogger(__name__)
NUM_DOUBLE_BLOCKS = 19
NUM_SINGLE_BLOCKS = 38
@contextmanager
def temp_random_seed(seed, device=None):
"""
Context manager that temporarily sets a specific random seed and then
restores the original RNG state afterward.
Args:
seed (int): The random seed to set temporarily
device (torch.device, optional): The device to set the seed for.
If None, will detect from the current context.
"""
# Save original RNG states
original_cpu_rng_state = torch.get_rng_state()
original_cuda_rng_states = None
if torch.cuda.is_available():
original_cuda_rng_states = torch.cuda.get_rng_state_all()
# Determine if we need to set CUDA seed
set_cuda = False
if device is not None:
set_cuda = device.type == 'cuda'
elif torch.cuda.is_available():
set_cuda = True
try:
# Set the temporary seed
torch.manual_seed(seed)
if set_cuda:
torch.cuda.manual_seed_all(seed)
yield
finally:
# Restore original RNG states
torch.set_rng_state(original_cpu_rng_state)
if torch.cuda.is_available() and original_cuda_rng_states is not None:
torch.cuda.set_rng_state_all(original_cuda_rng_states)
class LoRAModule(torch.nn.Module):
"""
@@ -150,7 +114,6 @@ class LoRAModule(torch.nn.Module):
self.combined_weight_norms = None
self.grad_norms = None
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
self.perturbation_seed = torch.randint(0, 2**32 - 1, (1,)).detach().item()
self.initialize_norm_cache(org_module.weight)
self.org_module_shape: tuple[int] = org_module.weight.shape
@@ -193,8 +156,8 @@ class LoRAModule(torch.nn.Module):
lx = self.lora_up(lx)
# LoRA Gradient-Guided Perturbation Optimization
if self.training and hasattr(self, 'perturbation_seed') and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
with torch.no_grad(), temp_random_seed(self.perturbation_seed):
if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
with torch.no_grad():
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device)
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)