mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Fix metadata loading
This commit is contained in:
@@ -58,12 +58,12 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||||
LORA_PREFIX_UNET = 'lora_unet'
|
LORA_PREFIX_UNET = 'lora_unet'
|
||||||
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
||||||
|
METADATA_FILENAME = "sd_scripts_metadata.json"
|
||||||
|
|
||||||
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None:
|
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
self.lora_dim = lora_dim
|
self.lora_dim = lora_dim
|
||||||
self.metadata = {}
|
|
||||||
|
|
||||||
# create module instances
|
# create module instances
|
||||||
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]:
|
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]:
|
||||||
@@ -94,17 +94,11 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
names.add(lora.lora_name)
|
names.add(lora.lora_name)
|
||||||
|
|
||||||
def load_weights(self, file):
|
def load_weights(self, file):
|
||||||
self.metadata = {}
|
|
||||||
if os.path.splitext(file)[1] == '.safetensors':
|
if os.path.splitext(file)[1] == '.safetensors':
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file, safe_open
|
||||||
self.weights_sd = load_file(file)
|
self.weights_sd = load_file(file)
|
||||||
self.metadata = self.weights_sd.metadata()
|
|
||||||
else:
|
else:
|
||||||
self.weights_sd = torch.load(file, map_location='cpu')
|
self.weights_sd = torch.load(file, map_location='cpu')
|
||||||
with zipfile.ZipFile(file, "w") as zipf:
|
|
||||||
if "sd_scripts_metadata.json" in zipf.namelist():
|
|
||||||
with zipf.open("sd_scripts_metadata.json", "r") as jsfile:
|
|
||||||
self.metadata = json.load(jsfile)
|
|
||||||
|
|
||||||
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
|
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
|
||||||
if self.weights_sd:
|
if self.weights_sd:
|
||||||
@@ -184,7 +178,6 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
return self.parameters()
|
return self.parameters()
|
||||||
|
|
||||||
def save_weights(self, file, dtype, metadata):
|
def save_weights(self, file, dtype, metadata):
|
||||||
self.metadata = metadata
|
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
@@ -199,4 +192,4 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
torch.save(state_dict, file)
|
torch.save(state_dict, file)
|
||||||
with zipfile.ZipFile(file, "w") as zipf:
|
with zipfile.ZipFile(file, "w") as zipf:
|
||||||
zipf.writestr("sd_scripts_metadata.json", json.dumps(metadata))
|
zipf.writestr(LoRANetwork.METADATA_FILENAME, json.dumps(metadata))
|
||||||
|
|||||||
Reference in New Issue
Block a user