mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 08:52:45 +00:00
Use token downsampling to control Redux strength.
This commit is contained in:
@@ -191,7 +191,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
args.skip_cache_check,
|
||||
is_partial=self.train_clip_l or self.train_t5xxl,
|
||||
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
||||
vision_cond_ratio=args.vision_cond_ratio,
|
||||
vision_cond_size=args.vision_cond_downsample,
|
||||
redux_path=args.redux_model_path
|
||||
)
|
||||
else:
|
||||
@@ -379,10 +379,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
if args.vision_cond_dropout < 1.0:
|
||||
if random.uniform(0,1) > args.vision_cond_dropout:
|
||||
vision_encoder_conds = batch.get("vision_encoder_outputs_list", None)
|
||||
vis_t5_out, vis_txt_ids = vision_encoder_conds
|
||||
t5_out = vis_t5_out
|
||||
txt_ids = vis_txt_ids
|
||||
t5_attn_mask = None
|
||||
vis_t5_out, vis_txt_ids, vis_attn_mask = vision_encoder_conds
|
||||
t5_out = torch.cat([t5_out, vis_t5_out], dim=1)
|
||||
txt_ids = torch.cat([txt_ids, vis_txt_ids], dim=1)
|
||||
if args.apply_t5_attn_mask:
|
||||
t5_attn_mask = torch.cat([t5_attn_mask, vis_attn_mask], dim=1)
|
||||
|
||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
||||
# if not args.split_mode:
|
||||
|
||||
@@ -624,10 +624,10 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
||||
help="path to Redux model (*.sft or *.safetensors), should be float16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vision_cond_ratio",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Ratio of conditioning for Redux embeddings, averaged with text encoder embeddings. Zero disables vision conditioning, maximum is 1.0",
|
||||
"--vision_cond_downsample",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Downsample Redux tokens to the specified grid size (default is 27). Zero disables this feature.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from math import sqrt
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import safetensors
|
||||
@@ -129,12 +130,12 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
skip_disk_cache_validity_check: bool,
|
||||
is_partial: bool = False,
|
||||
apply_t5_attn_mask: bool = False,
|
||||
vision_cond_ratio: float = 0.0,
|
||||
vision_cond_size: int = 0,
|
||||
redux_path: str = None,
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
self.vision_cond_ratio = vision_cond_ratio
|
||||
self.vision_cond_size = vision_cond_size
|
||||
self.redux_path = redux_path
|
||||
self.warn_fp8_weights = False
|
||||
|
||||
@@ -179,7 +180,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
# apply_t5_attn_mask should be same as self.apply_t5_attn_mask
|
||||
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
|
||||
|
||||
def encode_vision(self, infos, ratio, t5_out, txt_ids):
|
||||
def encode_vision(self, infos, grid_size, t5_out, txt_ids):
|
||||
global siglip_model
|
||||
global siglip_processor
|
||||
global redux_encoder
|
||||
@@ -205,16 +206,21 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
|
||||
with torch.no_grad(), torch.autocast("cuda"):
|
||||
siglip_out = siglip_model(**siglip_in)
|
||||
new_embed = redux_encoder(siglip_out.last_hidden_state).float().cpu().numpy()
|
||||
new_embed = redux_encoder(siglip_out.last_hidden_state).float()
|
||||
(b, t, h) = new_embed.shape
|
||||
s = int(sqrt(t))
|
||||
new_embed = torch.nn.functional.interpolate(new_embed.view(b, s, s, h).transpose(1, -1),
|
||||
size=(grid_size, grid_size),
|
||||
mode="bicubic")
|
||||
new_embed = new_embed.transpose(1, -1).reshape(b, -1, h).cpu().numpy()
|
||||
new_ids = np.zeros(shape=(bsz, new_embed.shape[1], txt_ids.shape[2]))
|
||||
|
||||
t5_out_ext = np.concatenate([t5_out] + [np.zeros((bsz, new_embed.shape[1] - t5_out.shape[1], t5_out.shape[2]))], axis=1)
|
||||
new_embed = new_embed * ratio + t5_out_ext * (1.0 - ratio)
|
||||
attn_mask = np.ones((bsz, new_embed.shape[1]))
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
new_embed_i = new_embed[i]
|
||||
new_ids_i = new_ids[i]
|
||||
info.vision_encoder_outputs = (new_embed_i, new_ids_i)
|
||||
attn_mask_i = attn_mask[i]
|
||||
info.vision_encoder_outputs = (new_embed_i, new_ids_i, attn_mask_i)
|
||||
|
||||
|
||||
def cache_batch_outputs(
|
||||
@@ -248,8 +254,9 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
txt_ids = txt_ids.cpu().numpy()
|
||||
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
|
||||
|
||||
if self.vision_cond_ratio > 0.0:
|
||||
self.encode_vision(infos, self.vision_cond_ratio, t5_out, txt_ids)
|
||||
if self.vision_cond_size > 0:
|
||||
assert self.vision_cond_size <= 27, "Downsample ratio must not be greater than 27."
|
||||
self.encode_vision(infos, self.vision_cond_size, t5_out, txt_ids)
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
l_pooled_i = l_pooled[i]
|
||||
|
||||
@@ -176,7 +176,7 @@ class ImageInfo:
|
||||
|
||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||
|
||||
self.vision_encoder_outputs: Optional[torch.Tensor] = None
|
||||
self.vision_encoder_outputs: Optional[List[torch.Tensor]] = None
|
||||
|
||||
|
||||
class BucketManager:
|
||||
|
||||
Reference in New Issue
Block a user