diff --git a/library/flux_utils.py b/library/flux_utils.py index 95df71cd..309f0772 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -24,32 +24,13 @@ MODEL_VERSION_FLUX_V1 = "flux1" MODEL_NAME_DEV = "dev" MODEL_NAME_SCHNELL = "schnell" -def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection): - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder( - timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) - pooled_projections = self.text_embedder(pooled_projection) - conditioning = timesteps_emb + pooled_projections - return conditioning - -# bypass the forward function +# bypass guidance def bypass_flux_guidance(transformer): - if hasattr(transformer.time_text_embed, '_bfg_orig_forward'): - return - # dont bypass if it doesnt have the guidance embedding - if not hasattr(transformer.time_text_embed, 'guidance_embedder'): - return - transformer.time_text_embed._bfg_orig_forward = transformer.time_text_embed.forward - transformer.time_text_embed.forward = partial( - guidance_embed_bypass_forward, transformer.time_text_embed - ) + transformer.params.guidance_embed = False # restore the forward function def restore_flux_guidance(transformer): - if not hasattr(transformer.time_text_embed, '_bfg_orig_forward'): - return - transformer.time_text_embed.forward = transformer.time_text_embed._bfg_orig_forward - del transformer.time_text_embed._bfg_orig_forward + transformer.params.guidance_embed = True def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: """ @@ -86,7 +67,6 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys) - # is_schnell = True # check number of double and single blocks if not is_diffusers: