mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support attn mask for l+g/t5
This commit is contained in:
@@ -37,11 +37,14 @@ class Sd3TokenizeStrategy(TokenizeStrategy):
|
|||||||
g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
||||||
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
|
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
|
||||||
|
|
||||||
|
l_attn_mask = l_tokens["attention_mask"]
|
||||||
|
g_attn_mask = g_tokens["attention_mask"]
|
||||||
|
t5_attn_mask = t5_tokens["attention_mask"]
|
||||||
l_tokens = l_tokens["input_ids"]
|
l_tokens = l_tokens["input_ids"]
|
||||||
g_tokens = g_tokens["input_ids"]
|
g_tokens = g_tokens["input_ids"]
|
||||||
t5_tokens = t5_tokens["input_ids"]
|
t5_tokens = t5_tokens["input_ids"]
|
||||||
|
|
||||||
return [l_tokens, g_tokens, t5_tokens]
|
return [l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask]
|
||||||
|
|
||||||
|
|
||||||
class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
||||||
@@ -49,11 +52,20 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def encode_tokens(
|
def encode_tokens(
|
||||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
self,
|
||||||
|
tokenize_strategy: TokenizeStrategy,
|
||||||
|
models: List[Any],
|
||||||
|
tokens: List[torch.Tensor],
|
||||||
|
apply_lg_attn_mask: bool = False,
|
||||||
|
apply_t5_attn_mask: bool = False,
|
||||||
) -> List[torch.Tensor]:
|
) -> List[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
returned embeddings are not masked
|
||||||
|
"""
|
||||||
clip_l, clip_g, t5xxl = models
|
clip_l, clip_g, t5xxl = models
|
||||||
|
|
||||||
l_tokens, g_tokens, t5_tokens = tokens
|
l_tokens, g_tokens, t5_tokens = tokens[:3]
|
||||||
|
l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:] if len(tokens) > 3 else [None, None, None]
|
||||||
if l_tokens is None:
|
if l_tokens is None:
|
||||||
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
|
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
|
||||||
lg_out = None
|
lg_out = None
|
||||||
@@ -61,10 +73,15 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
|||||||
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
|
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
|
||||||
l_out, l_pooled = clip_l(l_tokens)
|
l_out, l_pooled = clip_l(l_tokens)
|
||||||
g_out, g_pooled = clip_g(g_tokens)
|
g_out, g_pooled = clip_g(g_tokens)
|
||||||
|
if apply_lg_attn_mask:
|
||||||
|
l_out = l_out * l_attn_mask.to(l_out.device).unsqueeze(-1)
|
||||||
|
g_out = g_out * g_attn_mask.to(g_out.device).unsqueeze(-1)
|
||||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||||
|
|
||||||
if t5xxl is not None and t5_tokens is not None:
|
if t5xxl is not None and t5_tokens is not None:
|
||||||
t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096]
|
t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096]
|
||||||
|
if apply_t5_attn_mask:
|
||||||
|
t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
t5_out = None
|
t5_out = None
|
||||||
|
|
||||||
@@ -84,50 +101,81 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
|||||||
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
|
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
|
self,
|
||||||
|
cache_to_disk: bool,
|
||||||
|
batch_size: int,
|
||||||
|
skip_disk_cache_validity_check: bool,
|
||||||
|
is_partial: bool = False,
|
||||||
|
apply_lg_attn_mask: bool = False,
|
||||||
|
apply_t5_attn_mask: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||||
|
self.apply_lg_attn_mask = apply_lg_attn_mask
|
||||||
|
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||||
|
|
||||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||||
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||||
|
|
||||||
def is_disk_cached_outputs_expected(self, abs_path: str):
|
def is_disk_cached_outputs_expected(self, npz_path: str):
|
||||||
if not self.cache_to_disk:
|
if not self.cache_to_disk:
|
||||||
return False
|
return False
|
||||||
if not os.path.exists(self.get_outputs_npz_path(abs_path)):
|
if not os.path.exists(npz_path):
|
||||||
return False
|
return False
|
||||||
if self.skip_disk_cache_validity_check:
|
if self.skip_disk_cache_validity_check:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
npz = np.load(self.get_outputs_npz_path(abs_path))
|
npz = np.load(npz_path)
|
||||||
if "clip_l" not in npz or "clip_g" not in npz:
|
if "lg_out" not in npz:
|
||||||
return False
|
return False
|
||||||
if "clip_l_pool" not in npz or "clip_g_pool" not in npz:
|
if "lg_pooled" not in npz:
|
||||||
|
return False
|
||||||
|
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
|
||||||
return False
|
return False
|
||||||
# t5xxl is optional
|
# t5xxl is optional
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}")
|
logger.error(f"Error loading file: {npz_path}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def mask_lg_attn(self, lg_out: np.ndarray, l_attn_mask: np.ndarray, g_attn_mask: np.ndarray) -> np.ndarray:
|
||||||
|
l_out = lg_out[..., :768]
|
||||||
|
g_out = lg_out[..., 768:] # 1280
|
||||||
|
l_out = l_out * np.expand_dims(l_attn_mask, -1) # l_out = l_out * l_attn_mask.
|
||||||
|
g_out = g_out * np.expand_dims(g_attn_mask, -1) # g_out = g_out * g_attn_mask.
|
||||||
|
return np.concatenate([l_out, g_out], axis=-1)
|
||||||
|
|
||||||
|
def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray:
|
||||||
|
return t5_out * np.expand_dims(t5_attn_mask, -1)
|
||||||
|
|
||||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||||
data = np.load(npz_path)
|
data = np.load(npz_path)
|
||||||
lg_out = data["lg_out"]
|
lg_out = data["lg_out"]
|
||||||
lg_pooled = data["lg_pooled"]
|
lg_pooled = data["lg_pooled"]
|
||||||
t5_out = data["t5_out"] if "t5_out" in data else None
|
t5_out = data["t5_out"] if "t5_out" in data else None
|
||||||
|
|
||||||
|
if self.apply_lg_attn_mask:
|
||||||
|
l_attn_mask = data["clip_l_attn_mask"]
|
||||||
|
g_attn_mask = data["clip_g_attn_mask"]
|
||||||
|
lg_out = self.mask_lg_attn(lg_out, l_attn_mask, g_attn_mask)
|
||||||
|
|
||||||
|
if self.apply_t5_attn_mask and t5_out is not None:
|
||||||
|
t5_attn_mask = data["t5_attn_mask"]
|
||||||
|
t5_out = self.mask_t5_attn(t5_out, t5_attn_mask)
|
||||||
|
|
||||||
return [lg_out, t5_out, lg_pooled]
|
return [lg_out, t5_out, lg_pooled]
|
||||||
|
|
||||||
def cache_batch_outputs(
|
def cache_batch_outputs(
|
||||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||||
):
|
):
|
||||||
|
sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy
|
||||||
captions = [info.caption for info in infos]
|
captions = [info.caption for info in infos]
|
||||||
|
|
||||||
clip_l_tokens, clip_g_tokens, t5xxl_tokens = tokenize_strategy.tokenize(captions)
|
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
lg_out, t5_out, lg_pooled = text_encoding_strategy.encode_tokens(
|
lg_out, t5_out, lg_pooled = sd3_text_encoding_strategy.encode_tokens(
|
||||||
tokenize_strategy, models, [clip_l_tokens, clip_g_tokens, t5xxl_tokens]
|
tokenize_strategy, models, tokens_and_masks, self.apply_lg_attn_mask, self.apply_t5_attn_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
if lg_out.dtype == torch.bfloat16:
|
if lg_out.dtype == torch.bfloat16:
|
||||||
@@ -148,10 +196,22 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
|||||||
lg_pooled_i = lg_pooled[i]
|
lg_pooled_i = lg_pooled[i]
|
||||||
|
|
||||||
if self.cache_to_disk:
|
if self.cache_to_disk:
|
||||||
|
clip_l_attn_mask, clip_g_attn_mask, t5_attn_mask = tokens_and_masks[3:6]
|
||||||
|
clip_l_attn_mask_i = clip_l_attn_mask[i].cpu().numpy()
|
||||||
|
clip_g_attn_mask_i = clip_g_attn_mask[i].cpu().numpy()
|
||||||
|
t5_attn_mask_i = t5_attn_mask[i].cpu().numpy() if t5_attn_mask is not None else None # shouldn't be None
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if t5_out is not None:
|
if t5_out is not None:
|
||||||
kwargs["t5_out"] = t5_out_i
|
kwargs["t5_out"] = t5_out_i
|
||||||
np.savez(info.text_encoder_outputs_npz, lg_out=lg_out_i, lg_pooled=lg_pooled_i, **kwargs)
|
np.savez(
|
||||||
|
info.text_encoder_outputs_npz,
|
||||||
|
lg_out=lg_out_i,
|
||||||
|
lg_pooled=lg_pooled_i,
|
||||||
|
clip_l_attn_mask=clip_l_attn_mask_i,
|
||||||
|
clip_g_attn_mask=clip_g_attn_mask_i,
|
||||||
|
t5_attn_mask=t5_attn_mask_i,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i)
|
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i)
|
||||||
|
|
||||||
|
|||||||
@@ -1486,6 +1486,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz(
|
text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz(
|
||||||
image_info.text_encoder_outputs_npz
|
image_info.text_encoder_outputs_npz
|
||||||
)
|
)
|
||||||
|
text_encoder_outputs = [torch.FloatTensor(x) for x in text_encoder_outputs]
|
||||||
else:
|
else:
|
||||||
tokenization_required = True
|
tokenization_required = True
|
||||||
text_encoder_outputs_list.append(text_encoder_outputs)
|
text_encoder_outputs_list.append(text_encoder_outputs)
|
||||||
|
|||||||
@@ -146,6 +146,8 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--clip_l", type=str, required=False)
|
parser.add_argument("--clip_l", type=str, required=False)
|
||||||
parser.add_argument("--t5xxl", type=str, required=False)
|
parser.add_argument("--t5xxl", type=str, required=False)
|
||||||
parser.add_argument("--t5xxl_token_length", type=int, default=77, help="t5xxl token length, default: 77")
|
parser.add_argument("--t5xxl_token_length", type=int, default=77, help="t5xxl token length, default: 77")
|
||||||
|
parser.add_argument("--apply_lg_attn_mask", action="store_true")
|
||||||
|
parser.add_argument("--apply_t5_attn_mask", action="store_true")
|
||||||
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
||||||
# parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders
|
# parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders
|
||||||
parser.add_argument("--negative_prompt", type=str, default="")
|
parser.add_argument("--negative_prompt", type=str, default="")
|
||||||
@@ -323,15 +325,15 @@ if __name__ == "__main__":
|
|||||||
logger.info("Encoding prompts...")
|
logger.info("Encoding prompts...")
|
||||||
encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
|
encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
|
||||||
|
|
||||||
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.prompt)
|
tokens_and_masks = tokenize_strategy.tokenize(args.prompt)
|
||||||
lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
|
lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
|
||||||
tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens]
|
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
|
||||||
)
|
)
|
||||||
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
||||||
|
|
||||||
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.negative_prompt)
|
tokens_and_masks = tokenize_strategy.tokenize(args.negative_prompt)
|
||||||
lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
|
lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
|
||||||
tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens]
|
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
|
||||||
)
|
)
|
||||||
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
||||||
|
|
||||||
|
|||||||
30
sd3_train.py
30
sd3_train.py
@@ -172,6 +172,8 @@ def train(args):
|
|||||||
args.text_encoder_batch_size,
|
args.text_encoder_batch_size,
|
||||||
False,
|
False,
|
||||||
False,
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
train_dataset_group.set_current_strategies()
|
train_dataset_group.set_current_strategies()
|
||||||
@@ -312,6 +314,8 @@ def train(args):
|
|||||||
args.text_encoder_batch_size,
|
args.text_encoder_batch_size,
|
||||||
False,
|
False,
|
||||||
train_clip_g or train_clip_l or args.use_t5xxl_cache_only,
|
train_clip_g or train_clip_l or args.use_t5xxl_cache_only,
|
||||||
|
args.apply_lg_attn_mask,
|
||||||
|
args.apply_t5_attn_mask,
|
||||||
)
|
)
|
||||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
|
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
|
||||||
|
|
||||||
@@ -335,7 +339,11 @@ def train(args):
|
|||||||
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
||||||
tokens_list = sd3_tokenize_strategy.tokenize(p)
|
tokens_list = sd3_tokenize_strategy.tokenize(p)
|
||||||
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
||||||
sd3_tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_list
|
sd3_tokenize_strategy,
|
||||||
|
[clip_l, clip_g, t5xxl],
|
||||||
|
tokens_list,
|
||||||
|
args.apply_lg_attn_mask,
|
||||||
|
args.apply_t5_attn_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
@@ -748,21 +756,23 @@ def train(args):
|
|||||||
|
|
||||||
if lg_out is None or (train_clip_l or train_clip_g):
|
if lg_out is None or (train_clip_l or train_clip_g):
|
||||||
# not cached or training, so get from text encoders
|
# not cached or training, so get from text encoders
|
||||||
input_ids_clip_l, input_ids_clip_g, _ = batch["input_ids_list"]
|
input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"]
|
||||||
with torch.set_grad_enabled(args.train_text_encoder):
|
with torch.set_grad_enabled(args.train_text_encoder):
|
||||||
# TODO support weighted captions
|
# TODO support weighted captions
|
||||||
input_ids_clip_l = input_ids_clip_l.to(accelerator.device)
|
input_ids_clip_l = input_ids_clip_l.to(accelerator.device)
|
||||||
input_ids_clip_g = input_ids_clip_g.to(accelerator.device)
|
input_ids_clip_g = input_ids_clip_g.to(accelerator.device)
|
||||||
lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens(
|
lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens(
|
||||||
sd3_tokenize_strategy, [clip_l, clip_g, None], [input_ids_clip_l, input_ids_clip_g, None]
|
sd3_tokenize_strategy,
|
||||||
|
[clip_l, clip_g, None],
|
||||||
|
[input_ids_clip_l, input_ids_clip_g, None, l_attn_mask, g_attn_mask, None],
|
||||||
)
|
)
|
||||||
|
|
||||||
if t5_out is None:
|
if t5_out is None:
|
||||||
_, _, input_ids_t5xxl = batch["input_ids_list"]
|
_, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"]
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) if t5_out is None else None
|
input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) if t5_out is None else None
|
||||||
_, t5_out, _ = text_encoding_strategy.encode_tokens(
|
_, t5_out, _ = text_encoding_strategy.encode_tokens(
|
||||||
sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl]
|
sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask]
|
||||||
)
|
)
|
||||||
|
|
||||||
context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled)
|
context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled)
|
||||||
@@ -969,6 +979,16 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
default=None,
|
default=None,
|
||||||
help="maximum token length for T5-XXL. 256 if omitted / T5-XXLの最大トークン数。省略時は256",
|
help="maximum token length for T5-XXL. 256 if omitted / T5-XXLの最大トークン数。省略時は256",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--apply_lg_attn_mask",
|
||||||
|
action="store_true",
|
||||||
|
help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスク(ゼロ埋め)を適用する",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--apply_t5_attn_mask",
|
||||||
|
action="store_true",
|
||||||
|
help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する",
|
||||||
|
)
|
||||||
|
|
||||||
# TE training is disabled temporarily
|
# TE training is disabled temporarily
|
||||||
# parser.add_argument(
|
# parser.add_argument(
|
||||||
|
|||||||
Reference in New Issue
Block a user