mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add scaling alpha for LoRA
This commit is contained in:
@@ -1981,7 +1981,6 @@ def main(args):
|
||||
imported_module = importlib.import_module(network_module)
|
||||
|
||||
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||
network_dim = None if args.network_dim is None or len(args.network_dim) <= i else args.network_dim[i]
|
||||
|
||||
net_kwargs = {}
|
||||
if args.network_args and i < len(args.network_args):
|
||||
@@ -1992,22 +1991,21 @@ def main(args):
|
||||
key, value = net_arg.split("=")
|
||||
net_kwargs[key] = value
|
||||
|
||||
network = imported_module.create_network(network_mul, network_dim, vae, text_encoder, unet, **net_kwargs)
|
||||
if network is None:
|
||||
return
|
||||
|
||||
if args.network_weights and i < len(args.network_weights):
|
||||
network_weight = args.network_weights[i]
|
||||
print("load network weights from:", network_weight)
|
||||
|
||||
if os.path.splitext(network_weight)[1] == '.safetensors':
|
||||
from safetensors.torch import safe_open
|
||||
with safe_open(network_weight, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is not None:
|
||||
print(f"metadata for: {network_weight}: {metadata}")
|
||||
from safetensors.torch import safe_open
|
||||
with safe_open(network_weight, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is not None:
|
||||
print(f"metadata for: {network_weight}: {metadata}")
|
||||
|
||||
network.load_weights(network_weight)
|
||||
network = imported_module.create_network_from_weights(network_mul, network_weight, vae, text_encoder, unet, **net_kwargs)
|
||||
else:
|
||||
raise ValueError("No weight. Weight is required.")
|
||||
if network is None:
|
||||
return
|
||||
|
||||
network.apply_to(text_encoder, unet)
|
||||
|
||||
@@ -2526,8 +2524,6 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
||||
help='Hypernetwork weights to load / Hypernetworkの重み')
|
||||
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
||||
parser.add_argument("--network_dim", type=int, default=None, nargs='*',
|
||||
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
|
||||
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
||||
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
||||
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
||||
|
||||
Reference in New Issue
Block a user