Add bypass flux guidance to flux_train.py

This commit is contained in:
stepfunction83
2025-01-23 13:39:45 -05:00
parent cafc5d78de
commit a768d53d77

View File

@@ -642,6 +642,9 @@ def train(args):
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
if not args.apply_t5_attn_mask:
t5_attn_mask = None
if args.bypass_flux_guidance:
flux_utils.bypass_flux_guidance(flux)
with accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
@@ -659,6 +662,9 @@ def train(args):
# unpack latents
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
if args.bypass_flux_guidance:
flux_utils.restore_flux_guidance(flux)
# apply model prediction type
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)