fix VAE becomes last one

This commit is contained in:
Kohya S
2023-09-13 17:59:14 +09:00
parent 207fc8b256
commit 0ecfd91a20

View File

@@ -51,7 +51,7 @@ def merge(args):
print(f"Model {model} does not exist") print(f"Model {model} does not exist")
exit() exit()
assert len(args.models) == len(args.ratios) or args.ratios is None, "ratios must be the same length as models" assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models"
# load and merge # load and merge
ratio = 1.0 / len(args.models) # default ratio = 1.0 / len(args.models) # default
@@ -113,13 +113,13 @@ def merge(args):
# add supplementary keys' value (including VAE and TextEncoder) # add supplementary keys' value (including VAE and TextEncoder)
if len(supplementary_key_ratios) > 0: if len(supplementary_key_ratios) > 0:
print("add first model's value") print("add first model's value")
with safe_open(model, framework="pt", device=args.device) as f: with safe_open(args.models[0], framework="pt", device=args.device) as f:
for key in tqdm(f.keys()): for key in tqdm(f.keys()):
_, new_key = replace_text_encoder_key(key) _, new_key = replace_text_encoder_key(key)
if new_key not in supplementary_key_ratios: if new_key not in supplementary_key_ratios:
continue continue
if is_unet_key(new_key): # not VAE or TextEncoder if is_unet_key(new_key): # not VAE or TextEncoder
print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}") print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")
value = f.get_tensor(key) # original key value = f.get_tensor(key) # original key