Add command line argument for bypassing flux guidance

This commit is contained in:
stepfunction83
2025-01-22 19:32:39 -05:00
parent b203e31877
commit 05fd3f763f
3 changed files with 17 additions and 6 deletions

View File

@@ -154,6 +154,7 @@ def do_sample(
timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell)
# bypass guidance module
if args.bypass_flux_guidance:
bypass_flux_guidance(model)
# denoise initial noise
@@ -370,6 +371,7 @@ def generate_image(
img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
# restore guidance module
if args.bypass_flux_guidance:
restore_flux_guidance(model)
# save image

View File

@@ -424,7 +424,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
"""
return model_pred
if args.bypass_flux_guidance:
flux_utils.bypass_flux_guidance(unet)
model_pred = call_dit(
@@ -441,6 +441,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# unpack latents
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
if args.bypass_flux_guidance:
flux_utils.restore_flux_guidance(unet)
# apply model prediction type

View File

@@ -4103,6 +4103,14 @@ def add_dit_training_arguments(parser: argparse.ArgumentParser):
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度s/itも低下します。",
)
# bypass guidance module for flux
parser.add_argument(
"--bypass_flux_guidance"
, action="store_true"
, help="bypass flux guidance module for Flex.1-Alpha Training / Flex.1-Alpha トレーニング用バイパス フラックス ガイダンス モジュール"
)
def get_sanitized_config_or_none(args: argparse.Namespace):
# if `--log_config` is enabled, return args for logging. if not, return None.