diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 208b1b70..1912e720 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -46,11 +46,13 @@ VGG( ) """ +import json from typing import List, Optional, Union import glob import importlib import inspect import time +import zipfile from diffusers.utils import deprecate from diffusers.configuration_utils import FrozenDict import argparse @@ -1972,6 +1974,19 @@ def main(args): if args.network_weights and i < len(args.network_weights): network_weight = args.network_weights[i] print("load network weights from:", network_weight) + + metadata = None + if os.path.splitext(network_weight)[1] == '.safetensors': + from safetensors.torch import safe_open + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + else: + with zipfile.ZipFile(network_weight, "r") as zipf: + if "sd_scripts_metadata.json" in zipf.namelist(): + with zipf.open("sd_scripts_metadata.json", "r") as jsfile: + metadata = json.load(jsfile) + print(f"metadata for: {network_weight}: {metadata}") + network.load_weights(network_weight) network.apply_to(text_encoder, unet)