Fix to remove zero pad for t5xxl output

This commit is contained in:
Kohya S
2024-08-22 19:56:27 +09:00
parent a4d27a232b
commit 2d8fa3387a
2 changed files with 16 additions and 12 deletions

View File

@@ -9,6 +9,11 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
The command to install PyTorch is as follows: The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
Aug 22, 2024 (update 2):
Fixed a bug that the embedding was zero-padded when `--apply_t5_attn_mask` option was applied. Also, the cache file for text encoder outputs now records whether the mask is applied or not. Please note that the cache file will be recreated when switching the `--apply_t5_attn_mask` option.
Added a script to extract LoRA from the difference between the two models of FLUX.1. Use `networks/flux_extract_lora.py`. See `--help` for details. Normally, more than 50GB of memory is required, but specifying the `--mem_eff_safe_open` option significantly reduces memory usage. However, this option is a custom implementation, so unexpected problems may occur. Please always check if the model is loaded correctly.
Aug 22, 2024: Aug 22, 2024:
Fixed a bug in multi-GPU training. It should work with fine-tuning and LoRA training. `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training. Fixed a bug in multi-GPU training. It should work with fine-tuning and LoRA training. `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training.

View File

@@ -22,7 +22,7 @@ T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
class FluxTokenizeStrategy(TokenizeStrategy): class FluxTokenizeStrategy(TokenizeStrategy):
def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None: def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None:
self.t5xxl_max_length = t5xxl_max_length self.t5xxl_max_length = t5xxl_max_length
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
@@ -120,25 +120,24 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return False return False
if "t5_attn_mask" not in npz: if "t5_attn_mask" not in npz:
return False return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
except Exception as e: except Exception as e:
logger.error(f"Error loading file: {npz_path}") logger.error(f"Error loading file: {npz_path}")
raise e raise e
return True return True
def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray:
return t5_out * np.expand_dims(t5_attn_mask, -1)
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path) data = np.load(npz_path)
l_pooled = data["l_pooled"] l_pooled = data["l_pooled"]
t5_out = data["t5_out"] t5_out = data["t5_out"]
txt_ids = data["txt_ids"] txt_ids = data["txt_ids"]
t5_attn_mask = data["t5_attn_mask"] t5_attn_mask = data["t5_attn_mask"]
# apply_t5_attn_mask should be same as self.apply_t5_attn_mask
if self.apply_t5_attn_mask:
t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) # FIXME do not mask here!!!
return [l_pooled, t5_out, txt_ids, t5_attn_mask] return [l_pooled, t5_out, txt_ids, t5_attn_mask]
def cache_batch_outputs( def cache_batch_outputs(
@@ -149,10 +148,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokens_and_masks = tokenize_strategy.tokenize(captions) tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad(): with torch.no_grad():
# attn_mask is not applied when caching to disk: it is applied when loading from disk FIXME apply mask when loading # attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens( l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks)
tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk
)
if l_pooled.dtype == torch.bfloat16: if l_pooled.dtype == torch.bfloat16:
l_pooled = l_pooled.float() l_pooled = l_pooled.float()
@@ -171,6 +168,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
t5_out_i = t5_out[i] t5_out_i = t5_out[i]
txt_ids_i = txt_ids[i] txt_ids_i = txt_ids[i]
t5_attn_mask_i = t5_attn_mask[i] t5_attn_mask_i = t5_attn_mask[i]
apply_t5_attn_mask_i = self.apply_t5_attn_mask
if self.cache_to_disk: if self.cache_to_disk:
np.savez( np.savez(
@@ -179,6 +177,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
t5_out=t5_out_i, t5_out=t5_out_i,
txt_ids=txt_ids_i, txt_ids=txt_ids_i,
t5_attn_mask=t5_attn_mask_i, t5_attn_mask=t5_attn_mask_i,
apply_t5_attn_mask=apply_t5_attn_mask_i,
) )
else: else:
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i) info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)