mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
move max_norm to lora to avoid crashing in lycoris
This commit is contained in:
@@ -145,6 +145,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
|||||||
- Max Norm Regularization is now available in `train_network.py`. [PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) Thanks to AI-Casanova!
|
- Max Norm Regularization is now available in `train_network.py`. [PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) Thanks to AI-Casanova!
|
||||||
- Max Norm Regularization is a technique to stabilize network training by limiting the norm of network weights. It may be effective in suppressing overfitting of LoRA and improving stability when used with other LoRAs. See PR for details.
|
- Max Norm Regularization is a technique to stabilize network training by limiting the norm of network weights. It may be effective in suppressing overfitting of LoRA and improving stability when used with other LoRAs. See PR for details.
|
||||||
- Specify as `--scale_weight_norms=1.0`. It seems good to try from `1.0`.
|
- Specify as `--scale_weight_norms=1.0`. It seems good to try from `1.0`.
|
||||||
|
- The networks other than LoRA in this repository (such as LyCORIS) do not support this option.
|
||||||
|
|
||||||
- Three types of dropout have been added to `train_network.py` and LoRA network.
|
- Three types of dropout have been added to `train_network.py` and LoRA network.
|
||||||
- Dropout is a technique to suppress overfitting and improve network performance by randomly setting some of the network outputs to 0.
|
- Dropout is a technique to suppress overfitting and improve network performance by randomly setting some of the network outputs to 0.
|
||||||
@@ -156,6 +157,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
|||||||
- `--network_dropout`, `rank_dropout`, and `module_dropout` can be specified at the same time.
|
- `--network_dropout`, `rank_dropout`, and `module_dropout` can be specified at the same time.
|
||||||
- Values of 0.1 to 0.3 may be good to try. Values greater than 0.5 should not be specified.
|
- Values of 0.1 to 0.3 may be good to try. Values greater than 0.5 should not be specified.
|
||||||
- `rank_dropout` and `module_dropout` are original techniques of this repository. Their effectiveness has not been verified yet.
|
- `rank_dropout` and `module_dropout` are original techniques of this repository. Their effectiveness has not been verified yet.
|
||||||
|
- The networks other than LoRA in this repository (such as LyCORIS) do not support these options.
|
||||||
|
|
||||||
- Added an option `--scale_v_pred_loss_like_noise_pred` to scale v-prediction loss like noise prediction in each training script.
|
- Added an option `--scale_v_pred_loss_like_noise_pred` to scale v-prediction loss like noise prediction in each training script.
|
||||||
- By scaling the loss according to the time step, the weights of global noise prediction and local noise prediction become the same, and the improvement of details may be expected.
|
- By scaling the loss according to the time step, the weights of global noise prediction and local noise prediction become the same, and the improvement of details may be expected.
|
||||||
@@ -164,6 +166,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
|||||||
- Max Norm Regularizationが`train_network.py`で使えるようになりました。[PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) AI-Casanova氏に感謝します。
|
- Max Norm Regularizationが`train_network.py`で使えるようになりました。[PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) AI-Casanova氏に感謝します。
|
||||||
- Max Norm Regularizationは、ネットワークの重みのノルムを制限することで、ネットワークの学習を安定させる手法です。LoRAの過学習の抑制、他のLoRAと併用した時の安定性の向上が期待できるかもしれません。詳細はPRを参照してください。
|
- Max Norm Regularizationは、ネットワークの重みのノルムを制限することで、ネットワークの学習を安定させる手法です。LoRAの過学習の抑制、他のLoRAと併用した時の安定性の向上が期待できるかもしれません。詳細はPRを参照してください。
|
||||||
- `--scale_weight_norms=1.0`のように `--scale_weight_norms` で指定してください。`1.0`から試すと良いようです。
|
- `--scale_weight_norms=1.0`のように `--scale_weight_norms` で指定してください。`1.0`から試すと良いようです。
|
||||||
|
- LyCORIS等、当リポジトリ以外のネットワークは現時点では未対応です。
|
||||||
|
|
||||||
- `train_network.py` およびLoRAに計三種類のdropoutを追加しました。
|
- `train_network.py` およびLoRAに計三種類のdropoutを追加しました。
|
||||||
- dropoutはネットワークの一部の出力をランダムに0にすることで、過学習の抑制、ネットワークの性能向上等を図る手法です。
|
- dropoutはネットワークの一部の出力をランダムに0にすることで、過学習の抑制、ネットワークの性能向上等を図る手法です。
|
||||||
@@ -175,6 +178,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
|||||||
- `--network_dropout`、`rank_dropout` 、 `module_dropout` は同時に指定できます。
|
- `--network_dropout`、`rank_dropout` 、 `module_dropout` は同時に指定できます。
|
||||||
- それぞれの値は0.1~0.3程度から試してみると良いかもしれません。0.5を超える値は指定しない方が良いでしょう。
|
- それぞれの値は0.1~0.3程度から試してみると良いかもしれません。0.5を超える値は指定しない方が良いでしょう。
|
||||||
- `rank_dropout`および`module_dropout`は当リポジトリ独自の手法です。有効性の検証はまだ行っていません。
|
- `rank_dropout`および`module_dropout`は当リポジトリ独自の手法です。有効性の検証はまだ行っていません。
|
||||||
|
- これらのdropoutはLyCORIS等、当リポジトリ以外のネットワークは現時点では未対応です。
|
||||||
|
|
||||||
- 各学習スクリプトにv-prediction lossをnoise predictionと同様の値にスケールするオプション`--scale_v_pred_loss_like_noise_pred`を追加しました。
|
- 各学習スクリプトにv-prediction lossをnoise predictionと同様の値にスケールするオプション`--scale_v_pred_loss_like_noise_pred`を追加しました。
|
||||||
- タイムステップに応じてlossをスケールすることで、 大域的なノイズの予測と局所的なノイズの予測の重みが同じになり、ディテールの改善が期待できるかもしれません。
|
- タイムステップに応じてlossをスケールすることで、 大域的なノイズの予測と局所的なノイズの予測の重みが同じになり、ディテールの改善が期待できるかもしれません。
|
||||||
|
|||||||
@@ -456,46 +456,3 @@ def perlin_noise(noise, device, octaves):
|
|||||||
noise += noise_perlin # broadcast for each batch
|
noise += noise_perlin # broadcast for each batch
|
||||||
return noise / noise.std() # Scaled back to roughly unit variance
|
return noise / noise.std() # Scaled back to roughly unit variance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def max_norm(state_dict, max_norm_value, device):
|
|
||||||
downkeys = []
|
|
||||||
upkeys = []
|
|
||||||
alphakeys = []
|
|
||||||
norms = []
|
|
||||||
keys_scaled = 0
|
|
||||||
|
|
||||||
for key in state_dict.keys():
|
|
||||||
if "lora_down" in key and "weight" in key:
|
|
||||||
downkeys.append(key)
|
|
||||||
upkeys.append(key.replace("lora_down", "lora_up"))
|
|
||||||
alphakeys.append(key.replace("lora_down.weight", "alpha"))
|
|
||||||
|
|
||||||
for i in range(len(downkeys)):
|
|
||||||
down = state_dict[downkeys[i]].to(device)
|
|
||||||
up = state_dict[upkeys[i]].to(device)
|
|
||||||
alpha = state_dict[alphakeys[i]].to(device)
|
|
||||||
dim = down.shape[0]
|
|
||||||
scale = alpha / dim
|
|
||||||
|
|
||||||
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
|
||||||
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
|
||||||
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
|
||||||
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
|
||||||
else:
|
|
||||||
updown = up @ down
|
|
||||||
|
|
||||||
updown *= scale
|
|
||||||
|
|
||||||
norm = updown.norm().clamp(min=max_norm_value / 2)
|
|
||||||
desired = torch.clamp(norm, max=max_norm_value)
|
|
||||||
ratio = desired.cpu() / norm.cpu()
|
|
||||||
sqrt_ratio = ratio**0.5
|
|
||||||
if ratio != 1:
|
|
||||||
keys_scaled += 1
|
|
||||||
state_dict[upkeys[i]] *= sqrt_ratio
|
|
||||||
state_dict[downkeys[i]] *= sqrt_ratio
|
|
||||||
scalednorm = updown.norm() * ratio
|
|
||||||
norms.append(scalednorm.item())
|
|
||||||
|
|
||||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
|
||||||
|
|||||||
@@ -1126,3 +1126,46 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
org_module._lora_restored = False
|
org_module._lora_restored = False
|
||||||
lora.enabled = False
|
lora.enabled = False
|
||||||
|
|
||||||
|
def apply_max_norm_regularization(self, max_norm_value, device):
|
||||||
|
downkeys = []
|
||||||
|
upkeys = []
|
||||||
|
alphakeys = []
|
||||||
|
norms = []
|
||||||
|
keys_scaled = 0
|
||||||
|
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
for key in state_dict.keys():
|
||||||
|
if "lora_down" in key and "weight" in key:
|
||||||
|
downkeys.append(key)
|
||||||
|
upkeys.append(key.replace("lora_down", "lora_up"))
|
||||||
|
alphakeys.append(key.replace("lora_down.weight", "alpha"))
|
||||||
|
|
||||||
|
for i in range(len(downkeys)):
|
||||||
|
down = state_dict[downkeys[i]].to(device)
|
||||||
|
up = state_dict[upkeys[i]].to(device)
|
||||||
|
alpha = state_dict[alphakeys[i]].to(device)
|
||||||
|
dim = down.shape[0]
|
||||||
|
scale = alpha / dim
|
||||||
|
|
||||||
|
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||||
|
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||||
|
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
||||||
|
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
||||||
|
else:
|
||||||
|
updown = up @ down
|
||||||
|
|
||||||
|
updown *= scale
|
||||||
|
|
||||||
|
norm = updown.norm().clamp(min=max_norm_value / 2)
|
||||||
|
desired = torch.clamp(norm, max=max_norm_value)
|
||||||
|
ratio = desired.cpu() / norm.cpu()
|
||||||
|
sqrt_ratio = ratio**0.5
|
||||||
|
if ratio != 1:
|
||||||
|
keys_scaled += 1
|
||||||
|
state_dict[upkeys[i]] *= sqrt_ratio
|
||||||
|
state_dict[downkeys[i]] *= sqrt_ratio
|
||||||
|
scalednorm = updown.norm() * ratio
|
||||||
|
norms.append(scalednorm.item())
|
||||||
|
|
||||||
|
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ from library.custom_train_functions import (
|
|||||||
prepare_scheduler_for_custom_training,
|
prepare_scheduler_for_custom_training,
|
||||||
pyramid_noise_like,
|
pyramid_noise_like,
|
||||||
apply_noise_offset,
|
apply_noise_offset,
|
||||||
max_norm,
|
|
||||||
scale_v_prediction_loss_like_noise_prediction,
|
scale_v_prediction_loss_like_noise_prediction,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -220,6 +219,11 @@ def train(args):
|
|||||||
|
|
||||||
if hasattr(network, "prepare_network"):
|
if hasattr(network, "prepare_network"):
|
||||||
network.prepare_network(args)
|
network.prepare_network(args)
|
||||||
|
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
|
||||||
|
print(
|
||||||
|
"warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません"
|
||||||
|
)
|
||||||
|
args.scale_weight_norms = False
|
||||||
|
|
||||||
train_unet = not args.network_train_text_encoder_only
|
train_unet = not args.network_train_text_encoder_only
|
||||||
train_text_encoder = not args.network_train_unet_only
|
train_text_encoder = not args.network_train_unet_only
|
||||||
@@ -677,7 +681,9 @@ def train(args):
|
|||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
if args.scale_weight_norms:
|
if args.scale_weight_norms:
|
||||||
keys_scaled, mean_norm, maximum_norm = max_norm(network.state_dict(), args.scale_weight_norms, accelerator.device)
|
keys_scaled, mean_norm, maximum_norm = network.apply_max_norm_regularization(
|
||||||
|
args.scale_weight_norms, accelerator.device
|
||||||
|
)
|
||||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||||
else:
|
else:
|
||||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||||
|
|||||||
Reference in New Issue
Block a user