Fix metadata loading

This commit is contained in:
space-nuko
2023-01-10 02:55:25 -08:00
parent 0c4423d9dc
commit de37fd9906

View File

@@ -58,12 +58,12 @@ class LoRANetwork(torch.nn.Module):
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = 'lora_unet'
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
METADATA_FILENAME = "sd_scripts_metadata.json"
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None:
super().__init__()
self.multiplier = multiplier
self.lora_dim = lora_dim
self.metadata = {}
# create module instances
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]:
@@ -94,17 +94,11 @@ class LoRANetwork(torch.nn.Module):
names.add(lora.lora_name)
def load_weights(self, file):
self.metadata = {}
if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import load_file
from safetensors.torch import load_file, safe_open
self.weights_sd = load_file(file)
self.metadata = self.weights_sd.metadata()
else:
self.weights_sd = torch.load(file, map_location='cpu')
with zipfile.ZipFile(file, "w") as zipf:
if "sd_scripts_metadata.json" in zipf.namelist():
with zipf.open("sd_scripts_metadata.json", "r") as jsfile:
self.metadata = json.load(jsfile)
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
if self.weights_sd:
@@ -184,7 +178,6 @@ class LoRANetwork(torch.nn.Module):
return self.parameters()
def save_weights(self, file, dtype, metadata):
self.metadata = metadata
state_dict = self.state_dict()
if dtype is not None:
@@ -199,4 +192,4 @@ class LoRANetwork(torch.nn.Module):
else:
torch.save(state_dict, file)
with zipfile.ZipFile(file, "w") as zipf:
zipf.writestr("sd_scripts_metadata.json", json.dumps(metadata))
zipf.writestr(LoRANetwork.METADATA_FILENAME, json.dumps(metadata))