From 05fd3f763f2338cf8cdf1860bb9ed465751a8016 Mon Sep 17 00:00:00 2001 From: stepfunction83 <32859451+stepfunction83@users.noreply.github.com> Date: Wed, 22 Jan 2025 19:32:39 -0500 Subject: [PATCH] Add command line argument for bypassing flux guidance --- flux_minimal_inference.py | 6 ++++-- flux_train_network.py | 9 +++++---- library/train_util.py | 8 ++++++++ 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index b4021bd1..c470bcb4 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -154,7 +154,8 @@ def do_sample( timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell) # bypass guidance module - bypass_flux_guidance(model) + if args.bypass_flux_guidance: + bypass_flux_guidance(model) # denoise initial noise if accelerator: @@ -370,7 +371,8 @@ def generate_image( img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) # restore guidance module - restore_flux_guidance(model) + if args.bypass_flux_guidance: + restore_flux_guidance(model) # save image output_dir = args.output_dir diff --git a/flux_train_network.py b/flux_train_network.py index 3035d716..ed578168 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -424,8 +424,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): """ return model_pred - - flux_utils.bypass_flux_guidance(unet) + if args.bypass_flux_guidance: + flux_utils.bypass_flux_guidance(unet) model_pred = call_dit( img=packed_noisy_model_input, @@ -440,8 +440,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) - - flux_utils.restore_flux_guidance(unet) + + if args.bypass_flux_guidance: + flux_utils.restore_flux_guidance(unet) # apply model prediction type model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) diff --git a/library/train_util.py b/library/train_util.py index 72b5b24d..3180d694 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4103,6 +4103,14 @@ def add_dit_training_arguments(parser: argparse.ArgumentParser): "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", ) + # bypass guidance module for flux + parser.add_argument( + "--bypass_flux_guidance" + , action="store_true" + , help="bypass flux guidance module for Flex.1-Alpha Training / Flex.1-Alpha トレーニング用バイパス フラックス ガイダンス モジュール" + ) + + def get_sanitized_config_or_none(args: argparse.Namespace): # if `--log_config` is enabled, return args for logging. if not, return None.