mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix network_weights not working in train_network
This commit is contained in:
@@ -2285,7 +2285,7 @@ def main(args):
|
|||||||
|
|
||||||
if not args.network_merge:
|
if not args.network_merge:
|
||||||
network.apply_to(text_encoder, unet)
|
network.apply_to(text_encoder, unet)
|
||||||
info = network.load_state_dict(weights_sd, False)
|
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
||||||
print(f"weights are loaded: {info}")
|
print(f"weights are loaded: {info}")
|
||||||
|
|
||||||
if args.opt_channels_last:
|
if args.opt_channels_last:
|
||||||
|
|||||||
@@ -593,6 +593,17 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
lora.multiplier = self.multiplier
|
lora.multiplier = self.multiplier
|
||||||
|
|
||||||
|
def load_weights(self, file):
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
weights_sd = load_file(file)
|
||||||
|
else:
|
||||||
|
weights_sd = torch.load(file, map_location="cpu")
|
||||||
|
|
||||||
|
info = self.load_state_dict(weights_sd, False)
|
||||||
|
return info
|
||||||
|
|
||||||
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||||
if apply_text_encoder:
|
if apply_text_encoder:
|
||||||
print("enable LoRA for text encoder")
|
print("enable LoRA for text encoder")
|
||||||
|
|||||||
@@ -194,14 +194,14 @@ def train(args):
|
|||||||
if network is None:
|
if network is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if args.network_weights is not None:
|
|
||||||
print("load network weights from:", args.network_weights)
|
|
||||||
network.load_weights(args.network_weights)
|
|
||||||
|
|
||||||
train_unet = not args.network_train_text_encoder_only
|
train_unet = not args.network_train_text_encoder_only
|
||||||
train_text_encoder = not args.network_train_unet_only
|
train_text_encoder = not args.network_train_unet_only
|
||||||
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
||||||
|
|
||||||
|
if args.network_weights is not None:
|
||||||
|
info = network.load_weights(args.network_weights)
|
||||||
|
print(f"load network weights from {args.network_weights}: {info}")
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
text_encoder.gradient_checkpointing_enable()
|
text_encoder.gradient_checkpointing_enable()
|
||||||
|
|||||||
Reference in New Issue
Block a user