mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
do not save metadata in .pt/.ckpt
This commit is contained in:
@@ -1975,16 +1975,11 @@ def main(args):
|
|||||||
network_weight = args.network_weights[i]
|
network_weight = args.network_weights[i]
|
||||||
print("load network weights from:", network_weight)
|
print("load network weights from:", network_weight)
|
||||||
|
|
||||||
metadata = None
|
|
||||||
if os.path.splitext(network_weight)[1] == '.safetensors':
|
if os.path.splitext(network_weight)[1] == '.safetensors':
|
||||||
from safetensors.torch import safe_open
|
from safetensors.torch import safe_open
|
||||||
with safe_open(network_weight, framework="pt") as f:
|
with safe_open(network_weight, framework="pt") as f:
|
||||||
metadata = f.metadata()
|
metadata = f.metadata()
|
||||||
else:
|
if metadata is not None:
|
||||||
with zipfile.ZipFile(network_weight, "r") as zipf:
|
|
||||||
if "sd_scripts_metadata.json" in zipf.namelist():
|
|
||||||
with zipf.open("sd_scripts_metadata.json", "r") as jsfile:
|
|
||||||
metadata = json.load(jsfile)
|
|
||||||
print(f"metadata for: {network_weight}: {metadata}")
|
print(f"metadata for: {network_weight}: {metadata}")
|
||||||
|
|
||||||
network.load_weights(network_weight)
|
network.load_weights(network_weight)
|
||||||
|
|||||||
@@ -6,8 +6,6 @@
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import zipfile
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAModule(torch.nn.Module):
|
class LoRAModule(torch.nn.Module):
|
||||||
@@ -58,7 +56,6 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||||
LORA_PREFIX_UNET = 'lora_unet'
|
LORA_PREFIX_UNET = 'lora_unet'
|
||||||
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
||||||
METADATA_KEY_NAME = "sd_scripts_metadata"
|
|
||||||
|
|
||||||
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None:
|
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -193,6 +190,4 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
save_file(state_dict, file, metadata)
|
save_file(state_dict, file, metadata)
|
||||||
else:
|
else:
|
||||||
if metadata is not None:
|
|
||||||
state_dict[LoRANetwork.METADATA_KEY_NAME] = metadata
|
|
||||||
torch.save(state_dict, file)
|
torch.save(state_dict, file)
|
||||||
|
|||||||
Reference in New Issue
Block a user