diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 7ab224f1..b4021bd1 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -20,6 +20,8 @@ from library import device_utils from library.device_utils import init_ipex, get_preferred_device from networks import oft_flux +from library.flux_utils import bypass_flux_guidance, restore_flux_guidance + init_ipex() @@ -151,6 +153,9 @@ def do_sample( logger.info(f"num_steps: {num_steps}") timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell) + # bypass guidance module + bypass_flux_guidance(model) + # denoise initial noise if accelerator: with accelerator.autocast(), torch.no_grad(): @@ -364,6 +369,9 @@ def generate_image( x = x.permute(0, 2, 3, 1) img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) + # restore guidance module + restore_flux_guidance(model) + # save image output_dir = args.output_dir os.makedirs(output_dir, exist_ok=True) diff --git a/flux_train_network.py b/flux_train_network.py index 75e975ba..3035d716 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -425,6 +425,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): return model_pred + flux_utils.bypass_flux_guidance(unet) + model_pred = call_dit( img=packed_noisy_model_input, img_ids=img_ids, @@ -439,6 +441,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + flux_utils.restore_flux_guidance(unet) + # apply model prediction type model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) diff --git a/library/flux_utils.py b/library/flux_utils.py index 8be1d63e..95df71cd 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -24,6 +24,32 @@ MODEL_VERSION_FLUX_V1 = "flux1" MODEL_NAME_DEV = "dev" MODEL_NAME_SCHNELL = "schnell" +def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) + pooled_projections = self.text_embedder(pooled_projection) + conditioning = timesteps_emb + pooled_projections + return conditioning + +# bypass the forward function +def bypass_flux_guidance(transformer): + if hasattr(transformer.time_text_embed, '_bfg_orig_forward'): + return + # dont bypass if it doesnt have the guidance embedding + if not hasattr(transformer.time_text_embed, 'guidance_embedder'): + return + transformer.time_text_embed._bfg_orig_forward = transformer.time_text_embed.forward + transformer.time_text_embed.forward = partial( + guidance_embed_bypass_forward, transformer.time_text_embed + ) + +# restore the forward function +def restore_flux_guidance(transformer): + if not hasattr(transformer.time_text_embed, '_bfg_orig_forward'): + return + transformer.time_text_embed.forward = transformer.time_text_embed._bfg_orig_forward + del transformer.time_text_embed._bfg_orig_forward def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: """ @@ -60,6 +86,7 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys) + # is_schnell = True # check number of double and single blocks if not is_diffusers: