From da4d0fe0165b3e0143c237de8cf307d53a9de45a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 5 Aug 2024 20:51:34 +0900 Subject: [PATCH] support attn mask for l+g/t5 --- library/strategy_sd3.py | 88 +++++++++++++++++++++++++++++++++------- library/train_util.py | 3 +- sd3_minimal_inference.py | 10 +++-- sd3_train.py | 30 +++++++++++--- 4 files changed, 107 insertions(+), 24 deletions(-) diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index 7491e814..a2281890 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -37,11 +37,14 @@ class Sd3TokenizeStrategy(TokenizeStrategy): g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt") + l_attn_mask = l_tokens["attention_mask"] + g_attn_mask = g_tokens["attention_mask"] + t5_attn_mask = t5_tokens["attention_mask"] l_tokens = l_tokens["input_ids"] g_tokens = g_tokens["input_ids"] t5_tokens = t5_tokens["input_ids"] - return [l_tokens, g_tokens, t5_tokens] + return [l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask] class Sd3TextEncodingStrategy(TextEncodingStrategy): @@ -49,11 +52,20 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): pass def encode_tokens( - self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens: List[torch.Tensor], + apply_lg_attn_mask: bool = False, + apply_t5_attn_mask: bool = False, ) -> List[torch.Tensor]: + """ + returned embeddings are not masked + """ clip_l, clip_g, t5xxl = models - l_tokens, g_tokens, t5_tokens = tokens + l_tokens, g_tokens, t5_tokens = tokens[:3] + l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:] if len(tokens) > 3 else [None, None, None] if l_tokens is None: assert g_tokens is None, "g_tokens must be None if l_tokens is None" lg_out = None @@ -61,10 +73,15 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" l_out, l_pooled = clip_l(l_tokens) g_out, g_pooled = clip_g(g_tokens) + if apply_lg_attn_mask: + l_out = l_out * l_attn_mask.to(l_out.device).unsqueeze(-1) + g_out = g_out * g_attn_mask.to(g_out.device).unsqueeze(-1) lg_out = torch.cat([l_out, g_out], dim=-1) if t5xxl is not None and t5_tokens is not None: t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096] + if apply_t5_attn_mask: + t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) else: t5_out = None @@ -84,50 +101,81 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz" def __init__( - self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + apply_lg_attn_mask: bool = False, + apply_t5_attn_mask: bool = False, ) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + self.apply_lg_attn_mask = apply_lg_attn_mask + self.apply_t5_attn_mask = apply_t5_attn_mask def get_outputs_npz_path(self, image_abs_path: str) -> str: return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX - def is_disk_cached_outputs_expected(self, abs_path: str): + def is_disk_cached_outputs_expected(self, npz_path: str): if not self.cache_to_disk: return False - if not os.path.exists(self.get_outputs_npz_path(abs_path)): + if not os.path.exists(npz_path): return False if self.skip_disk_cache_validity_check: return True try: - npz = np.load(self.get_outputs_npz_path(abs_path)) - if "clip_l" not in npz or "clip_g" not in npz: + npz = np.load(npz_path) + if "lg_out" not in npz: return False - if "clip_l_pool" not in npz or "clip_g_pool" not in npz: + if "lg_pooled" not in npz: + return False + if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used return False # t5xxl is optional except Exception as e: - logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}") + logger.error(f"Error loading file: {npz_path}") raise e return True + def mask_lg_attn(self, lg_out: np.ndarray, l_attn_mask: np.ndarray, g_attn_mask: np.ndarray) -> np.ndarray: + l_out = lg_out[..., :768] + g_out = lg_out[..., 768:] # 1280 + l_out = l_out * np.expand_dims(l_attn_mask, -1) # l_out = l_out * l_attn_mask. + g_out = g_out * np.expand_dims(g_attn_mask, -1) # g_out = g_out * g_attn_mask. + return np.concatenate([l_out, g_out], axis=-1) + + def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray: + return t5_out * np.expand_dims(t5_attn_mask, -1) + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: data = np.load(npz_path) lg_out = data["lg_out"] lg_pooled = data["lg_pooled"] t5_out = data["t5_out"] if "t5_out" in data else None + + if self.apply_lg_attn_mask: + l_attn_mask = data["clip_l_attn_mask"] + g_attn_mask = data["clip_g_attn_mask"] + lg_out = self.mask_lg_attn(lg_out, l_attn_mask, g_attn_mask) + + if self.apply_t5_attn_mask and t5_out is not None: + t5_attn_mask = data["t5_attn_mask"] + t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) + return [lg_out, t5_out, lg_pooled] def cache_batch_outputs( self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List ): + sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy captions = [info.caption for info in infos] - clip_l_tokens, clip_g_tokens, t5xxl_tokens = tokenize_strategy.tokenize(captions) + tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): - lg_out, t5_out, lg_pooled = text_encoding_strategy.encode_tokens( - tokenize_strategy, models, [clip_l_tokens, clip_g_tokens, t5xxl_tokens] + lg_out, t5_out, lg_pooled = sd3_text_encoding_strategy.encode_tokens( + tokenize_strategy, models, tokens_and_masks, self.apply_lg_attn_mask, self.apply_t5_attn_mask ) if lg_out.dtype == torch.bfloat16: @@ -148,10 +196,22 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): lg_pooled_i = lg_pooled[i] if self.cache_to_disk: + clip_l_attn_mask, clip_g_attn_mask, t5_attn_mask = tokens_and_masks[3:6] + clip_l_attn_mask_i = clip_l_attn_mask[i].cpu().numpy() + clip_g_attn_mask_i = clip_g_attn_mask[i].cpu().numpy() + t5_attn_mask_i = t5_attn_mask[i].cpu().numpy() if t5_attn_mask is not None else None # shouldn't be None kwargs = {} if t5_out is not None: kwargs["t5_out"] = t5_out_i - np.savez(info.text_encoder_outputs_npz, lg_out=lg_out_i, lg_pooled=lg_pooled_i, **kwargs) + np.savez( + info.text_encoder_outputs_npz, + lg_out=lg_out_i, + lg_pooled=lg_pooled_i, + clip_l_attn_mask=clip_l_attn_mask_i, + clip_g_attn_mask=clip_g_attn_mask_i, + t5_attn_mask=t5_attn_mask_i, + **kwargs, + ) else: info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i) diff --git a/library/train_util.py b/library/train_util.py index a747e047..fc458a88 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -646,7 +646,7 @@ class BaseDataset(torch.utils.data.Dataset): # caching self.caching_mode = None # None, 'latents', 'text' - + self.tokenize_strategy = None self.text_encoder_output_caching_strategy = None self.latents_caching_strategy = None @@ -1486,6 +1486,7 @@ class BaseDataset(torch.utils.data.Dataset): text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz( image_info.text_encoder_outputs_npz ) + text_encoder_outputs = [torch.FloatTensor(x) for x in text_encoder_outputs] else: tokenization_required = True text_encoder_outputs_list.append(text_encoder_outputs) diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index e9e61af1..630da7e0 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -146,6 +146,8 @@ if __name__ == "__main__": parser.add_argument("--clip_l", type=str, required=False) parser.add_argument("--t5xxl", type=str, required=False) parser.add_argument("--t5xxl_token_length", type=int, default=77, help="t5xxl token length, default: 77") + parser.add_argument("--apply_lg_attn_mask", action="store_true") + parser.add_argument("--apply_t5_attn_mask", action="store_true") parser.add_argument("--prompt", type=str, default="A photo of a cat") # parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders parser.add_argument("--negative_prompt", type=str, default="") @@ -323,15 +325,15 @@ if __name__ == "__main__": logger.info("Encoding prompts...") encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() - l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.prompt) + tokens_and_masks = tokenize_strategy.tokenize(args.prompt) lg_out, t5_out, pooled = encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens] + tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask ) cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) - l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.negative_prompt) + tokens_and_masks = tokenize_strategy.tokenize(args.negative_prompt) lg_out, t5_out, pooled = encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens] + tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask ) neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) diff --git a/sd3_train.py b/sd3_train.py index 2f4ea8cb..9c37cbce 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -172,6 +172,8 @@ def train(args): args.text_encoder_batch_size, False, False, + False, + False, ) ) train_dataset_group.set_current_strategies() @@ -312,6 +314,8 @@ def train(args): args.text_encoder_batch_size, False, train_clip_g or train_clip_l or args.use_t5xxl_cache_only, + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, ) strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) @@ -335,7 +339,11 @@ def train(args): logger.info(f"cache Text Encoder outputs for prompt: {p}") tokens_list = sd3_tokenize_strategy.tokenize(p) sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( - sd3_tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_list + sd3_tokenize_strategy, + [clip_l, clip_g, t5xxl], + tokens_list, + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, ) accelerator.wait_for_everyone() @@ -748,21 +756,23 @@ def train(args): if lg_out is None or (train_clip_l or train_clip_g): # not cached or training, so get from text encoders - input_ids_clip_l, input_ids_clip_g, _ = batch["input_ids_list"] + input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): # TODO support weighted captions input_ids_clip_l = input_ids_clip_l.to(accelerator.device) input_ids_clip_g = input_ids_clip_g.to(accelerator.device) lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens( - sd3_tokenize_strategy, [clip_l, clip_g, None], [input_ids_clip_l, input_ids_clip_g, None] + sd3_tokenize_strategy, + [clip_l, clip_g, None], + [input_ids_clip_l, input_ids_clip_g, None, l_attn_mask, g_attn_mask, None], ) if t5_out is None: - _, _, input_ids_t5xxl = batch["input_ids_list"] + _, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"] with torch.no_grad(): input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) if t5_out is None else None _, t5_out, _ = text_encoding_strategy.encode_tokens( - sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl] + sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask] ) context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled) @@ -969,6 +979,16 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="maximum token length for T5-XXL. 256 if omitted / T5-XXLの最大トークン数。省略時は256", ) + parser.add_argument( + "--apply_lg_attn_mask", + action="store_true", + help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスク(ゼロ埋め)を適用する", + ) + parser.add_argument( + "--apply_t5_attn_mask", + action="store_true", + help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", + ) # TE training is disabled temporarily # parser.add_argument(