From 985761ca439f0ff5a3ee6f544cde8e3ece3a8e87 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 20 Feb 2024 20:33:03 +0900 Subject: [PATCH] fix to work without network module --- stable_cascade_gen_img.py | 49 ++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/stable_cascade_gen_img.py b/stable_cascade_gen_img.py index a98c1a68..76fe3b39 100644 --- a/stable_cascade_gen_img.py +++ b/stable_cascade_gen_img.py @@ -68,37 +68,38 @@ def main(args): previewer = None # LoRA - for i, network_module in enumerate(args.network_module): - print("import network module:", network_module) - imported_module = importlib.import_module(network_module) + if args.network_module: + for i, network_module in enumerate(args.network_module): + print("import network module:", network_module) + imported_module = importlib.import_module(network_module) - network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] + network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] - net_kwargs = {} - if args.network_args and i < len(args.network_args): - network_args = args.network_args[i] - # TODO escape special chars - network_args = network_args.split(";") - for net_arg in network_args: - key, value = net_arg.split("=") - net_kwargs[key] = value + net_kwargs = {} + if args.network_args and i < len(args.network_args): + network_args = args.network_args[i] + # TODO escape special chars + network_args = network_args.split(";") + for net_arg in network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value - if args.network_weights is None or len(args.network_weights) <= i: - raise ValueError("No weight. Weight is required.") + if args.network_weights is None or len(args.network_weights) <= i: + raise ValueError("No weight. Weight is required.") - network_weight = args.network_weights[i] - print("load network weights from:", network_weight) + network_weight = args.network_weights[i] + print("load network weights from:", network_weight) - network, weights_sd = imported_module.create_network_from_weights( - network_mul, network_weight, effnet, text_model, generator_c, for_inference=True, **net_kwargs - ) - if network is None: - return + network, weights_sd = imported_module.create_network_from_weights( + network_mul, network_weight, effnet, text_model, generator_c, for_inference=True, **net_kwargs + ) + if network is None: + return - mergeable = network.is_mergeable() - assert mergeable, "not-mergeable network is not supported yet." + mergeable = network.is_mergeable() + assert mergeable, "not-mergeable network is not supported yet." - network.merge_to(text_model, generator_c, weights_sd, dtype, device) + network.merge_to(text_model, generator_c, weights_sd, dtype, device) # 謎のクラス gdf gdf_c = sc.GDF(