Print metadata for additional network

This commit is contained in:
Kohya S
2023-01-11 23:12:35 +09:00
parent e4f9b2b715
commit 9622082eb8

View File

@@ -46,11 +46,13 @@ VGG(
)
"""
import json
from typing import List, Optional, Union
import glob
import importlib
import inspect
import time
import zipfile
from diffusers.utils import deprecate
from diffusers.configuration_utils import FrozenDict
import argparse
@@ -1972,6 +1974,19 @@ def main(args):
if args.network_weights and i < len(args.network_weights):
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
metadata = None
if os.path.splitext(network_weight)[1] == '.safetensors':
from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
else:
with zipfile.ZipFile(network_weight, "r") as zipf:
if "sd_scripts_metadata.json" in zipf.namelist():
with zipf.open("sd_scripts_metadata.json", "r") as jsfile:
metadata = json.load(jsfile)
print(f"metadata for: {network_weight}: {metadata}")
network.load_weights(network_weight)
network.apply_to(text_encoder, unet)