This commit is contained in:
recris
2025-07-19 10:49:54 +02:00
committed by GitHub
4 changed files with 127 additions and 3 deletions

View File

@@ -8,6 +8,7 @@ import torch
from accelerate import Accelerator
from library.device_utils import clean_memory_on_device, init_ipex
from library.strategy_flux import move_vision_encoder_to_device
init_ipex()
@@ -197,6 +198,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
args.skip_cache_check,
is_partial=self.train_clip_l or self.train_t5xxl,
apply_t5_attn_mask=args.apply_t5_attn_mask,
vision_cond_size=args.vision_cond_downsample,
redux_path=args.redux_model_path
)
else:
return None
@@ -257,6 +260,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
text_encoders[0].to("cpu")
logger.info("move t5XXL back to cpu")
text_encoders[1].to("cpu")
move_vision_encoder_to_device("cpu")
clean_memory_on_device(accelerator.device)
if not args.lowram:
@@ -380,6 +384,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
if not args.apply_t5_attn_mask:
t5_attn_mask = None
if args.vision_cond_dropout < 1.0:
if random.uniform(0,1) > args.vision_cond_dropout:
vision_encoder_conds = batch.get("vision_encoder_outputs_list", None)
vis_t5_out, vis_txt_ids, vis_attn_mask = vision_encoder_conds
t5_out = torch.cat([t5_out, vis_t5_out], dim=1)
txt_ids = torch.cat([txt_ids, vis_txt_ids], dim=1)
if args.apply_t5_attn_mask:
t5_attn_mask = torch.cat([t5_attn_mask, vis_attn_mask], dim=1)
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
with torch.set_grad_enabled(is_train), accelerator.autocast():

View File

@@ -680,3 +680,22 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
default=3.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
)
parser.add_argument(
"--redux_model_path",
type=str,
help="path to Redux model (*.sft or *.safetensors), should be float16",
)
parser.add_argument(
"--vision_cond_downsample",
type=int,
default=0,
help="Downsample Redux tokens to the specified grid size (default is 27). Zero disables this feature.",
)
parser.add_argument(
"--vision_cond_dropout",
type=float,
default=1.0,
help="Probability of dropout for Redux conditioning.",
)

View File

@@ -1,9 +1,12 @@
import os
import glob
from math import sqrt
from typing import Any, List, Optional, Tuple, Union
import safetensors
import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast
import PIL.Image
from transformers import CLIPTokenizer, T5TokenizerFast, SiglipVisionModel, AutoProcessor
from library import flux_utils, train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
@@ -20,6 +23,38 @@ CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
# FIXME: this is a very hacky way of handling the encoder model
siglip_model = None
siglip_processor = None
redux_encoder = None
def move_vision_encoder_to_device(device):
if siglip_model is not None:
siglip_model.to(device)
if redux_encoder is not None:
redux_encoder.to(device)
class ReduxImageEncoder(torch.nn.Module):
def __init__(
self,
redux_dim: int = 1152,
txt_in_features: int = 4096,
device=None,
dtype=None,
) -> None:
super().__init__()
self.redux_dim = redux_dim
self.device = device
self.dtype = dtype
self.redux_up = torch.nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
self.redux_down = torch.nn.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
def forward(self, sigclip_embeds) -> torch.Tensor:
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
return projected_x
class FluxTokenizeStrategy(TokenizeStrategy):
def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None:
self.t5xxl_max_length = t5xxl_max_length
@@ -95,10 +130,13 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
apply_t5_attn_mask: bool = False,
vision_cond_size: int = 0,
redux_path: str = None,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_t5_attn_mask = apply_t5_attn_mask
self.vision_cond_size = vision_cond_size
self.redux_path = redux_path
self.warn_fp8_weights = False
def get_outputs_npz_path(self, image_abs_path: str) -> str:
@@ -142,6 +180,49 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
# apply_t5_attn_mask should be same as self.apply_t5_attn_mask
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
def encode_vision(self, infos, grid_size, t5_out, txt_ids):
global siglip_model
global siglip_processor
global redux_encoder
if siglip_model is None:
model_id = "google/siglip-so400m-patch14-384"
siglip_model = SiglipVisionModel.from_pretrained(
model_id, attn_implementation="sdpa", device_map="cuda")
siglip_processor = AutoProcessor.from_pretrained(model_id)
if redux_encoder is None:
if self.redux_path is None:
raise Exception("Vision encoding requires Redux model, but no file was provided.")
model_data = safetensors.torch.load_file(self.redux_path, device=torch.device("cpu").type)
redux_encoder = ReduxImageEncoder()
redux_encoder.load_state_dict(model_data)
redux_encoder = redux_encoder.to(device="cuda")
bsz = txt_ids.shape[0]
imgs = [PIL.Image.open(nfo.absolute_path) for nfo in infos]
siglip_in = siglip_processor(images=imgs, padding="max_length", return_tensors="pt")
siglip_in = siglip_in.to(device="cuda")
with torch.no_grad(), torch.autocast("cuda"):
siglip_out = siglip_model(**siglip_in)
new_embed = redux_encoder(siglip_out.last_hidden_state).float()
(b, t, h) = new_embed.shape
s = int(sqrt(t))
new_embed = torch.nn.functional.interpolate(new_embed.view(b, s, s, h).transpose(1, -1),
size=(grid_size, grid_size),
mode="bicubic")
new_embed = new_embed.transpose(1, -1).reshape(b, -1, h).cpu().numpy()
new_ids = np.zeros(shape=(bsz, new_embed.shape[1], txt_ids.shape[2]))
attn_mask = np.ones((bsz, new_embed.shape[1]))
for i, info in enumerate(infos):
new_embed_i = new_embed[i]
new_ids_i = new_ids[i]
attn_mask_i = attn_mask[i]
info.vision_encoder_outputs = (new_embed_i, new_ids_i, attn_mask_i)
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
@@ -173,6 +254,10 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
txt_ids = txt_ids.cpu().numpy()
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
if self.vision_cond_size > 0:
assert self.vision_cond_size <= 27, "Downsample ratio must not be greater than 27."
self.encode_vision(infos, self.vision_cond_size, t5_out, txt_ids)
for i, info in enumerate(infos):
l_pooled_i = l_pooled[i]
t5_out_i = t5_out[i]

View File

@@ -209,6 +209,8 @@ class ImageInfo:
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
self.resize_interpolation: Optional[str] = None
self.vision_encoder_outputs: Optional[List[torch.Tensor]] = None
class BucketManager:
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
@@ -1566,6 +1568,7 @@ class BaseDataset(torch.utils.data.Dataset):
target_sizes_hw = []
flippeds = [] # 変数名が微妙
text_encoder_outputs_list = []
vision_encoder_outputs_list = []
custom_attributes = []
for image_key in bucket[image_index : image_index + bucket_batch_size]:
@@ -1690,6 +1693,9 @@ class BaseDataset(torch.utils.data.Dataset):
text_encoder_outputs = None
input_ids = None
if image_info.vision_encoder_outputs is not None:
vision_encoder_outputs_list.append(image_info.vision_encoder_outputs)
if image_info.text_encoder_outputs is not None:
# cached
text_encoder_outputs = image_info.text_encoder_outputs
@@ -1745,6 +1751,7 @@ class BaseDataset(torch.utils.data.Dataset):
example["custom_attributes"] = custom_attributes # may be list of empty dict
example["loss_weights"] = torch.FloatTensor(loss_weights)
example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor)
example["vision_encoder_outputs_list"] = none_or_stack_elements(vision_encoder_outputs_list, torch.FloatTensor)
example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x)
# if one of alpha_masks is not None, we need to replace None with ones