mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Add scaling alpha for LoRA
This commit is contained in:
@@ -1981,7 +1981,6 @@ def main(args):
|
|||||||
imported_module = importlib.import_module(network_module)
|
imported_module = importlib.import_module(network_module)
|
||||||
|
|
||||||
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||||
network_dim = None if args.network_dim is None or len(args.network_dim) <= i else args.network_dim[i]
|
|
||||||
|
|
||||||
net_kwargs = {}
|
net_kwargs = {}
|
||||||
if args.network_args and i < len(args.network_args):
|
if args.network_args and i < len(args.network_args):
|
||||||
@@ -1992,22 +1991,21 @@ def main(args):
|
|||||||
key, value = net_arg.split("=")
|
key, value = net_arg.split("=")
|
||||||
net_kwargs[key] = value
|
net_kwargs[key] = value
|
||||||
|
|
||||||
network = imported_module.create_network(network_mul, network_dim, vae, text_encoder, unet, **net_kwargs)
|
|
||||||
if network is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if args.network_weights and i < len(args.network_weights):
|
if args.network_weights and i < len(args.network_weights):
|
||||||
network_weight = args.network_weights[i]
|
network_weight = args.network_weights[i]
|
||||||
print("load network weights from:", network_weight)
|
print("load network weights from:", network_weight)
|
||||||
|
|
||||||
if os.path.splitext(network_weight)[1] == '.safetensors':
|
|
||||||
from safetensors.torch import safe_open
|
from safetensors.torch import safe_open
|
||||||
with safe_open(network_weight, framework="pt") as f:
|
with safe_open(network_weight, framework="pt") as f:
|
||||||
metadata = f.metadata()
|
metadata = f.metadata()
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
print(f"metadata for: {network_weight}: {metadata}")
|
print(f"metadata for: {network_weight}: {metadata}")
|
||||||
|
|
||||||
network.load_weights(network_weight)
|
network = imported_module.create_network_from_weights(network_mul, network_weight, vae, text_encoder, unet, **net_kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError("No weight. Weight is required.")
|
||||||
|
if network is None:
|
||||||
|
return
|
||||||
|
|
||||||
network.apply_to(text_encoder, unet)
|
network.apply_to(text_encoder, unet)
|
||||||
|
|
||||||
@@ -2526,8 +2524,6 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
||||||
help='Hypernetwork weights to load / Hypernetworkの重み')
|
help='Hypernetwork weights to load / Hypernetworkの重み')
|
||||||
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
||||||
parser.add_argument("--network_dim", type=int, default=None, nargs='*',
|
|
||||||
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
|
|
||||||
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
||||||
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
||||||
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
||||||
|
|||||||
@@ -13,9 +13,11 @@ class LoRAModule(torch.nn.Module):
|
|||||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4):
|
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
|
||||||
|
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lora_name = lora_name
|
self.lora_name = lora_name
|
||||||
|
self.lora_dim = lora_dim
|
||||||
|
|
||||||
if org_module.__class__.__name__ == 'Conv2d':
|
if org_module.__class__.__name__ == 'Conv2d':
|
||||||
in_dim = org_module.in_channels
|
in_dim = org_module.in_channels
|
||||||
@@ -28,6 +30,12 @@ class LoRAModule(torch.nn.Module):
|
|||||||
self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
|
self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
|
||||||
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
|
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
|
||||||
|
|
||||||
|
if type(alpha) == torch.Tensor:
|
||||||
|
alpha = alpha.detach().numpy()
|
||||||
|
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
||||||
|
self.scale = alpha / self.lora_dim
|
||||||
|
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
||||||
|
|
||||||
# same as microsoft's
|
# same as microsoft's
|
||||||
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
||||||
torch.nn.init.zeros_(self.lora_up.weight)
|
torch.nn.init.zeros_(self.lora_up.weight)
|
||||||
@@ -41,13 +49,37 @@ class LoRAModule(torch.nn.Module):
|
|||||||
del self.org_module
|
del self.org_module
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||||
|
|
||||||
|
|
||||||
def create_network(multiplier, network_dim, vae, text_encoder, unet, **kwargs):
|
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
||||||
if network_dim is None:
|
if network_dim is None:
|
||||||
network_dim = 4 # default
|
network_dim = 4 # default
|
||||||
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim)
|
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
|
||||||
|
return network
|
||||||
|
|
||||||
|
|
||||||
|
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
|
||||||
|
if os.path.splitext(file)[1] == '.safetensors':
|
||||||
|
from safetensors.torch import load_file, safe_open
|
||||||
|
weights_sd = load_file(file)
|
||||||
|
else:
|
||||||
|
weights_sd = torch.load(file, map_location='cpu')
|
||||||
|
|
||||||
|
# get dim (rank)
|
||||||
|
network_alpha = None
|
||||||
|
network_dim = None
|
||||||
|
for key, value in weights_sd.items():
|
||||||
|
if network_alpha is None and 'alpha' in key:
|
||||||
|
network_alpha = value
|
||||||
|
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
|
||||||
|
network_dim = value.size()[0]
|
||||||
|
|
||||||
|
if network_alpha is None:
|
||||||
|
network_alpha = network_dim
|
||||||
|
|
||||||
|
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
|
||||||
|
network.weights_sd = weights_sd
|
||||||
return network
|
return network
|
||||||
|
|
||||||
|
|
||||||
@@ -57,10 +89,11 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
LORA_PREFIX_UNET = 'lora_unet'
|
LORA_PREFIX_UNET = 'lora_unet'
|
||||||
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
||||||
|
|
||||||
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None:
|
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
self.lora_dim = lora_dim
|
self.lora_dim = lora_dim
|
||||||
|
self.alpha = alpha
|
||||||
|
|
||||||
# create module instances
|
# create module instances
|
||||||
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]:
|
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]:
|
||||||
@@ -71,7 +104,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
|
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
|
||||||
lora_name = prefix + '.' + name + '.' + child_name
|
lora_name = prefix + '.' + name + '.' + child_name
|
||||||
lora_name = lora_name.replace('.', '_')
|
lora_name = lora_name.replace('.', '_')
|
||||||
lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim)
|
lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
|
||||||
loras.append(lora)
|
loras.append(lora)
|
||||||
return loras
|
return loras
|
||||||
|
|
||||||
@@ -149,21 +182,21 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
self.requires_grad_(True)
|
self.requires_grad_(True)
|
||||||
params = []
|
all_params = []
|
||||||
|
|
||||||
if self.text_encoder_loras:
|
if self.text_encoder_loras:
|
||||||
param_data = {'params': enumerate_params(self.text_encoder_loras)}
|
param_data = {'params': enumerate_params(self.text_encoder_loras)}
|
||||||
if text_encoder_lr is not None:
|
if text_encoder_lr is not None:
|
||||||
param_data['lr'] = text_encoder_lr
|
param_data['lr'] = text_encoder_lr
|
||||||
params.append(param_data)
|
all_params.append(param_data)
|
||||||
|
|
||||||
if self.unet_loras:
|
if self.unet_loras:
|
||||||
param_data = {'params': enumerate_params(self.unet_loras)}
|
param_data = {'params': enumerate_params(self.unet_loras)}
|
||||||
if unet_lr is not None:
|
if unet_lr is not None:
|
||||||
param_data['lr'] = unet_lr
|
param_data['lr'] = unet_lr
|
||||||
params.append(param_data)
|
all_params.append(param_data)
|
||||||
|
|
||||||
return params
|
return all_params
|
||||||
|
|
||||||
def prepare_grad_etc(self, text_encoder, unet):
|
def prepare_grad_etc(self, text_encoder, unet):
|
||||||
self.requires_grad_(True)
|
self.requires_grad_(True)
|
||||||
|
|||||||
@@ -107,7 +107,8 @@ def train(args):
|
|||||||
key, value = net_arg.split('=')
|
key, value = net_arg.split('=')
|
||||||
net_kwargs[key] = value
|
net_kwargs[key] = value
|
||||||
|
|
||||||
network = network_module.create_network(1.0, args.network_dim, vae, text_encoder, unet, **net_kwargs)
|
# if a new network is added in future, add if ~ then blocks for each network (;'∀')
|
||||||
|
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
|
||||||
if network is None:
|
if network is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -244,6 +245,7 @@ def train(args):
|
|||||||
"ss_lr_scheduler": args.lr_scheduler,
|
"ss_lr_scheduler": args.lr_scheduler,
|
||||||
"ss_network_module": args.network_module,
|
"ss_network_module": args.network_module,
|
||||||
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
|
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
|
||||||
|
"ss_network_alpha": args.network_alpha, # some networks may not use this value
|
||||||
"ss_mixed_precision": args.mixed_precision,
|
"ss_mixed_precision": args.mixed_precision,
|
||||||
"ss_full_fp16": bool(args.full_fp16),
|
"ss_full_fp16": bool(args.full_fp16),
|
||||||
"ss_v2": bool(args.v2),
|
"ss_v2": bool(args.v2),
|
||||||
@@ -445,6 +447,8 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール')
|
parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール')
|
||||||
parser.add_argument("--network_dim", type=int, default=None,
|
parser.add_argument("--network_dim", type=int, default=None,
|
||||||
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
|
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
|
||||||
|
parser.add_argument("--network_alpha", type=float, default=1,
|
||||||
|
help='alpha for LoRA weight scaling, 0 for no scaling (same as old version) / LoRaの重み調整のalpha値、0で調整なし(旧バージョンと同じ)')
|
||||||
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
||||||
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
||||||
parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する")
|
parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する")
|
||||||
|
|||||||
Reference in New Issue
Block a user