do not save metadata in .pt/.ckpt

This commit is contained in:
Kohya S
2023-01-12 21:52:55 +09:00
parent 9fd91d26a3
commit eba142ccb2
2 changed files with 2 additions and 12 deletions

View File

@@ -1975,16 +1975,11 @@ def main(args):
network_weight = args.network_weights[i] network_weight = args.network_weights[i]
print("load network weights from:", network_weight) print("load network weights from:", network_weight)
metadata = None
if os.path.splitext(network_weight)[1] == '.safetensors': if os.path.splitext(network_weight)[1] == '.safetensors':
from safetensors.torch import safe_open from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f: with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata() metadata = f.metadata()
else: if metadata is not None:
with zipfile.ZipFile(network_weight, "r") as zipf:
if "sd_scripts_metadata.json" in zipf.namelist():
with zipf.open("sd_scripts_metadata.json", "r") as jsfile:
metadata = json.load(jsfile)
print(f"metadata for: {network_weight}: {metadata}") print(f"metadata for: {network_weight}: {metadata}")
network.load_weights(network_weight) network.load_weights(network_weight)

View File

@@ -6,8 +6,6 @@
import math import math
import os import os
import torch import torch
import zipfile
import json
class LoRAModule(torch.nn.Module): class LoRAModule(torch.nn.Module):
@@ -58,7 +56,6 @@ class LoRANetwork(torch.nn.Module):
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_UNET = 'lora_unet'
LORA_PREFIX_TEXT_ENCODER = 'lora_te' LORA_PREFIX_TEXT_ENCODER = 'lora_te'
METADATA_KEY_NAME = "sd_scripts_metadata"
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None: def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None:
super().__init__() super().__init__()
@@ -193,6 +190,4 @@ class LoRANetwork(torch.nn.Module):
from safetensors.torch import save_file from safetensors.torch import save_file
save_file(state_dict, file, metadata) save_file(state_dict, file, metadata)
else: else:
if metadata is not None:
state_dict[LoRANetwork.METADATA_KEY_NAME] = metadata
torch.save(state_dict, file) torch.save(state_dict, file)