support LoRA merge in advance

This commit is contained in:
Kohya S
2023-03-30 21:34:36 +09:00
parent cb53a77334
commit 2d6faa9860
2 changed files with 81 additions and 16 deletions

View File

@@ -2262,7 +2262,7 @@ def main(args):
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoder, unet, **net_kwargs
)
@@ -2271,13 +2271,17 @@ def main(args):
if network is None:
return
network.apply_to(text_encoder, unet)
if not args.network_merge:
network.apply_to(text_encoder, unet)
if args.opt_channels_last:
network.to(memory_format=torch.channels_last)
network.to(dtype).to(device)
if args.opt_channels_last:
network.to(memory_format=torch.channels_last)
network.to(dtype).to(device)
networks.append(network)
else:
network.merge_to(text_encoder, unet, dtype, device)
networks.append(network)
else:
networks = []
@@ -3074,6 +3078,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
)
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
parser.add_argument(
"--textual_inversion_embeddings",
type=str,