support attn mask for l+g/t5

This commit is contained in:
Kohya S
2024-08-05 20:51:34 +09:00
parent 231df197dd
commit da4d0fe016
4 changed files with 107 additions and 24 deletions

View File

@@ -646,7 +646,7 @@ class BaseDataset(torch.utils.data.Dataset):
# caching
self.caching_mode = None # None, 'latents', 'text'
self.tokenize_strategy = None
self.text_encoder_output_caching_strategy = None
self.latents_caching_strategy = None
@@ -1486,6 +1486,7 @@ class BaseDataset(torch.utils.data.Dataset):
text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz(
image_info.text_encoder_outputs_npz
)
text_encoder_outputs = [torch.FloatTensor(x) for x in text_encoder_outputs]
else:
tokenization_required = True
text_encoder_outputs_list.append(text_encoder_outputs)