Add scaling alpha for LoRA

This commit is contained in:
Kohya S
2023-01-21 20:37:34 +09:00
parent 22ee0ac467
commit b4636d4185
3 changed files with 59 additions and 26 deletions

View File

@@ -1981,7 +1981,6 @@ def main(args):
imported_module = importlib.import_module(network_module) imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_dim = None if args.network_dim is None or len(args.network_dim) <= i else args.network_dim[i]
net_kwargs = {} net_kwargs = {}
if args.network_args and i < len(args.network_args): if args.network_args and i < len(args.network_args):
@@ -1992,22 +1991,21 @@ def main(args):
key, value = net_arg.split("=") key, value = net_arg.split("=")
net_kwargs[key] = value net_kwargs[key] = value
network = imported_module.create_network(network_mul, network_dim, vae, text_encoder, unet, **net_kwargs)
if network is None:
return
if args.network_weights and i < len(args.network_weights): if args.network_weights and i < len(args.network_weights):
network_weight = args.network_weights[i] network_weight = args.network_weights[i]
print("load network weights from:", network_weight) print("load network weights from:", network_weight)
if os.path.splitext(network_weight)[1] == '.safetensors': from safetensors.torch import safe_open
from safetensors.torch import safe_open with safe_open(network_weight, framework="pt") as f:
with safe_open(network_weight, framework="pt") as f: metadata = f.metadata()
metadata = f.metadata() if metadata is not None:
if metadata is not None: print(f"metadata for: {network_weight}: {metadata}")
print(f"metadata for: {network_weight}: {metadata}")
network.load_weights(network_weight) network = imported_module.create_network_from_weights(network_mul, network_weight, vae, text_encoder, unet, **net_kwargs)
else:
raise ValueError("No weight. Weight is required.")
if network is None:
return
network.apply_to(text_encoder, unet) network.apply_to(text_encoder, unet)
@@ -2526,8 +2524,6 @@ if __name__ == '__main__':
parser.add_argument("--network_weights", type=str, default=None, nargs='*', parser.add_argument("--network_weights", type=str, default=None, nargs='*',
help='Hypernetwork weights to load / Hypernetworkの重み') help='Hypernetwork weights to load / Hypernetworkの重み')
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率') parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
parser.add_argument("--network_dim", type=int, default=None, nargs='*',
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
parser.add_argument("--network_args", type=str, default=None, nargs='*', parser.add_argument("--network_args", type=str, default=None, nargs='*',
help='additional argmuments for network (key=value) / ネットワークへの追加の引数') help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う') parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')

View File

@@ -13,9 +13,11 @@ class LoRAModule(torch.nn.Module):
replaces forward method of the original Linear, instead of replacing the original Linear module. replaces forward method of the original Linear, instead of replacing the original Linear module.
""" """
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4): def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
""" if alpha == 0 or None, alpha is rank (no scaling). """
super().__init__() super().__init__()
self.lora_name = lora_name self.lora_name = lora_name
self.lora_dim = lora_dim
if org_module.__class__.__name__ == 'Conv2d': if org_module.__class__.__name__ == 'Conv2d':
in_dim = org_module.in_channels in_dim = org_module.in_channels
@@ -28,6 +30,12 @@ class LoRAModule(torch.nn.Module):
self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False) self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False) self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().numpy()
alpha = lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
# same as microsoft's # same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_up.weight) torch.nn.init.zeros_(self.lora_up.weight)
@@ -41,13 +49,37 @@ class LoRAModule(torch.nn.Module):
del self.org_module del self.org_module
def forward(self, x): def forward(self, x):
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
def create_network(multiplier, network_dim, vae, text_encoder, unet, **kwargs): def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
if network_dim is None: if network_dim is None:
network_dim = 4 # default network_dim = 4 # default
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim) network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
return network
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import load_file, safe_open
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location='cpu')
# get dim (rank)
network_alpha = None
network_dim = None
for key, value in weights_sd.items():
if network_alpha is None and 'alpha' in key:
network_alpha = value
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is None:
network_alpha = network_dim
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
network.weights_sd = weights_sd
return network return network
@@ -57,10 +89,11 @@ class LoRANetwork(torch.nn.Module):
LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_UNET = 'lora_unet'
LORA_PREFIX_TEXT_ENCODER = 'lora_te' LORA_PREFIX_TEXT_ENCODER = 'lora_te'
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None: def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
super().__init__() super().__init__()
self.multiplier = multiplier self.multiplier = multiplier
self.lora_dim = lora_dim self.lora_dim = lora_dim
self.alpha = alpha
# create module instances # create module instances
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]: def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]:
@@ -71,7 +104,7 @@ class LoRANetwork(torch.nn.Module):
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
lora_name = prefix + '.' + name + '.' + child_name lora_name = prefix + '.' + name + '.' + child_name
lora_name = lora_name.replace('.', '_') lora_name = lora_name.replace('.', '_')
lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim) lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
loras.append(lora) loras.append(lora)
return loras return loras
@@ -149,21 +182,21 @@ class LoRANetwork(torch.nn.Module):
return params return params
self.requires_grad_(True) self.requires_grad_(True)
params = [] all_params = []
if self.text_encoder_loras: if self.text_encoder_loras:
param_data = {'params': enumerate_params(self.text_encoder_loras)} param_data = {'params': enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None: if text_encoder_lr is not None:
param_data['lr'] = text_encoder_lr param_data['lr'] = text_encoder_lr
params.append(param_data) all_params.append(param_data)
if self.unet_loras: if self.unet_loras:
param_data = {'params': enumerate_params(self.unet_loras)} param_data = {'params': enumerate_params(self.unet_loras)}
if unet_lr is not None: if unet_lr is not None:
param_data['lr'] = unet_lr param_data['lr'] = unet_lr
params.append(param_data) all_params.append(param_data)
return params return all_params
def prepare_grad_etc(self, text_encoder, unet): def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True) self.requires_grad_(True)

View File

@@ -107,7 +107,8 @@ def train(args):
key, value = net_arg.split('=') key, value = net_arg.split('=')
net_kwargs[key] = value net_kwargs[key] = value
network = network_module.create_network(1.0, args.network_dim, vae, text_encoder, unet, **net_kwargs) # if a new network is added in future, add if ~ then blocks for each network (;'∀')
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
if network is None: if network is None:
return return
@@ -243,7 +244,8 @@ def train(args):
"ss_lr_warmup_steps": args.lr_warmup_steps, "ss_lr_warmup_steps": args.lr_warmup_steps,
"ss_lr_scheduler": args.lr_scheduler, "ss_lr_scheduler": args.lr_scheduler,
"ss_network_module": args.network_module, "ss_network_module": args.network_module,
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
"ss_network_alpha": args.network_alpha, # some networks may not use this value
"ss_mixed_precision": args.mixed_precision, "ss_mixed_precision": args.mixed_precision,
"ss_full_fp16": bool(args.full_fp16), "ss_full_fp16": bool(args.full_fp16),
"ss_v2": bool(args.v2), "ss_v2": bool(args.v2),
@@ -445,6 +447,8 @@ if __name__ == '__main__':
parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール') parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール')
parser.add_argument("--network_dim", type=int, default=None, parser.add_argument("--network_dim", type=int, default=None,
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
parser.add_argument("--network_alpha", type=float, default=1,
help='alpha for LoRA weight scaling, 0 for no scaling (same as old version) / LoRaの重み調整のalpha値、0で調整なし旧バージョンと同じ')
parser.add_argument("--network_args", type=str, default=None, nargs='*', parser.add_argument("--network_args", type=str, default=None, nargs='*',
help='additional argmuments for network (key=value) / ネットワークへの追加の引数') help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する") parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する")