mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Merge 10ff863b8a into 1dae34b0af
This commit is contained in:
@@ -36,6 +36,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
self.is_schnell: Optional[bool] = None
|
self.is_schnell: Optional[bool] = None
|
||||||
self.is_swapping_blocks: bool = False
|
self.is_swapping_blocks: bool = False
|
||||||
self.model_type: Optional[str] = None
|
self.model_type: Optional[str] = None
|
||||||
|
self.args = None
|
||||||
|
|
||||||
def assert_extra_args(
|
def assert_extra_args(
|
||||||
self,
|
self,
|
||||||
@@ -162,25 +163,50 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
def get_tokenize_strategy(self, args):
|
def get_tokenize_strategy(self, args):
|
||||||
# This method is called before `assert_extra_args`, so we cannot use `self.is_schnell` here.
|
# This method is called before `assert_extra_args`, so we cannot use `self.is_schnell` here.
|
||||||
# Instead, we analyze the checkpoint state to determine if it is schnell.
|
# Instead, we analyze the checkpoint state to determine if it is schnell.
|
||||||
if args.model_type != "chroma":
|
if args.model_type == "chroma":
|
||||||
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
|
||||||
else:
|
|
||||||
is_schnell = False
|
is_schnell = False
|
||||||
self.is_schnell = is_schnell
|
self.is_schnell = is_schnell
|
||||||
|
t5xxl_max_token_length = args.t5xxl_max_token_length or 512
|
||||||
if args.t5xxl_max_token_length is None:
|
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
|
||||||
if self.is_schnell:
|
# Chroma doesn't use CLIP-L
|
||||||
t5xxl_max_token_length = 256
|
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir, use_clip_l=False)
|
||||||
else:
|
|
||||||
t5xxl_max_token_length = 512
|
|
||||||
else:
|
else:
|
||||||
t5xxl_max_token_length = args.t5xxl_max_token_length
|
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
||||||
|
self.is_schnell = is_schnell
|
||||||
|
|
||||||
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
|
if args.t5xxl_max_token_length is None:
|
||||||
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
|
if self.is_schnell:
|
||||||
|
t5xxl_max_token_length = 256
|
||||||
|
else:
|
||||||
|
t5xxl_max_token_length = 512
|
||||||
|
else:
|
||||||
|
t5xxl_max_token_length = args.t5xxl_max_token_length
|
||||||
|
|
||||||
|
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
|
||||||
|
# FLUX models use both CLIP-L and T5
|
||||||
|
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir, use_clip_l=True)
|
||||||
|
|
||||||
def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
|
def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
|
||||||
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
|
# Try to access the tokenizers through the tokenize strategy's attributes
|
||||||
|
# First, check if the attributes exist directly
|
||||||
|
if hasattr(tokenize_strategy, 'clip_l') and hasattr(tokenize_strategy, 't5xxl'):
|
||||||
|
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
|
||||||
|
# If not, try to find them with different attribute names
|
||||||
|
elif hasattr(tokenize_strategy, 'clip_l_tokenizer') and hasattr(tokenize_strategy, 't5xxl_tokenizer'):
|
||||||
|
return [tokenize_strategy.clip_l_tokenizer, tokenize_strategy.t5xxl_tokenizer]
|
||||||
|
else:
|
||||||
|
# As a last resort, create new tokenizers
|
||||||
|
logger.warning("Tokenizers not found in tokenize strategy, creating new ones")
|
||||||
|
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||||
|
clip_l_tokenizer = CLIPTokenizer.from_pretrained(
|
||||||
|
"openai/clip-vit-large-patch14",
|
||||||
|
cache_dir=getattr(self.args, 'tokenizer_cache_dir', None) if hasattr(self, 'args') else None
|
||||||
|
)
|
||||||
|
t5xxl_tokenizer = T5TokenizerFast.from_pretrained(
|
||||||
|
"google/t5-v1_1-xxl",
|
||||||
|
cache_dir=getattr(self.args, 'tokenizer_cache_dir', None) if hasattr(self, 'args') else None
|
||||||
|
)
|
||||||
|
return [clip_l_tokenizer, t5xxl_tokenizer]
|
||||||
|
|
||||||
def get_latents_caching_strategy(self, args):
|
def get_latents_caching_strategy(self, args):
|
||||||
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
|
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
|
||||||
|
|||||||
@@ -21,19 +21,25 @@ T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
|
|||||||
|
|
||||||
|
|
||||||
class FluxTokenizeStrategy(TokenizeStrategy):
|
class FluxTokenizeStrategy(TokenizeStrategy):
|
||||||
def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None:
|
def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None, use_clip_l: bool = True) -> None:
|
||||||
self.t5xxl_max_length = t5xxl_max_length
|
self.t5xxl_max_length = t5xxl_max_length
|
||||||
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
self.use_clip_l = use_clip_l
|
||||||
|
if self.use_clip_l:
|
||||||
|
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||||
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||||
|
|
||||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||||
text = [text] if isinstance(text, str) else text
|
text = [text] if isinstance(text, str) else text
|
||||||
|
|
||||||
l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
if self.use_clip_l:
|
||||||
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
|
l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
||||||
|
l_tokens = l_tokens["input_ids"]
|
||||||
|
else:
|
||||||
|
# For Chroma, return None for CLIP-L tokens
|
||||||
|
l_tokens = None
|
||||||
|
|
||||||
|
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
|
||||||
t5_attn_mask = t5_tokens["attention_mask"]
|
t5_attn_mask = t5_tokens["attention_mask"]
|
||||||
l_tokens = l_tokens["input_ids"]
|
|
||||||
t5_tokens = t5_tokens["input_ids"]
|
t5_tokens = t5_tokens["input_ids"]
|
||||||
|
|
||||||
return [l_tokens, t5_tokens, t5_attn_mask]
|
return [l_tokens, t5_tokens, t5_attn_mask]
|
||||||
@@ -63,24 +69,27 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
|
|||||||
l_tokens, t5_tokens = tokens[:2]
|
l_tokens, t5_tokens = tokens[:2]
|
||||||
t5_attn_mask = tokens[2] if len(tokens) > 2 else None
|
t5_attn_mask = tokens[2] if len(tokens) > 2 else None
|
||||||
|
|
||||||
# clip_l is None when using T5 only
|
# Handle Chroma case where CLIP-L is not used
|
||||||
if clip_l is not None and l_tokens is not None:
|
if clip_l is not None and l_tokens is not None:
|
||||||
l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"]
|
l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"]
|
||||||
else:
|
else:
|
||||||
l_pooled = None
|
# For Chroma, create a dummy tensor with the right shape
|
||||||
|
if t5_tokens is not None:
|
||||||
|
batch_size = t5_tokens.shape[0]
|
||||||
|
l_pooled = torch.zeros(batch_size, 768, device=t5_tokens.device, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
l_pooled = None
|
||||||
|
|
||||||
# t5xxl is None when using CLIP only
|
# t5xxl is None when using CLIP only
|
||||||
if t5xxl is not None and t5_tokens is not None:
|
if t5xxl is not None and t5_tokens is not None:
|
||||||
# t5_out is [b, max length, 4096]
|
# t5_out is [b, max length, 4096]
|
||||||
attention_mask = None if not apply_t5_attn_mask else t5_attn_mask.to(t5xxl.device)
|
attention_mask = None if not apply_t5_attn_mask else t5_attn_mask.to(t5xxl.device)
|
||||||
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), attention_mask, return_dict=False, output_hidden_states=True)
|
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), attention_mask, return_dict=False, output_hidden_states=True)
|
||||||
# if zero_pad_t5_output:
|
|
||||||
# t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
|
|
||||||
txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device)
|
txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device)
|
||||||
else:
|
else:
|
||||||
t5_out = None
|
t5_out = None
|
||||||
txt_ids = None
|
txt_ids = None
|
||||||
t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one
|
t5_attn_mask = None
|
||||||
|
|
||||||
return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer
|
return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer
|
||||||
|
|
||||||
|
|||||||
@@ -400,8 +400,19 @@ class NetworkTrainer:
|
|||||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||||
if text_encoder_outputs_list is not None:
|
if text_encoder_outputs_list is not None:
|
||||||
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
||||||
|
else:
|
||||||
|
# For debugging
|
||||||
|
logger.debug(f"text_encoder_outputs_list is None, batch keys: {list(batch.keys())}")
|
||||||
|
|
||||||
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
|
# For Chroma, text_encoder_conds might be set up differently
|
||||||
|
# Check if we need to encode text encoders
|
||||||
|
need_to_encode = len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder
|
||||||
|
# Also check if input_ids_list is None (for Chroma)
|
||||||
|
if "input_ids_list" in batch and batch["input_ids_list"] is None:
|
||||||
|
# If input_ids_list is None, we might already have the text encoder outputs cached
|
||||||
|
need_to_encode = False
|
||||||
|
|
||||||
|
if need_to_encode:
|
||||||
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
|
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
|
||||||
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
|
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
|
||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
@@ -414,12 +425,27 @@ class NetworkTrainer:
|
|||||||
weights_list,
|
weights_list,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
# Handle Chroma case where CLIP-L tokens might be None
|
||||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
|
# Check if input_ids_list exists and is not None
|
||||||
tokenize_strategy,
|
if "input_ids_list" in batch and batch["input_ids_list"] is not None:
|
||||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
input_ids = []
|
||||||
input_ids,
|
for ids in batch["input_ids_list"]:
|
||||||
)
|
if ids is not None: # Skip None values (CLIP-L tokens for Chroma)
|
||||||
|
input_ids.append(ids.to(accelerator.device))
|
||||||
|
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||||
|
tokenize_strategy,
|
||||||
|
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||||
|
input_ids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For Chroma, we might have a different way to get the input ids
|
||||||
|
# Since input_ids_list is None, we need to handle this case
|
||||||
|
# Let's assume the text encoding strategy can handle this
|
||||||
|
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||||
|
tokenize_strategy,
|
||||||
|
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||||
|
[], # Pass empty list or handle differently
|
||||||
|
)
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
|
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
|
||||||
|
|
||||||
@@ -472,6 +498,7 @@ class NetworkTrainer:
|
|||||||
return True # default for other than HunyuanImage
|
return True # default for other than HunyuanImage
|
||||||
|
|
||||||
def train(self, args):
|
def train(self, args):
|
||||||
|
self.args = args # store args for later use
|
||||||
session_id = random.randint(0, 2**32)
|
session_id = random.randint(0, 2**32)
|
||||||
training_started_at = time.time()
|
training_started_at = time.time()
|
||||||
train_util.verify_training_args(args)
|
train_util.verify_training_args(args)
|
||||||
|
|||||||
Reference in New Issue
Block a user