mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 06:28:48 +00:00
Print metadata for additional network
This commit is contained in:
@@ -46,11 +46,13 @@ VGG(
|
||||
)
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import List, Optional, Union
|
||||
import glob
|
||||
import importlib
|
||||
import inspect
|
||||
import time
|
||||
import zipfile
|
||||
from diffusers.utils import deprecate
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
import argparse
|
||||
@@ -1972,6 +1974,19 @@ def main(args):
|
||||
if args.network_weights and i < len(args.network_weights):
|
||||
network_weight = args.network_weights[i]
|
||||
print("load network weights from:", network_weight)
|
||||
|
||||
metadata = None
|
||||
if os.path.splitext(network_weight)[1] == '.safetensors':
|
||||
from safetensors.torch import safe_open
|
||||
with safe_open(network_weight, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
else:
|
||||
with zipfile.ZipFile(network_weight, "r") as zipf:
|
||||
if "sd_scripts_metadata.json" in zipf.namelist():
|
||||
with zipf.open("sd_scripts_metadata.json", "r") as jsfile:
|
||||
metadata = json.load(jsfile)
|
||||
print(f"metadata for: {network_weight}: {metadata}")
|
||||
|
||||
network.load_weights(network_weight)
|
||||
|
||||
network.apply_to(text_encoder, unet)
|
||||
|
||||
Reference in New Issue
Block a user