mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
Merge 1b77a4a77c into d53a532a82
This commit is contained in:
@@ -8,6 +8,7 @@ import torch
|
||||
from accelerate import Accelerator
|
||||
|
||||
from library.device_utils import clean_memory_on_device, init_ipex
|
||||
from library.strategy_flux import move_vision_encoder_to_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -197,6 +198,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
args.skip_cache_check,
|
||||
is_partial=self.train_clip_l or self.train_t5xxl,
|
||||
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
||||
vision_cond_size=args.vision_cond_downsample,
|
||||
redux_path=args.redux_model_path
|
||||
)
|
||||
else:
|
||||
return None
|
||||
@@ -257,6 +260,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
text_encoders[0].to("cpu")
|
||||
logger.info("move t5XXL back to cpu")
|
||||
text_encoders[1].to("cpu")
|
||||
move_vision_encoder_to_device("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
if not args.lowram:
|
||||
@@ -380,6 +384,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
if not args.apply_t5_attn_mask:
|
||||
t5_attn_mask = None
|
||||
|
||||
if args.vision_cond_dropout < 1.0:
|
||||
if random.uniform(0,1) > args.vision_cond_dropout:
|
||||
vision_encoder_conds = batch.get("vision_encoder_outputs_list", None)
|
||||
vis_t5_out, vis_txt_ids, vis_attn_mask = vision_encoder_conds
|
||||
t5_out = torch.cat([t5_out, vis_t5_out], dim=1)
|
||||
txt_ids = torch.cat([txt_ids, vis_txt_ids], dim=1)
|
||||
if args.apply_t5_attn_mask:
|
||||
t5_attn_mask = torch.cat([t5_attn_mask, vis_attn_mask], dim=1)
|
||||
|
||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
||||
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
|
||||
@@ -680,3 +680,22 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
||||
default=3.0,
|
||||
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--redux_model_path",
|
||||
type=str,
|
||||
help="path to Redux model (*.sft or *.safetensors), should be float16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vision_cond_downsample",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Downsample Redux tokens to the specified grid size (default is 27). Zero disables this feature.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vision_cond_dropout",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Probability of dropout for Redux conditioning.",
|
||||
)
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import os
|
||||
import glob
|
||||
from math import sqrt
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
import PIL.Image
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast, SiglipVisionModel, AutoProcessor
|
||||
|
||||
from library import flux_utils, train_util
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
@@ -20,6 +23,38 @@ CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
|
||||
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
|
||||
|
||||
|
||||
# FIXME: this is a very hacky way of handling the encoder model
|
||||
siglip_model = None
|
||||
siglip_processor = None
|
||||
redux_encoder = None
|
||||
|
||||
def move_vision_encoder_to_device(device):
|
||||
if siglip_model is not None:
|
||||
siglip_model.to(device)
|
||||
if redux_encoder is not None:
|
||||
redux_encoder.to(device)
|
||||
|
||||
|
||||
class ReduxImageEncoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
redux_dim: int = 1152,
|
||||
txt_in_features: int = 4096,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.redux_dim = redux_dim
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.redux_up = torch.nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
|
||||
self.redux_down = torch.nn.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
|
||||
|
||||
def forward(self, sigclip_embeds) -> torch.Tensor:
|
||||
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
|
||||
return projected_x
|
||||
|
||||
|
||||
class FluxTokenizeStrategy(TokenizeStrategy):
|
||||
def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None:
|
||||
self.t5xxl_max_length = t5xxl_max_length
|
||||
@@ -95,10 +130,13 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
skip_disk_cache_validity_check: bool,
|
||||
is_partial: bool = False,
|
||||
apply_t5_attn_mask: bool = False,
|
||||
vision_cond_size: int = 0,
|
||||
redux_path: str = None,
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
|
||||
self.vision_cond_size = vision_cond_size
|
||||
self.redux_path = redux_path
|
||||
self.warn_fp8_weights = False
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
@@ -142,6 +180,49 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
# apply_t5_attn_mask should be same as self.apply_t5_attn_mask
|
||||
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
|
||||
|
||||
def encode_vision(self, infos, grid_size, t5_out, txt_ids):
|
||||
global siglip_model
|
||||
global siglip_processor
|
||||
global redux_encoder
|
||||
|
||||
if siglip_model is None:
|
||||
model_id = "google/siglip-so400m-patch14-384"
|
||||
siglip_model = SiglipVisionModel.from_pretrained(
|
||||
model_id, attn_implementation="sdpa", device_map="cuda")
|
||||
siglip_processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
if redux_encoder is None:
|
||||
if self.redux_path is None:
|
||||
raise Exception("Vision encoding requires Redux model, but no file was provided.")
|
||||
model_data = safetensors.torch.load_file(self.redux_path, device=torch.device("cpu").type)
|
||||
redux_encoder = ReduxImageEncoder()
|
||||
redux_encoder.load_state_dict(model_data)
|
||||
redux_encoder = redux_encoder.to(device="cuda")
|
||||
|
||||
bsz = txt_ids.shape[0]
|
||||
imgs = [PIL.Image.open(nfo.absolute_path) for nfo in infos]
|
||||
siglip_in = siglip_processor(images=imgs, padding="max_length", return_tensors="pt")
|
||||
siglip_in = siglip_in.to(device="cuda")
|
||||
|
||||
with torch.no_grad(), torch.autocast("cuda"):
|
||||
siglip_out = siglip_model(**siglip_in)
|
||||
new_embed = redux_encoder(siglip_out.last_hidden_state).float()
|
||||
(b, t, h) = new_embed.shape
|
||||
s = int(sqrt(t))
|
||||
new_embed = torch.nn.functional.interpolate(new_embed.view(b, s, s, h).transpose(1, -1),
|
||||
size=(grid_size, grid_size),
|
||||
mode="bicubic")
|
||||
new_embed = new_embed.transpose(1, -1).reshape(b, -1, h).cpu().numpy()
|
||||
new_ids = np.zeros(shape=(bsz, new_embed.shape[1], txt_ids.shape[2]))
|
||||
attn_mask = np.ones((bsz, new_embed.shape[1]))
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
new_embed_i = new_embed[i]
|
||||
new_ids_i = new_ids[i]
|
||||
attn_mask_i = attn_mask[i]
|
||||
info.vision_encoder_outputs = (new_embed_i, new_ids_i, attn_mask_i)
|
||||
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||
):
|
||||
@@ -173,6 +254,10 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
txt_ids = txt_ids.cpu().numpy()
|
||||
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
|
||||
|
||||
if self.vision_cond_size > 0:
|
||||
assert self.vision_cond_size <= 27, "Downsample ratio must not be greater than 27."
|
||||
self.encode_vision(infos, self.vision_cond_size, t5_out, txt_ids)
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
l_pooled_i = l_pooled[i]
|
||||
t5_out_i = t5_out[i]
|
||||
|
||||
@@ -209,6 +209,8 @@ class ImageInfo:
|
||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||
self.resize_interpolation: Optional[str] = None
|
||||
|
||||
self.vision_encoder_outputs: Optional[List[torch.Tensor]] = None
|
||||
|
||||
|
||||
class BucketManager:
|
||||
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
|
||||
@@ -1566,6 +1568,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
target_sizes_hw = []
|
||||
flippeds = [] # 変数名が微妙
|
||||
text_encoder_outputs_list = []
|
||||
vision_encoder_outputs_list = []
|
||||
custom_attributes = []
|
||||
|
||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||
@@ -1690,6 +1693,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
text_encoder_outputs = None
|
||||
input_ids = None
|
||||
|
||||
if image_info.vision_encoder_outputs is not None:
|
||||
vision_encoder_outputs_list.append(image_info.vision_encoder_outputs)
|
||||
|
||||
if image_info.text_encoder_outputs is not None:
|
||||
# cached
|
||||
text_encoder_outputs = image_info.text_encoder_outputs
|
||||
@@ -1745,6 +1751,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
example["custom_attributes"] = custom_attributes # may be list of empty dict
|
||||
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
||||
example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor)
|
||||
example["vision_encoder_outputs_list"] = none_or_stack_elements(vision_encoder_outputs_list, torch.FloatTensor)
|
||||
example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x)
|
||||
|
||||
# if one of alpha_masks is not None, we need to replace None with ones
|
||||
|
||||
Reference in New Issue
Block a user