mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
Add bypass flux guidance to flux_train.py
This commit is contained in:
@@ -642,6 +642,9 @@ def train(args):
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
if not args.apply_t5_attn_mask:
|
||||
t5_attn_mask = None
|
||||
|
||||
if args.bypass_flux_guidance:
|
||||
flux_utils.bypass_flux_guidance(flux)
|
||||
|
||||
with accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
@@ -659,6 +662,9 @@ def train(args):
|
||||
# unpack latents
|
||||
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
||||
|
||||
if args.bypass_flux_guidance:
|
||||
flux_utils.restore_flux_guidance(flux)
|
||||
|
||||
# apply model prediction type
|
||||
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user