Fix network_weights not working in train_network

This commit is contained in:
Kohya S
2023-04-03 22:45:28 +09:00
parent 959561473c
commit 83c7e03d05
3 changed files with 16 additions and 5 deletions

View File

@@ -194,14 +194,14 @@ def train(args):
if network is None:
return
if args.network_weights is not None:
print("load network weights from:", args.network_weights)
network.load_weights(args.network_weights)
train_unet = not args.network_train_text_encoder_only
train_text_encoder = not args.network_train_unet_only
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
if args.network_weights is not None:
info = network.load_weights(args.network_weights)
print(f"load network weights from {args.network_weights}: {info}")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()