Fix network_weights not working in train_network

This commit is contained in:
Kohya S
2023-04-03 22:45:28 +09:00
parent 959561473c
commit 83c7e03d05
3 changed files with 16 additions and 5 deletions

View File

@@ -593,6 +593,17 @@ class LoRANetwork(torch.nn.Module):
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
info = self.load_state_dict(weights_sd, False)
return info
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
print("enable LoRA for text encoder")