mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Merge branch 'sd3' into sd3_5_support
This commit is contained in:
@@ -54,6 +54,10 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int
|
|||||||
with safe_open(ckpt_path, framework="pt") as f:
|
with safe_open(ckpt_path, framework="pt") as f:
|
||||||
keys.extend(f.keys())
|
keys.extend(f.keys())
|
||||||
|
|
||||||
|
# if the key has annoying prefix, remove it
|
||||||
|
if keys[0].startswith("model.diffusion_model."):
|
||||||
|
keys = [key.replace("model.diffusion_model.", "") for key in keys]
|
||||||
|
|
||||||
is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
|
is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
|
||||||
is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
|
is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
|
||||||
|
|
||||||
@@ -122,6 +126,13 @@ def load_flow_model(
|
|||||||
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
|
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
|
||||||
logger.info("Converted Diffusers to BFL")
|
logger.info("Converted Diffusers to BFL")
|
||||||
|
|
||||||
|
# if the key has annoying prefix, remove it
|
||||||
|
for key in list(sd.keys()):
|
||||||
|
new_key = key.replace("model.diffusion_model.", "")
|
||||||
|
if new_key == key:
|
||||||
|
break # the model doesn't have annoying prefix
|
||||||
|
sd[new_key] = sd.pop(key)
|
||||||
|
|
||||||
info = model.load_state_dict(sd, strict=False, assign=True)
|
info = model.load_state_dict(sd, strict=False, assign=True)
|
||||||
logger.info(f"Loaded Flux: {info}")
|
logger.info(f"Loaded Flux: {info}")
|
||||||
return is_schnell, model
|
return is_schnell, model
|
||||||
|
|||||||
Reference in New Issue
Block a user