Merge pull request #17 from rockerBOO/lumina-cache-text-encoder-outputs

Lumina cache text encoder outputs
This commit is contained in:
青龍聖者@bdsqlsz
2025-03-02 18:30:08 +08:00
committed by GitHub
2 changed files with 21 additions and 23 deletions

View File

@@ -196,6 +196,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
input_ids = data["input_ids"]
return [hidden_state, input_ids, attention_mask]
@torch.no_grad()
def cache_batch_outputs(
self,
tokenize_strategy: TokenizeStrategy,
@@ -222,23 +223,21 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
tokens, attention_masks, weights_list = (
tokenize_strategy.tokenize_with_weights(captions)
)
with torch.no_grad():
hidden_state, input_ids, attention_masks = (
text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
models,
(tokens, attention_masks),
weights_list,
)
hidden_state, input_ids, attention_masks = (
text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
models,
(tokens, attention_masks),
weights_list,
)
)
else:
tokens = tokenize_strategy.tokenize(captions)
with torch.no_grad():
hidden_state, input_ids, attention_masks = (
text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens
)
hidden_state, input_ids, attention_masks = (
text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens
)
)
if hidden_state.dtype != torch.float32:
hidden_state = hidden_state.float()
@@ -247,14 +246,14 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
attention_mask = attention_masks.cpu().numpy() # (B, S)
input_ids = input_ids.cpu().numpy() # (B, S)
for i, info in enumerate(batch):
hidden_state_i = hidden_state[i]
attention_mask_i = attention_mask[i]
input_ids_i = input_ids[i]
assert info.text_encoder_outputs_npz is not None, "Text encoder cache outputs to disk not found for image {info.image_path}"
if self.cache_to_disk:
assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}"
np.savez(
info.text_encoder_outputs_npz,
hidden_state=hidden_state_i,
@@ -338,21 +337,21 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(
self,
vae,
image_infos: List,
model,
batch: List,
flip_aug: bool,
alpha_mask: bool,
random_crop: bool,
):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
vae_device = vae.device
vae_dtype = vae.dtype
encode_by_vae = lambda img_tensor: model.encode(img_tensor).to("cpu")
vae_device = model.device
vae_dtype = model.dtype
self._default_cache_batch_latents(
encode_by_vae,
vae_device,
vae_dtype,
image_infos,
batch,
flip_aug,
alpha_mask,
random_crop,
@@ -360,4 +359,4 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)
train_util.clean_memory_on_device(model.device)

View File

@@ -1282,7 +1282,6 @@ class NetworkTrainer:
# For --sample_at_first
optimizer_eval_fn()
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
progress_bar.unpause() # Reset progress bar to before sampling images
optimizer_train_fn()
is_tracking = len(accelerator.trackers) > 0
if is_tracking: