diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 7ab224f1..c470bcb4 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -20,6 +20,8 @@ from library import device_utils from library.device_utils import init_ipex, get_preferred_device from networks import oft_flux +from library.flux_utils import bypass_flux_guidance, restore_flux_guidance + init_ipex() @@ -151,6 +153,10 @@ def do_sample( logger.info(f"num_steps: {num_steps}") timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell) + # bypass guidance module + if args.bypass_flux_guidance: + bypass_flux_guidance(model) + # denoise initial noise if accelerator: with accelerator.autocast(), torch.no_grad(): @@ -364,6 +370,10 @@ def generate_image( x = x.permute(0, 2, 3, 1) img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) + # restore guidance module + if args.bypass_flux_guidance: + restore_flux_guidance(model) + # save image output_dir = args.output_dir os.makedirs(output_dir, exist_ok=True) diff --git a/flux_train.py b/flux_train.py index 6f98adea..7d3faea2 100644 --- a/flux_train.py +++ b/flux_train.py @@ -643,6 +643,9 @@ def train(args): l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds if not args.apply_t5_attn_mask: t5_attn_mask = None + + if args.bypass_flux_guidance: + flux_utils.bypass_flux_guidance(flux) with accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) @@ -660,6 +663,9 @@ def train(args): # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + if args.bypass_flux_guidance: + flux_utils.restore_flux_guidance(flux) + # apply model prediction type model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) diff --git a/flux_train_network.py b/flux_train_network.py index def44155..1f725a59 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -395,6 +395,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): txt_attention_mask=t5_attn_mask, ) return model_pred + if args.bypass_flux_guidance: + flux_utils.bypass_flux_guidance(unet) model_pred = call_dit( img=packed_noisy_model_input, @@ -409,6 +411,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + if args.bypass_flux_guidance: + flux_utils.restore_flux_guidance(unet) # apply model prediction type model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 8392e559..3ca21eb2 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -680,3 +680,10 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): default=3.0, help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", ) + + # bypass guidance module for flux + parser.add_argument( + "--bypass_flux_guidance" + , action="store_true" + , help="bypass flux guidance module for Flex.1-Alpha Training / Flex.1-Alpha トレーニング用バイパス フラックス ガイダンス モジュール" + ) diff --git a/library/flux_utils.py b/library/flux_utils.py index 8be1d63e..309f0772 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -24,6 +24,13 @@ MODEL_VERSION_FLUX_V1 = "flux1" MODEL_NAME_DEV = "dev" MODEL_NAME_SCHNELL = "schnell" +# bypass guidance +def bypass_flux_guidance(transformer): + transformer.params.guidance_embed = False + +# restore the forward function +def restore_flux_guidance(transformer): + transformer.params.guidance_embed = True def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: """