diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 925e7d97..f32f050e 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -435,35 +435,45 @@ def perlin_noise(noise, device, octaves): return noise / noise.std() # Scaled back to roughly unit variance """ -def max_norm(state_dict, max_norm_value): - downkeys = [] - upkeys = [] - norms = [] - keys_scaled = 0 - for key in state_dict.keys(): - if "lora_down" in key and "weight" in key: - downkeys.append(key) - upkeys.append(key.replace("lora_down","lora_up")) - for i in range(len(downkeys)): - down = state_dict[downkeys[i]].cuda() - up = state_dict[upkeys[i]].cuda() - if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): - updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) - elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): - updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) - else: - updown = up @ down - norm = updown.norm().clamp(min=max_norm_value/2) - desired = torch.clamp(norm, max=max_norm_value) - ratio = desired.cpu() / norm.cpu() - sqrt_ratio = ratio **0.5 - if ratio != 1: - keys_scaled +=1 - state_dict[upkeys[i]] *= sqrt_ratio - state_dict[downkeys[i]] *= sqrt_ratio - scalednorm = updown.norm()*ratio - norms.append(scalednorm.item()) - - return keys_scaled, sum(norms)/len(norms), max(norms) +def max_norm(state_dict, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/train_network.py b/train_network.py index edbc915e..3a9c474d 100644 --- a/train_network.py +++ b/train_network.py @@ -670,7 +670,7 @@ def train(args): optimizer.zero_grad(set_to_none=True) if args.scale_weight_norms: - keys_scaled, mean_norm, maximum_norm = max_norm(network.state_dict(), args.scale_weight_norms) + keys_scaled, mean_norm, maximum_norm = max_norm(network.state_dict(), args.scale_weight_norms, accelerator.device) max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} else: keys_scaled, mean_norm, maximum_norm = None, None, None