Use token downsampling to control Redux strength.

This commit is contained in:
recris
2025-01-06 21:06:30 +00:00
parent 9d28701699
commit 1b77a4a77c
4 changed files with 28 additions and 20 deletions

View File

@@ -191,7 +191,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
args.skip_cache_check,
is_partial=self.train_clip_l or self.train_t5xxl,
apply_t5_attn_mask=args.apply_t5_attn_mask,
vision_cond_ratio=args.vision_cond_ratio,
vision_cond_size=args.vision_cond_downsample,
redux_path=args.redux_model_path
)
else:
@@ -379,10 +379,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
if args.vision_cond_dropout < 1.0:
if random.uniform(0,1) > args.vision_cond_dropout:
vision_encoder_conds = batch.get("vision_encoder_outputs_list", None)
vis_t5_out, vis_txt_ids = vision_encoder_conds
t5_out = vis_t5_out
txt_ids = vis_txt_ids
t5_attn_mask = None
vis_t5_out, vis_txt_ids, vis_attn_mask = vision_encoder_conds
t5_out = torch.cat([t5_out, vis_t5_out], dim=1)
txt_ids = torch.cat([txt_ids, vis_txt_ids], dim=1)
if args.apply_t5_attn_mask:
t5_attn_mask = torch.cat([t5_attn_mask, vis_attn_mask], dim=1)
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
# if not args.split_mode:

View File

@@ -624,10 +624,10 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
help="path to Redux model (*.sft or *.safetensors), should be float16",
)
parser.add_argument(
"--vision_cond_ratio",
type=float,
default=0.0,
help="Ratio of conditioning for Redux embeddings, averaged with text encoder embeddings. Zero disables vision conditioning, maximum is 1.0",
"--vision_cond_downsample",
type=int,
default=0,
help="Downsample Redux tokens to the specified grid size (default is 27). Zero disables this feature.",
)
parser.add_argument(

View File

@@ -1,4 +1,5 @@
import os
from math import sqrt
from typing import Any, List, Optional, Tuple, Union
import safetensors
@@ -129,12 +130,12 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
apply_t5_attn_mask: bool = False,
vision_cond_ratio: float = 0.0,
vision_cond_size: int = 0,
redux_path: str = None,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_t5_attn_mask = apply_t5_attn_mask
self.vision_cond_ratio = vision_cond_ratio
self.vision_cond_size = vision_cond_size
self.redux_path = redux_path
self.warn_fp8_weights = False
@@ -179,7 +180,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
# apply_t5_attn_mask should be same as self.apply_t5_attn_mask
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
def encode_vision(self, infos, ratio, t5_out, txt_ids):
def encode_vision(self, infos, grid_size, t5_out, txt_ids):
global siglip_model
global siglip_processor
global redux_encoder
@@ -205,16 +206,21 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
with torch.no_grad(), torch.autocast("cuda"):
siglip_out = siglip_model(**siglip_in)
new_embed = redux_encoder(siglip_out.last_hidden_state).float().cpu().numpy()
new_embed = redux_encoder(siglip_out.last_hidden_state).float()
(b, t, h) = new_embed.shape
s = int(sqrt(t))
new_embed = torch.nn.functional.interpolate(new_embed.view(b, s, s, h).transpose(1, -1),
size=(grid_size, grid_size),
mode="bicubic")
new_embed = new_embed.transpose(1, -1).reshape(b, -1, h).cpu().numpy()
new_ids = np.zeros(shape=(bsz, new_embed.shape[1], txt_ids.shape[2]))
t5_out_ext = np.concatenate([t5_out] + [np.zeros((bsz, new_embed.shape[1] - t5_out.shape[1], t5_out.shape[2]))], axis=1)
new_embed = new_embed * ratio + t5_out_ext * (1.0 - ratio)
attn_mask = np.ones((bsz, new_embed.shape[1]))
for i, info in enumerate(infos):
new_embed_i = new_embed[i]
new_ids_i = new_ids[i]
info.vision_encoder_outputs = (new_embed_i, new_ids_i)
attn_mask_i = attn_mask[i]
info.vision_encoder_outputs = (new_embed_i, new_ids_i, attn_mask_i)
def cache_batch_outputs(
@@ -248,8 +254,9 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
txt_ids = txt_ids.cpu().numpy()
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
if self.vision_cond_ratio > 0.0:
self.encode_vision(infos, self.vision_cond_ratio, t5_out, txt_ids)
if self.vision_cond_size > 0:
assert self.vision_cond_size <= 27, "Downsample ratio must not be greater than 27."
self.encode_vision(infos, self.vision_cond_size, t5_out, txt_ids)
for i, info in enumerate(infos):
l_pooled_i = l_pooled[i]

View File

@@ -176,7 +176,7 @@ class ImageInfo:
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
self.vision_encoder_outputs: Optional[torch.Tensor] = None
self.vision_encoder_outputs: Optional[List[torch.Tensor]] = None
class BucketManager: