mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Fix to remove zero pad for t5xxl output
This commit is contained in:
@@ -9,6 +9,11 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
|
|||||||
The command to install PyTorch is as follows:
|
The command to install PyTorch is as follows:
|
||||||
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
||||||
|
|
||||||
|
Aug 22, 2024 (update 2):
|
||||||
|
Fixed a bug that the embedding was zero-padded when `--apply_t5_attn_mask` option was applied. Also, the cache file for text encoder outputs now records whether the mask is applied or not. Please note that the cache file will be recreated when switching the `--apply_t5_attn_mask` option.
|
||||||
|
|
||||||
|
Added a script to extract LoRA from the difference between the two models of FLUX.1. Use `networks/flux_extract_lora.py`. See `--help` for details. Normally, more than 50GB of memory is required, but specifying the `--mem_eff_safe_open` option significantly reduces memory usage. However, this option is a custom implementation, so unexpected problems may occur. Please always check if the model is loaded correctly.
|
||||||
|
|
||||||
Aug 22, 2024:
|
Aug 22, 2024:
|
||||||
Fixed a bug in multi-GPU training. It should work with fine-tuning and LoRA training. `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training.
|
Fixed a bug in multi-GPU training. It should work with fine-tuning and LoRA training. `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training.
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
|
|||||||
|
|
||||||
|
|
||||||
class FluxTokenizeStrategy(TokenizeStrategy):
|
class FluxTokenizeStrategy(TokenizeStrategy):
|
||||||
def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None:
|
def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None:
|
||||||
self.t5xxl_max_length = t5xxl_max_length
|
self.t5xxl_max_length = t5xxl_max_length
|
||||||
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||||
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||||
@@ -120,25 +120,24 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
|||||||
return False
|
return False
|
||||||
if "t5_attn_mask" not in npz:
|
if "t5_attn_mask" not in npz:
|
||||||
return False
|
return False
|
||||||
|
if "apply_t5_attn_mask" not in npz:
|
||||||
|
return False
|
||||||
|
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
|
||||||
|
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
|
||||||
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading file: {npz_path}")
|
logger.error(f"Error loading file: {npz_path}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray:
|
|
||||||
return t5_out * np.expand_dims(t5_attn_mask, -1)
|
|
||||||
|
|
||||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||||
data = np.load(npz_path)
|
data = np.load(npz_path)
|
||||||
l_pooled = data["l_pooled"]
|
l_pooled = data["l_pooled"]
|
||||||
t5_out = data["t5_out"]
|
t5_out = data["t5_out"]
|
||||||
txt_ids = data["txt_ids"]
|
txt_ids = data["txt_ids"]
|
||||||
t5_attn_mask = data["t5_attn_mask"]
|
t5_attn_mask = data["t5_attn_mask"]
|
||||||
|
# apply_t5_attn_mask should be same as self.apply_t5_attn_mask
|
||||||
if self.apply_t5_attn_mask:
|
|
||||||
t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) # FIXME do not mask here!!!
|
|
||||||
|
|
||||||
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
|
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
|
||||||
|
|
||||||
def cache_batch_outputs(
|
def cache_batch_outputs(
|
||||||
@@ -149,10 +148,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
|||||||
|
|
||||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# attn_mask is not applied when caching to disk: it is applied when loading from disk FIXME apply mask when loading
|
# attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
|
||||||
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(
|
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks)
|
||||||
tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk
|
|
||||||
)
|
|
||||||
|
|
||||||
if l_pooled.dtype == torch.bfloat16:
|
if l_pooled.dtype == torch.bfloat16:
|
||||||
l_pooled = l_pooled.float()
|
l_pooled = l_pooled.float()
|
||||||
@@ -171,6 +168,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
|||||||
t5_out_i = t5_out[i]
|
t5_out_i = t5_out[i]
|
||||||
txt_ids_i = txt_ids[i]
|
txt_ids_i = txt_ids[i]
|
||||||
t5_attn_mask_i = t5_attn_mask[i]
|
t5_attn_mask_i = t5_attn_mask[i]
|
||||||
|
apply_t5_attn_mask_i = self.apply_t5_attn_mask
|
||||||
|
|
||||||
if self.cache_to_disk:
|
if self.cache_to_disk:
|
||||||
np.savez(
|
np.savez(
|
||||||
@@ -179,6 +177,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
|||||||
t5_out=t5_out_i,
|
t5_out=t5_out_i,
|
||||||
txt_ids=txt_ids_i,
|
txt_ids=txt_ids_i,
|
||||||
t5_attn_mask=t5_attn_mask_i,
|
t5_attn_mask=t5_attn_mask_i,
|
||||||
|
apply_t5_attn_mask=apply_t5_attn_mask_i,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
|
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
|
||||||
|
|||||||
Reference in New Issue
Block a user