mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Store metadata to .ckpt as value of state dict
This commit is contained in:
@@ -58,7 +58,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||||
LORA_PREFIX_UNET = 'lora_unet'
|
LORA_PREFIX_UNET = 'lora_unet'
|
||||||
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
||||||
METADATA_FILENAME = "sd_scripts_metadata.json"
|
METADATA_KEY_NAME = "sd_scripts_metadata"
|
||||||
|
|
||||||
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None:
|
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -178,7 +178,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
return self.parameters()
|
return self.parameters()
|
||||||
|
|
||||||
def save_weights(self, file, dtype, metadata):
|
def save_weights(self, file, dtype, metadata):
|
||||||
if len(metadata) == 0:
|
if metadata is not None and len(metadata) == 0:
|
||||||
metadata = None
|
metadata = None
|
||||||
|
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
@@ -193,7 +193,6 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
save_file(state_dict, file, metadata)
|
save_file(state_dict, file, metadata)
|
||||||
else:
|
else:
|
||||||
torch.save(state_dict, file)
|
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
with zipfile.ZipFile(file, "w") as zipf:
|
state_dict[LoRANetwork.METADATA_KEY_NAME] = metadata
|
||||||
zipf.writestr(LoRANetwork.METADATA_FILENAME, json.dumps(metadata))
|
torch.save(state_dict, file)
|
||||||
|
|||||||
Reference in New Issue
Block a user