diff --git a/networks/lora.py b/networks/lora.py index 77fe26a7..de87d064 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -58,7 +58,7 @@ class LoRANetwork(torch.nn.Module): TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_TEXT_ENCODER = 'lora_te' - METADATA_FILENAME = "sd_scripts_metadata.json" + METADATA_KEY_NAME = "sd_scripts_metadata" def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None: super().__init__() @@ -178,7 +178,7 @@ class LoRANetwork(torch.nn.Module): return self.parameters() def save_weights(self, file, dtype, metadata): - if len(metadata) == 0: + if metadata is not None and len(metadata) == 0: metadata = None state_dict = self.state_dict() @@ -193,7 +193,6 @@ class LoRANetwork(torch.nn.Module): from safetensors.torch import save_file save_file(state_dict, file, metadata) else: - torch.save(state_dict, file) if metadata is not None: - with zipfile.ZipFile(file, "w") as zipf: - zipf.writestr(LoRANetwork.METADATA_FILENAME, json.dumps(metadata)) + state_dict[LoRANetwork.METADATA_KEY_NAME] = metadata + torch.save(state_dict, file)