fix to work without network module

This commit is contained in:
Kohya S
2024-02-20 20:33:03 +09:00
parent 71e03559e2
commit 985761ca43

View File

@@ -68,37 +68,38 @@ def main(args):
previewer = None
# LoRA
for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
imported_module = importlib.import_module(network_module)
if args.network_module:
for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
net_kwargs = {}
if args.network_args and i < len(args.network_args):
network_args = args.network_args[i]
# TODO escape special chars
network_args = network_args.split(";")
for net_arg in network_args:
key, value = net_arg.split("=")
net_kwargs[key] = value
net_kwargs = {}
if args.network_args and i < len(args.network_args):
network_args = args.network_args[i]
# TODO escape special chars
network_args = network_args.split(";")
for net_arg in network_args:
key, value = net_arg.split("=")
net_kwargs[key] = value
if args.network_weights is None or len(args.network_weights) <= i:
raise ValueError("No weight. Weight is required.")
if args.network_weights is None or len(args.network_weights) <= i:
raise ValueError("No weight. Weight is required.")
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, effnet, text_model, generator_c, for_inference=True, **net_kwargs
)
if network is None:
return
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, effnet, text_model, generator_c, for_inference=True, **net_kwargs
)
if network is None:
return
mergeable = network.is_mergeable()
assert mergeable, "not-mergeable network is not supported yet."
mergeable = network.is_mergeable()
assert mergeable, "not-mergeable network is not supported yet."
network.merge_to(text_model, generator_c, weights_sd, dtype, device)
network.merge_to(text_model, generator_c, weights_sd, dtype, device)
# 謎のクラス gdf
gdf_c = sc.GDF(