mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add scaling to max norm
This commit is contained in:
@@ -435,9 +435,11 @@ def perlin_noise(noise, device, octaves):
|
|||||||
return noise / noise.std() # Scaled back to roughly unit variance
|
return noise / noise.std() # Scaled back to roughly unit variance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def max_norm(state_dict, max_norm_value):
|
|
||||||
|
def max_norm(state_dict, max_norm_value, device):
|
||||||
downkeys = []
|
downkeys = []
|
||||||
upkeys = []
|
upkeys = []
|
||||||
|
alphakeys = []
|
||||||
norms = []
|
norms = []
|
||||||
keys_scaled = 0
|
keys_scaled = 0
|
||||||
|
|
||||||
@@ -445,15 +447,24 @@ def max_norm(state_dict, max_norm_value):
|
|||||||
if "lora_down" in key and "weight" in key:
|
if "lora_down" in key and "weight" in key:
|
||||||
downkeys.append(key)
|
downkeys.append(key)
|
||||||
upkeys.append(key.replace("lora_down", "lora_up"))
|
upkeys.append(key.replace("lora_down", "lora_up"))
|
||||||
|
alphakeys.append(key.replace("lora_down.weight", "alpha"))
|
||||||
|
|
||||||
for i in range(len(downkeys)):
|
for i in range(len(downkeys)):
|
||||||
down = state_dict[downkeys[i]].cuda()
|
down = state_dict[downkeys[i]].to(device)
|
||||||
up = state_dict[upkeys[i]].cuda()
|
up = state_dict[upkeys[i]].to(device)
|
||||||
|
alpha = state_dict[alphakeys[i]].to(device)
|
||||||
|
dim = down.shape[0]
|
||||||
|
scale = alpha / dim
|
||||||
|
|
||||||
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||||
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||||
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
||||||
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
||||||
else:
|
else:
|
||||||
updown = up @ down
|
updown = up @ down
|
||||||
|
|
||||||
|
updown *= scale
|
||||||
|
|
||||||
norm = updown.norm().clamp(min=max_norm_value / 2)
|
norm = updown.norm().clamp(min=max_norm_value / 2)
|
||||||
desired = torch.clamp(norm, max=max_norm_value)
|
desired = torch.clamp(norm, max=max_norm_value)
|
||||||
ratio = desired.cpu() / norm.cpu()
|
ratio = desired.cpu() / norm.cpu()
|
||||||
@@ -466,4 +477,3 @@ def max_norm(state_dict, max_norm_value):
|
|||||||
norms.append(scalednorm.item())
|
norms.append(scalednorm.item())
|
||||||
|
|
||||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||||
|
|
||||||
|
|||||||
@@ -670,7 +670,7 @@ def train(args):
|
|||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
if args.scale_weight_norms:
|
if args.scale_weight_norms:
|
||||||
keys_scaled, mean_norm, maximum_norm = max_norm(network.state_dict(), args.scale_weight_norms)
|
keys_scaled, mean_norm, maximum_norm = max_norm(network.state_dict(), args.scale_weight_norms, accelerator.device)
|
||||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||||
else:
|
else:
|
||||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||||
|
|||||||
Reference in New Issue
Block a user