Store metadata to .ckpt as value of state dict

This commit is contained in:
Kohya S
2023-01-12 10:54:21 +09:00
parent 9622082eb8
commit 9fd91d26a3

View File

@@ -58,7 +58,7 @@ class LoRANetwork(torch.nn.Module):
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_UNET = 'lora_unet'
LORA_PREFIX_TEXT_ENCODER = 'lora_te' LORA_PREFIX_TEXT_ENCODER = 'lora_te'
METADATA_FILENAME = "sd_scripts_metadata.json" METADATA_KEY_NAME = "sd_scripts_metadata"
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None: def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None:
super().__init__() super().__init__()
@@ -178,7 +178,7 @@ class LoRANetwork(torch.nn.Module):
return self.parameters() return self.parameters()
def save_weights(self, file, dtype, metadata): def save_weights(self, file, dtype, metadata):
if len(metadata) == 0: if metadata is not None and len(metadata) == 0:
metadata = None metadata = None
state_dict = self.state_dict() state_dict = self.state_dict()
@@ -193,7 +193,6 @@ class LoRANetwork(torch.nn.Module):
from safetensors.torch import save_file from safetensors.torch import save_file
save_file(state_dict, file, metadata) save_file(state_dict, file, metadata)
else: else:
torch.save(state_dict, file)
if metadata is not None: if metadata is not None:
with zipfile.ZipFile(file, "w") as zipf: state_dict[LoRANetwork.METADATA_KEY_NAME] = metadata
zipf.writestr(LoRANetwork.METADATA_FILENAME, json.dumps(metadata)) torch.save(state_dict, file)