diff --git a/flux_train_network.py b/flux_train_network.py index 3c072bc6..297d0207 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -191,7 +191,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): args.skip_cache_check, is_partial=self.train_clip_l or self.train_t5xxl, apply_t5_attn_mask=args.apply_t5_attn_mask, - vision_cond_ratio=args.vision_cond_ratio, + vision_cond_size=args.vision_cond_downsample, redux_path=args.redux_model_path ) else: @@ -379,10 +379,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): if args.vision_cond_dropout < 1.0: if random.uniform(0,1) > args.vision_cond_dropout: vision_encoder_conds = batch.get("vision_encoder_outputs_list", None) - vis_t5_out, vis_txt_ids = vision_encoder_conds - t5_out = vis_t5_out - txt_ids = vis_txt_ids - t5_attn_mask = None + vis_t5_out, vis_txt_ids, vis_attn_mask = vision_encoder_conds + t5_out = torch.cat([t5_out, vis_t5_out], dim=1) + txt_ids = torch.cat([txt_ids, vis_txt_ids], dim=1) + if args.apply_t5_attn_mask: + t5_attn_mask = torch.cat([t5_attn_mask, vis_attn_mask], dim=1) def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): # if not args.split_mode: diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 1239997b..a70f6932 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -624,10 +624,10 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): help="path to Redux model (*.sft or *.safetensors), should be float16", ) parser.add_argument( - "--vision_cond_ratio", - type=float, - default=0.0, - help="Ratio of conditioning for Redux embeddings, averaged with text encoder embeddings. Zero disables vision conditioning, maximum is 1.0", + "--vision_cond_downsample", + type=int, + default=0, + help="Downsample Redux tokens to the specified grid size (default is 27). Zero disables this feature.", ) parser.add_argument( diff --git a/library/strategy_flux.py b/library/strategy_flux.py index f3abb555..627a8d42 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -1,4 +1,5 @@ import os +from math import sqrt from typing import Any, List, Optional, Tuple, Union import safetensors @@ -129,12 +130,12 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): skip_disk_cache_validity_check: bool, is_partial: bool = False, apply_t5_attn_mask: bool = False, - vision_cond_ratio: float = 0.0, + vision_cond_size: int = 0, redux_path: str = None, ) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) self.apply_t5_attn_mask = apply_t5_attn_mask - self.vision_cond_ratio = vision_cond_ratio + self.vision_cond_size = vision_cond_size self.redux_path = redux_path self.warn_fp8_weights = False @@ -179,7 +180,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): # apply_t5_attn_mask should be same as self.apply_t5_attn_mask return [l_pooled, t5_out, txt_ids, t5_attn_mask] - def encode_vision(self, infos, ratio, t5_out, txt_ids): + def encode_vision(self, infos, grid_size, t5_out, txt_ids): global siglip_model global siglip_processor global redux_encoder @@ -205,16 +206,21 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): with torch.no_grad(), torch.autocast("cuda"): siglip_out = siglip_model(**siglip_in) - new_embed = redux_encoder(siglip_out.last_hidden_state).float().cpu().numpy() + new_embed = redux_encoder(siglip_out.last_hidden_state).float() + (b, t, h) = new_embed.shape + s = int(sqrt(t)) + new_embed = torch.nn.functional.interpolate(new_embed.view(b, s, s, h).transpose(1, -1), + size=(grid_size, grid_size), + mode="bicubic") + new_embed = new_embed.transpose(1, -1).reshape(b, -1, h).cpu().numpy() new_ids = np.zeros(shape=(bsz, new_embed.shape[1], txt_ids.shape[2])) - - t5_out_ext = np.concatenate([t5_out] + [np.zeros((bsz, new_embed.shape[1] - t5_out.shape[1], t5_out.shape[2]))], axis=1) - new_embed = new_embed * ratio + t5_out_ext * (1.0 - ratio) + attn_mask = np.ones((bsz, new_embed.shape[1])) for i, info in enumerate(infos): new_embed_i = new_embed[i] new_ids_i = new_ids[i] - info.vision_encoder_outputs = (new_embed_i, new_ids_i) + attn_mask_i = attn_mask[i] + info.vision_encoder_outputs = (new_embed_i, new_ids_i, attn_mask_i) def cache_batch_outputs( @@ -248,8 +254,9 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): txt_ids = txt_ids.cpu().numpy() t5_attn_mask = tokens_and_masks[2].cpu().numpy() - if self.vision_cond_ratio > 0.0: - self.encode_vision(infos, self.vision_cond_ratio, t5_out, txt_ids) + if self.vision_cond_size > 0: + assert self.vision_cond_size <= 27, "Downsample ratio must not be greater than 27." + self.encode_vision(infos, self.vision_cond_size, t5_out, txt_ids) for i, info in enumerate(infos): l_pooled_i = l_pooled[i] diff --git a/library/train_util.py b/library/train_util.py index f25c68f8..30151daf 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -176,7 +176,7 @@ class ImageInfo: self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime - self.vision_encoder_outputs: Optional[torch.Tensor] = None + self.vision_encoder_outputs: Optional[List[torch.Tensor]] = None class BucketManager: