mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Remove pertubation seed
This commit is contained in:
@@ -29,42 +29,6 @@ logger = logging.getLogger(__name__)
|
||||
NUM_DOUBLE_BLOCKS = 19
|
||||
NUM_SINGLE_BLOCKS = 38
|
||||
|
||||
@contextmanager
|
||||
def temp_random_seed(seed, device=None):
|
||||
"""
|
||||
Context manager that temporarily sets a specific random seed and then
|
||||
restores the original RNG state afterward.
|
||||
|
||||
Args:
|
||||
seed (int): The random seed to set temporarily
|
||||
device (torch.device, optional): The device to set the seed for.
|
||||
If None, will detect from the current context.
|
||||
"""
|
||||
# Save original RNG states
|
||||
original_cpu_rng_state = torch.get_rng_state()
|
||||
original_cuda_rng_states = None
|
||||
if torch.cuda.is_available():
|
||||
original_cuda_rng_states = torch.cuda.get_rng_state_all()
|
||||
|
||||
# Determine if we need to set CUDA seed
|
||||
set_cuda = False
|
||||
if device is not None:
|
||||
set_cuda = device.type == 'cuda'
|
||||
elif torch.cuda.is_available():
|
||||
set_cuda = True
|
||||
|
||||
try:
|
||||
# Set the temporary seed
|
||||
torch.manual_seed(seed)
|
||||
if set_cuda:
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
yield
|
||||
finally:
|
||||
# Restore original RNG states
|
||||
torch.set_rng_state(original_cpu_rng_state)
|
||||
if torch.cuda.is_available() and original_cuda_rng_states is not None:
|
||||
torch.cuda.set_rng_state_all(original_cuda_rng_states)
|
||||
|
||||
|
||||
class LoRAModule(torch.nn.Module):
|
||||
"""
|
||||
@@ -150,7 +114,6 @@ class LoRAModule(torch.nn.Module):
|
||||
self.combined_weight_norms = None
|
||||
self.grad_norms = None
|
||||
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
|
||||
self.perturbation_seed = torch.randint(0, 2**32 - 1, (1,)).detach().item()
|
||||
self.initialize_norm_cache(org_module.weight)
|
||||
self.org_module_shape: tuple[int] = org_module.weight.shape
|
||||
|
||||
@@ -193,8 +156,8 @@ class LoRAModule(torch.nn.Module):
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
# LoRA Gradient-Guided Perturbation Optimization
|
||||
if self.training and hasattr(self, 'perturbation_seed') and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
|
||||
with torch.no_grad(), temp_random_seed(self.perturbation_seed):
|
||||
if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
|
||||
with torch.no_grad():
|
||||
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
|
||||
perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device)
|
||||
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
|
||||
|
||||
Reference in New Issue
Block a user