From 2d8fa3387a4adfdc2e36f2582e4ffc21864569f0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 22 Aug 2024 19:56:27 +0900 Subject: [PATCH] Fix to remove zero pad for t5xxl output --- README.md | 5 +++++ library/strategy_flux.py | 23 +++++++++++------------ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 5125c663..33b3a9a9 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,11 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 22, 2024 (update 2): +Fixed a bug that the embedding was zero-padded when `--apply_t5_attn_mask` option was applied. Also, the cache file for text encoder outputs now records whether the mask is applied or not. Please note that the cache file will be recreated when switching the `--apply_t5_attn_mask` option. + +Added a script to extract LoRA from the difference between the two models of FLUX.1. Use `networks/flux_extract_lora.py`. See `--help` for details. Normally, more than 50GB of memory is required, but specifying the `--mem_eff_safe_open` option significantly reduces memory usage. However, this option is a custom implementation, so unexpected problems may occur. Please always check if the model is loaded correctly. + Aug 22, 2024: Fixed a bug in multi-GPU training. It should work with fine-tuning and LoRA training. `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training. diff --git a/library/strategy_flux.py b/library/strategy_flux.py index b3643cbf..d52b3b8d 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -22,7 +22,7 @@ T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl" class FluxTokenizeStrategy(TokenizeStrategy): - def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None: + def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None: self.t5xxl_max_length = t5xxl_max_length self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) @@ -120,25 +120,24 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): return False if "t5_attn_mask" not in npz: return False + if "apply_t5_attn_mask" not in npz: + return False + npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"] + if npz_apply_t5_attn_mask != self.apply_t5_attn_mask: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e return True - def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray: - return t5_out * np.expand_dims(t5_attn_mask, -1) - def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: data = np.load(npz_path) l_pooled = data["l_pooled"] t5_out = data["t5_out"] txt_ids = data["txt_ids"] t5_attn_mask = data["t5_attn_mask"] - - if self.apply_t5_attn_mask: - t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) # FIXME do not mask here!!! - + # apply_t5_attn_mask should be same as self.apply_t5_attn_mask return [l_pooled, t5_out, txt_ids, t5_attn_mask] def cache_batch_outputs( @@ -149,10 +148,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): - # attn_mask is not applied when caching to disk: it is applied when loading from disk FIXME apply mask when loading - l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens( - tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk - ) + # attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True + l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks) if l_pooled.dtype == torch.bfloat16: l_pooled = l_pooled.float() @@ -171,6 +168,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): t5_out_i = t5_out[i] txt_ids_i = txt_ids[i] t5_attn_mask_i = t5_attn_mask[i] + apply_t5_attn_mask_i = self.apply_t5_attn_mask if self.cache_to_disk: np.savez( @@ -179,6 +177,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): t5_out=t5_out_i, txt_ids=txt_ids_i, t5_attn_mask=t5_attn_mask_i, + apply_t5_attn_mask=apply_t5_attn_mask_i, ) else: info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)