mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 00:32:25 +00:00
Add NTK factor input
This commit is contained in:
@@ -380,7 +380,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
if not args.apply_t5_attn_mask:
|
||||
t5_attn_mask = None
|
||||
|
||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, proportional_attention):
|
||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, proportional_attention, ntk_factor):
|
||||
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
@@ -393,7 +393,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
timesteps=timesteps / 1000,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
proportional_attention=proportional_attention
|
||||
proportional_attention=proportional_attention,
|
||||
ntk_factor=ntk_factor
|
||||
)
|
||||
return model_pred
|
||||
|
||||
@@ -407,6 +408,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
guidance_vec=guidance_vec,
|
||||
t5_attn_mask=t5_attn_mask,
|
||||
proportional_attention=args.proportional_attention,
|
||||
ntk_factor=args.ntk_factor,
|
||||
)
|
||||
|
||||
# unpack latents
|
||||
@@ -439,6 +441,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
|
||||
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
|
||||
proportional_attention=args.proportional_attention,
|
||||
ntk_factor=args.ntk_factor,
|
||||
)
|
||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||
|
||||
|
||||
@@ -235,7 +235,7 @@ def sample_image_inference(
|
||||
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
|
||||
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image, proportional_attention=args.proportional_attention)
|
||||
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image, proportional_attention=args.proportional_attention, ntk_factor=args.ntk_factor)
|
||||
|
||||
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
||||
|
||||
@@ -313,7 +313,8 @@ def denoise(
|
||||
t5_attn_mask: Optional[torch.Tensor] = None,
|
||||
controlnet: Optional[flux_models.ControlNetFlux] = None,
|
||||
controlnet_img: Optional[torch.Tensor] = None,
|
||||
proportional_attention: Optional[bool] =None
|
||||
proportional_attention: Optional[bool] = None,
|
||||
ntk_factor = 1.0
|
||||
):
|
||||
# this is ignored for schnell
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
@@ -333,7 +334,6 @@ def denoise(
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
proportional_attention=proportional_attention,
|
||||
)
|
||||
else:
|
||||
block_samples = None
|
||||
@@ -350,6 +350,7 @@ def denoise(
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
proportional_attention=proportional_attention,
|
||||
ntk_factor=ntk_factor,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
@@ -622,3 +623,4 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
||||
)
|
||||
|
||||
parser.add_argument("--proportional_attention", action="store_true", help="Proportional attention to the image sequence length URAE")
|
||||
parser.add_argument("--ntk_factor", type=float, default=1.0, help="NTK Factor for increasing the embedding space for RoPE")
|
||||
|
||||
Reference in New Issue
Block a user