diff --git a/networks/lora.py b/networks/lora.py index dbef2aa1..98e8e4a4 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -58,12 +58,12 @@ class LoRANetwork(torch.nn.Module): TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_TEXT_ENCODER = 'lora_te' + METADATA_FILENAME = "sd_scripts_metadata.json" def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None: super().__init__() self.multiplier = multiplier self.lora_dim = lora_dim - self.metadata = {} # create module instances def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]: @@ -94,17 +94,11 @@ class LoRANetwork(torch.nn.Module): names.add(lora.lora_name) def load_weights(self, file): - self.metadata = {} if os.path.splitext(file)[1] == '.safetensors': - from safetensors.torch import load_file + from safetensors.torch import load_file, safe_open self.weights_sd = load_file(file) - self.metadata = self.weights_sd.metadata() else: self.weights_sd = torch.load(file, map_location='cpu') - with zipfile.ZipFile(file, "w") as zipf: - if "sd_scripts_metadata.json" in zipf.namelist(): - with zipf.open("sd_scripts_metadata.json", "r") as jsfile: - self.metadata = json.load(jsfile) def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None): if self.weights_sd: @@ -184,7 +178,6 @@ class LoRANetwork(torch.nn.Module): return self.parameters() def save_weights(self, file, dtype, metadata): - self.metadata = metadata state_dict = self.state_dict() if dtype is not None: @@ -199,4 +192,4 @@ class LoRANetwork(torch.nn.Module): else: torch.save(state_dict, file) with zipfile.ZipFile(file, "w") as zipf: - zipf.writestr("sd_scripts_metadata.json", json.dumps(metadata)) + zipf.writestr(LoRANetwork.METADATA_FILENAME, json.dumps(metadata))