Dropout and Max Norm Regularization for LoRA training (#545)

* Instantiate max_norm

* minor

* Move to end of step

* argparse

* metadata

* phrasing

* Sqrt ratio and logging

* fix logging

* Dropout test

* Dropout Args

* Dropout changed to affect LoRA only

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
AI-Casanova
2023-06-01 00:58:38 -05:00
committed by GitHub
parent 5931948adb
commit 9c7237157d
4 changed files with 77 additions and 9 deletions

View File

@@ -434,3 +434,36 @@ def perlin_noise(noise, device, octaves):
noise += noise_perlin # broadcast for each batch
return noise / noise.std() # Scaled back to roughly unit variance
"""
def max_norm(state_dict, max_norm_value):
downkeys = []
upkeys = []
norms = []
keys_scaled = 0
for key in state_dict.keys():
if "lora_down" in key and "weight" in key:
downkeys.append(key)
upkeys.append(key.replace("lora_down","lora_up"))
for i in range(len(downkeys)):
down = state_dict[downkeys[i]].cuda()
up = state_dict[upkeys[i]].cuda()
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else:
updown = up @ down
norm = updown.norm().clamp(min=max_norm_value/2)
desired = torch.clamp(norm, max=max_norm_value)
ratio = desired.cpu() / norm.cpu()
sqrt_ratio = ratio **0.5
if ratio != 1:
keys_scaled +=1
state_dict[upkeys[i]] *= sqrt_ratio
state_dict[downkeys[i]] *= sqrt_ratio
scalednorm = updown.norm()*ratio
norms.append(scalednorm.item())
return keys_scaled, sum(norms)/len(norms), max(norms)

View File

@@ -3638,4 +3638,4 @@ class collater_class:
# set epoch and step
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]
return examples[0]