Add block dim(rank) feature

This commit is contained in:
Kohya S
2023-04-03 21:19:49 +09:00
parent 817a9268ff
commit 6134619998
4 changed files with 361 additions and 256 deletions

View File

@@ -2275,7 +2275,7 @@ def main(args):
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network = imported_module.create_network_from_weights(
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoder, unet, **net_kwargs
)
else:
@@ -2285,6 +2285,8 @@ def main(args):
if not args.network_merge:
network.apply_to(text_encoder, unet)
info = network.load_state_dict(weights_sd, False)
print(f"weights are loaded: {info}")
if args.opt_channels_last:
network.to(memory_format=torch.channels_last)
@@ -2292,7 +2294,7 @@ def main(args):
networks.append(network)
else:
network.merge_to(text_encoder, unet, dtype, device)
network.merge_to(text_encoder, unet, weights_sd, dtype, device)
else:
networks = []