Merge branch 'sd3_5_support' of https://github.com/kohya-ss/sd-scripts into sd3_5_support

This commit is contained in:
Kohya S
2024-10-30 12:51:55 +09:00
3 changed files with 74 additions and 22 deletions

View File

@@ -54,6 +54,10 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int
with safe_open(ckpt_path, framework="pt") as f:
keys.extend(f.keys())
# if the key has annoying prefix, remove it
if keys[0].startswith("model.diffusion_model."):
keys = [key.replace("model.diffusion_model.", "") for key in keys]
is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
@@ -122,6 +126,13 @@ def load_flow_model(
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
logger.info("Converted Diffusers to BFL")
# if the key has annoying prefix, remove it
for key in list(sd.keys()):
new_key = key.replace("model.diffusion_model.", "")
if new_key == key:
break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key)
info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Flux: {info}")
return is_schnell, model