fix cache text encoder outputs if not using disk. small cleanup/alignment

This commit is contained in:
rockerBOO
2025-02-26 23:33:50 -05:00
parent 7b83d50dc0
commit 70403f6977
2 changed files with 21 additions and 23 deletions

View File

@@ -196,6 +196,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
input_ids = data["input_ids"]
return [hidden_state, input_ids, attention_mask]
@torch.no_grad()
def cache_batch_outputs(
self,
tokenize_strategy: TokenizeStrategy,
@@ -222,23 +223,21 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
tokens, attention_masks, weights_list = (
tokenize_strategy.tokenize_with_weights(captions)
)
with torch.no_grad():
hidden_state, input_ids, attention_masks = (
text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
models,
(tokens, attention_masks),
weights_list,
)
hidden_state, input_ids, attention_masks = (
text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
models,
(tokens, attention_masks),
weights_list,
)
)
else:
tokens = tokenize_strategy.tokenize(captions)
with torch.no_grad():
hidden_state, input_ids, attention_masks = (
text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens
)
hidden_state, input_ids, attention_masks = (
text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens
)
)
if hidden_state.dtype != torch.float32:
hidden_state = hidden_state.float()
@@ -247,14 +246,14 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
attention_mask = attention_masks.cpu().numpy() # (B, S)
input_ids = input_ids.cpu().numpy() # (B, S)
for i, info in enumerate(batch):
hidden_state_i = hidden_state[i]
attention_mask_i = attention_mask[i]
input_ids_i = input_ids[i]
assert info.text_encoder_outputs_npz is not None, "Text encoder cache outputs to disk not found for image {info.image_path}"
if self.cache_to_disk:
assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}"
np.savez(
info.text_encoder_outputs_npz,
hidden_state=hidden_state_i,
@@ -338,21 +337,21 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(
self,
vae,
image_infos: List,
model,
batch: List,
flip_aug: bool,
alpha_mask: bool,
random_crop: bool,
):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
vae_device = vae.device
vae_dtype = vae.dtype
encode_by_vae = lambda img_tensor: model.encode(img_tensor).to("cpu")
vae_device = model.device
vae_dtype = model.dtype
self._default_cache_batch_latents(
encode_by_vae,
vae_device,
vae_dtype,
image_infos,
batch,
flip_aug,
alpha_mask,
random_crop,
@@ -360,4 +359,4 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)
train_util.clean_memory_on_device(model.device)