Experimental Redux conditioning for Flux Lora training

This commit is contained in:
John Doe
2024-12-15 19:31:03 +00:00
committed by recris
parent e425996a59
commit 9d28701699
4 changed files with 119 additions and 3 deletions

View File

@@ -8,6 +8,7 @@ import torch
from accelerate import Accelerator
from library.device_utils import clean_memory_on_device, init_ipex
from library.strategy_flux import move_vision_encoder_to_device
init_ipex()
@@ -190,6 +191,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
args.skip_cache_check,
is_partial=self.train_clip_l or self.train_t5xxl,
apply_t5_attn_mask=args.apply_t5_attn_mask,
vision_cond_ratio=args.vision_cond_ratio,
redux_path=args.redux_model_path
)
else:
return None
@@ -250,6 +253,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
text_encoders[0].to("cpu")
logger.info("move t5XXL back to cpu")
text_encoders[1].to("cpu")
move_vision_encoder_to_device("cpu")
clean_memory_on_device(accelerator.device)
if not args.lowram:
@@ -372,6 +376,14 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
if not args.apply_t5_attn_mask:
t5_attn_mask = None
if args.vision_cond_dropout < 1.0:
if random.uniform(0,1) > args.vision_cond_dropout:
vision_encoder_conds = batch.get("vision_encoder_outputs_list", None)
vis_t5_out, vis_txt_ids = vision_encoder_conds
t5_out = vis_t5_out
txt_ids = vis_txt_ids
t5_attn_mask = None
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
# if not args.split_mode:
# normal forward

View File

@@ -617,3 +617,22 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
default=3.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
)
parser.add_argument(
"--redux_model_path",
type=str,
help="path to Redux model (*.sft or *.safetensors), should be float16",
)
parser.add_argument(
"--vision_cond_ratio",
type=float,
default=0.0,
help="Ratio of conditioning for Redux embeddings, averaged with text encoder embeddings. Zero disables vision conditioning, maximum is 1.0",
)
parser.add_argument(
"--vision_cond_dropout",
type=float,
default=1.0,
help="Probability of dropout for Redux conditioning.",
)

View File

@@ -1,9 +1,11 @@
import os
import glob
from typing import Any, List, Optional, Tuple, Union
import safetensors
import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast
import PIL.Image
from transformers import CLIPTokenizer, T5TokenizerFast, SiglipVisionModel, AutoProcessor
from library import flux_utils, train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
@@ -20,6 +22,38 @@ CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
# FIXME: this is a very hacky way of handling the encoder model
siglip_model = None
siglip_processor = None
redux_encoder = None
def move_vision_encoder_to_device(device):
if siglip_model is not None:
siglip_model.to(device)
if redux_encoder is not None:
redux_encoder.to(device)
class ReduxImageEncoder(torch.nn.Module):
def __init__(
self,
redux_dim: int = 1152,
txt_in_features: int = 4096,
device=None,
dtype=None,
) -> None:
super().__init__()
self.redux_dim = redux_dim
self.device = device
self.dtype = dtype
self.redux_up = torch.nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
self.redux_down = torch.nn.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
def forward(self, sigclip_embeds) -> torch.Tensor:
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
return projected_x
class FluxTokenizeStrategy(TokenizeStrategy):
def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None:
self.t5xxl_max_length = t5xxl_max_length
@@ -95,10 +129,13 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
apply_t5_attn_mask: bool = False,
vision_cond_ratio: float = 0.0,
redux_path: str = None,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_t5_attn_mask = apply_t5_attn_mask
self.vision_cond_ratio = vision_cond_ratio
self.redux_path = redux_path
self.warn_fp8_weights = False
def get_outputs_npz_path(self, image_abs_path: str) -> str:
@@ -142,6 +179,44 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
# apply_t5_attn_mask should be same as self.apply_t5_attn_mask
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
def encode_vision(self, infos, ratio, t5_out, txt_ids):
global siglip_model
global siglip_processor
global redux_encoder
if siglip_model is None:
model_id = "google/siglip-so400m-patch14-384"
siglip_model = SiglipVisionModel.from_pretrained(
model_id, attn_implementation="sdpa", device_map="cuda")
siglip_processor = AutoProcessor.from_pretrained(model_id)
if redux_encoder is None:
if self.redux_path is None:
raise Exception("Vision encoding requires Redux model, but no file was provided.")
model_data = safetensors.torch.load_file(self.redux_path, device=torch.device("cpu").type)
redux_encoder = ReduxImageEncoder()
redux_encoder.load_state_dict(model_data)
redux_encoder = redux_encoder.to(device="cuda")
bsz = txt_ids.shape[0]
imgs = [PIL.Image.open(nfo.absolute_path) for nfo in infos]
siglip_in = siglip_processor(images=imgs, padding="max_length", return_tensors="pt")
siglip_in = siglip_in.to(device="cuda")
with torch.no_grad(), torch.autocast("cuda"):
siglip_out = siglip_model(**siglip_in)
new_embed = redux_encoder(siglip_out.last_hidden_state).float().cpu().numpy()
new_ids = np.zeros(shape=(bsz, new_embed.shape[1], txt_ids.shape[2]))
t5_out_ext = np.concatenate([t5_out] + [np.zeros((bsz, new_embed.shape[1] - t5_out.shape[1], t5_out.shape[2]))], axis=1)
new_embed = new_embed * ratio + t5_out_ext * (1.0 - ratio)
for i, info in enumerate(infos):
new_embed_i = new_embed[i]
new_ids_i = new_ids[i]
info.vision_encoder_outputs = (new_embed_i, new_ids_i)
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
@@ -173,6 +248,9 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
txt_ids = txt_ids.cpu().numpy()
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
if self.vision_cond_ratio > 0.0:
self.encode_vision(infos, self.vision_cond_ratio, t5_out, txt_ids)
for i, info in enumerate(infos):
l_pooled_i = l_pooled[i]
t5_out_i = t5_out[i]

View File

@@ -176,6 +176,8 @@ class ImageInfo:
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
self.vision_encoder_outputs: Optional[torch.Tensor] = None
class BucketManager:
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
@@ -1497,6 +1499,7 @@ class BaseDataset(torch.utils.data.Dataset):
target_sizes_hw = []
flippeds = [] # 変数名が微妙
text_encoder_outputs_list = []
vision_encoder_outputs_list = []
custom_attributes = []
for image_key in bucket[image_index : image_index + bucket_batch_size]:
@@ -1621,6 +1624,9 @@ class BaseDataset(torch.utils.data.Dataset):
text_encoder_outputs = None
input_ids = None
if image_info.vision_encoder_outputs is not None:
vision_encoder_outputs_list.append(image_info.vision_encoder_outputs)
if image_info.text_encoder_outputs is not None:
# cached
text_encoder_outputs = image_info.text_encoder_outputs
@@ -1676,6 +1682,7 @@ class BaseDataset(torch.utils.data.Dataset):
example["custom_attributes"] = custom_attributes # may be list of empty dict
example["loss_weights"] = torch.FloatTensor(loss_weights)
example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor)
example["vision_encoder_outputs_list"] = none_or_stack_elements(vision_encoder_outputs_list, torch.FloatTensor)
example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x)
# if one of alpha_masks is not None, we need to replace None with ones