mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Dropout and Max Norm Regularization for LoRA training (#545)
* Instantiate max_norm * minor * Move to end of step * argparse * metadata * phrasing * Sqrt ratio and logging * fix logging * Dropout test * Dropout Args * Dropout changed to affect LoRA only --------- Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
@@ -434,3 +434,36 @@ def perlin_noise(noise, device, octaves):
|
||||
noise += noise_perlin # broadcast for each batch
|
||||
return noise / noise.std() # Scaled back to roughly unit variance
|
||||
"""
|
||||
|
||||
def max_norm(state_dict, max_norm_value):
|
||||
downkeys = []
|
||||
upkeys = []
|
||||
norms = []
|
||||
keys_scaled = 0
|
||||
|
||||
for key in state_dict.keys():
|
||||
if "lora_down" in key and "weight" in key:
|
||||
downkeys.append(key)
|
||||
upkeys.append(key.replace("lora_down","lora_up"))
|
||||
for i in range(len(downkeys)):
|
||||
down = state_dict[downkeys[i]].cuda()
|
||||
up = state_dict[upkeys[i]].cuda()
|
||||
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
||||
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
||||
else:
|
||||
updown = up @ down
|
||||
norm = updown.norm().clamp(min=max_norm_value/2)
|
||||
desired = torch.clamp(norm, max=max_norm_value)
|
||||
ratio = desired.cpu() / norm.cpu()
|
||||
sqrt_ratio = ratio **0.5
|
||||
if ratio != 1:
|
||||
keys_scaled +=1
|
||||
state_dict[upkeys[i]] *= sqrt_ratio
|
||||
state_dict[downkeys[i]] *= sqrt_ratio
|
||||
scalednorm = updown.norm()*ratio
|
||||
norms.append(scalednorm.item())
|
||||
|
||||
return keys_scaled, sum(norms)/len(norms), max(norms)
|
||||
|
||||
|
||||
@@ -3638,4 +3638,4 @@ class collater_class:
|
||||
# set epoch and step
|
||||
dataset.set_current_epoch(self.current_epoch.value)
|
||||
dataset.set_current_step(self.current_step.value)
|
||||
return examples[0]
|
||||
return examples[0]
|
||||
Reference in New Issue
Block a user