mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Remove pertubation seed
This commit is contained in:
@@ -29,42 +29,6 @@ logger = logging.getLogger(__name__)
|
|||||||
NUM_DOUBLE_BLOCKS = 19
|
NUM_DOUBLE_BLOCKS = 19
|
||||||
NUM_SINGLE_BLOCKS = 38
|
NUM_SINGLE_BLOCKS = 38
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def temp_random_seed(seed, device=None):
|
|
||||||
"""
|
|
||||||
Context manager that temporarily sets a specific random seed and then
|
|
||||||
restores the original RNG state afterward.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seed (int): The random seed to set temporarily
|
|
||||||
device (torch.device, optional): The device to set the seed for.
|
|
||||||
If None, will detect from the current context.
|
|
||||||
"""
|
|
||||||
# Save original RNG states
|
|
||||||
original_cpu_rng_state = torch.get_rng_state()
|
|
||||||
original_cuda_rng_states = None
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
original_cuda_rng_states = torch.cuda.get_rng_state_all()
|
|
||||||
|
|
||||||
# Determine if we need to set CUDA seed
|
|
||||||
set_cuda = False
|
|
||||||
if device is not None:
|
|
||||||
set_cuda = device.type == 'cuda'
|
|
||||||
elif torch.cuda.is_available():
|
|
||||||
set_cuda = True
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Set the temporary seed
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
if set_cuda:
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
# Restore original RNG states
|
|
||||||
torch.set_rng_state(original_cpu_rng_state)
|
|
||||||
if torch.cuda.is_available() and original_cuda_rng_states is not None:
|
|
||||||
torch.cuda.set_rng_state_all(original_cuda_rng_states)
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAModule(torch.nn.Module):
|
class LoRAModule(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -150,7 +114,6 @@ class LoRAModule(torch.nn.Module):
|
|||||||
self.combined_weight_norms = None
|
self.combined_weight_norms = None
|
||||||
self.grad_norms = None
|
self.grad_norms = None
|
||||||
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
|
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
|
||||||
self.perturbation_seed = torch.randint(0, 2**32 - 1, (1,)).detach().item()
|
|
||||||
self.initialize_norm_cache(org_module.weight)
|
self.initialize_norm_cache(org_module.weight)
|
||||||
self.org_module_shape: tuple[int] = org_module.weight.shape
|
self.org_module_shape: tuple[int] = org_module.weight.shape
|
||||||
|
|
||||||
@@ -193,8 +156,8 @@ class LoRAModule(torch.nn.Module):
|
|||||||
lx = self.lora_up(lx)
|
lx = self.lora_up(lx)
|
||||||
|
|
||||||
# LoRA Gradient-Guided Perturbation Optimization
|
# LoRA Gradient-Guided Perturbation Optimization
|
||||||
if self.training and hasattr(self, 'perturbation_seed') and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
|
if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
|
||||||
with torch.no_grad(), temp_random_seed(self.perturbation_seed):
|
with torch.no_grad():
|
||||||
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
|
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
|
||||||
perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device)
|
perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device)
|
||||||
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
|
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
|
||||||
|
|||||||
Reference in New Issue
Block a user