mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix VAE becomes last one
This commit is contained in:
@@ -51,7 +51,7 @@ def merge(args):
|
||||
print(f"Model {model} does not exist")
|
||||
exit()
|
||||
|
||||
assert len(args.models) == len(args.ratios) or args.ratios is None, "ratios must be the same length as models"
|
||||
assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models"
|
||||
|
||||
# load and merge
|
||||
ratio = 1.0 / len(args.models) # default
|
||||
@@ -113,13 +113,13 @@ def merge(args):
|
||||
# add supplementary keys' value (including VAE and TextEncoder)
|
||||
if len(supplementary_key_ratios) > 0:
|
||||
print("add first model's value")
|
||||
with safe_open(model, framework="pt", device=args.device) as f:
|
||||
with safe_open(args.models[0], framework="pt", device=args.device) as f:
|
||||
for key in tqdm(f.keys()):
|
||||
_, new_key = replace_text_encoder_key(key)
|
||||
if new_key not in supplementary_key_ratios:
|
||||
continue
|
||||
|
||||
if is_unet_key(new_key): # not VAE or TextEncoder
|
||||
if is_unet_key(new_key): # not VAE or TextEncoder
|
||||
print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")
|
||||
|
||||
value = f.get_tensor(key) # original key
|
||||
|
||||
Reference in New Issue
Block a user