Add NTK factor input

This commit is contained in:
rockerBOO
2025-03-24 05:17:42 -04:00
parent 7a86668ca4
commit dbb7fcc0fa
2 changed files with 10 additions and 5 deletions

View File

@@ -380,7 +380,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
if not args.apply_t5_attn_mask:
t5_attn_mask = None
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, proportional_attention):
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, proportional_attention, ntk_factor):
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
with torch.set_grad_enabled(is_train), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
@@ -393,7 +393,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
proportional_attention=proportional_attention
proportional_attention=proportional_attention,
ntk_factor=ntk_factor
)
return model_pred
@@ -407,6 +408,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
guidance_vec=guidance_vec,
t5_attn_mask=t5_attn_mask,
proportional_attention=args.proportional_attention,
ntk_factor=args.ntk_factor,
)
# unpack latents
@@ -439,6 +441,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
proportional_attention=args.proportional_attention,
ntk_factor=args.ntk_factor,
)
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step

View File

@@ -235,7 +235,7 @@ def sample_image_inference(
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
with accelerator.autocast(), torch.no_grad():
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image, proportional_attention=args.proportional_attention)
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image, proportional_attention=args.proportional_attention, ntk_factor=args.ntk_factor)
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
@@ -313,7 +313,8 @@ def denoise(
t5_attn_mask: Optional[torch.Tensor] = None,
controlnet: Optional[flux_models.ControlNetFlux] = None,
controlnet_img: Optional[torch.Tensor] = None,
proportional_attention: Optional[bool] =None
proportional_attention: Optional[bool] = None,
ntk_factor = 1.0
):
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
@@ -333,7 +334,6 @@ def denoise(
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
proportional_attention=proportional_attention,
)
else:
block_samples = None
@@ -350,6 +350,7 @@ def denoise(
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
proportional_attention=proportional_attention,
ntk_factor=ntk_factor,
)
img = img + (t_prev - t_curr) * pred
@@ -622,3 +623,4 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
)
parser.add_argument("--proportional_attention", action="store_true", help="Proportional attention to the image sequence length URAE")
parser.add_argument("--ntk_factor", type=float, default=1.0, help="NTK Factor for increasing the embedding space for RoPE")