diff --git a/flux_train_network.py b/flux_train_network.py index 66656918..b223cc0d 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -380,7 +380,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): if not args.apply_t5_attn_mask: t5_attn_mask = None - def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, proportional_attention): + def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, proportional_attention, ntk_factor): # grad is enabled even if unet is not in train mode, because Text Encoder is in train mode with torch.set_grad_enabled(is_train), accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) @@ -393,7 +393,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): timesteps=timesteps / 1000, guidance=guidance_vec, txt_attention_mask=t5_attn_mask, - proportional_attention=proportional_attention + proportional_attention=proportional_attention, + ntk_factor=ntk_factor ) return model_pred @@ -407,6 +408,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): guidance_vec=guidance_vec, t5_attn_mask=t5_attn_mask, proportional_attention=args.proportional_attention, + ntk_factor=args.ntk_factor, ) # unpack latents @@ -439,6 +441,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None, t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None, proportional_attention=args.proportional_attention, + ntk_factor=args.ntk_factor, ) network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 88144a9f..73ea1624 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -235,7 +235,7 @@ def sample_image_inference( controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device) with accelerator.autocast(), torch.no_grad(): - x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image, proportional_attention=args.proportional_attention) + x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image, proportional_attention=args.proportional_attention, ntk_factor=args.ntk_factor) x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) @@ -313,7 +313,8 @@ def denoise( t5_attn_mask: Optional[torch.Tensor] = None, controlnet: Optional[flux_models.ControlNetFlux] = None, controlnet_img: Optional[torch.Tensor] = None, - proportional_attention: Optional[bool] =None + proportional_attention: Optional[bool] = None, + ntk_factor = 1.0 ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) @@ -333,7 +334,6 @@ def denoise( timesteps=t_vec, guidance=guidance_vec, txt_attention_mask=t5_attn_mask, - proportional_attention=proportional_attention, ) else: block_samples = None @@ -350,6 +350,7 @@ def denoise( guidance=guidance_vec, txt_attention_mask=t5_attn_mask, proportional_attention=proportional_attention, + ntk_factor=ntk_factor, ) img = img + (t_prev - t_curr) * pred @@ -622,3 +623,4 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): ) parser.add_argument("--proportional_attention", action="store_true", help="Proportional attention to the image sequence length URAE") + parser.add_argument("--ntk_factor", type=float, default=1.0, help="NTK Factor for increasing the embedding space for RoPE")