mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix network_weights not working in train_network
This commit is contained in:
@@ -2285,7 +2285,7 @@ def main(args):
|
||||
|
||||
if not args.network_merge:
|
||||
network.apply_to(text_encoder, unet)
|
||||
info = network.load_state_dict(weights_sd, False)
|
||||
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
||||
print(f"weights are loaded: {info}")
|
||||
|
||||
if args.opt_channels_last:
|
||||
|
||||
@@ -593,6 +593,17 @@ class LoRANetwork(torch.nn.Module):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.multiplier = self.multiplier
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
info = self.load_state_dict(weights_sd, False)
|
||||
return info
|
||||
|
||||
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
print("enable LoRA for text encoder")
|
||||
|
||||
@@ -194,14 +194,14 @@ def train(args):
|
||||
if network is None:
|
||||
return
|
||||
|
||||
if args.network_weights is not None:
|
||||
print("load network weights from:", args.network_weights)
|
||||
network.load_weights(args.network_weights)
|
||||
|
||||
train_unet = not args.network_train_text_encoder_only
|
||||
train_text_encoder = not args.network_train_unet_only
|
||||
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
||||
|
||||
if args.network_weights is not None:
|
||||
info = network.load_weights(args.network_weights)
|
||||
print(f"load network weights from {args.network_weights}: {info}")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
Reference in New Issue
Block a user