mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
fix cache text encoder outputs if not using disk. small cleanup/alignment
This commit is contained in:
@@ -196,6 +196,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
input_ids = data["input_ids"]
|
||||
return [hidden_state, input_ids, attention_mask]
|
||||
|
||||
@torch.no_grad()
|
||||
def cache_batch_outputs(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
@@ -222,23 +223,21 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
tokens, attention_masks, weights_list = (
|
||||
tokenize_strategy.tokenize_with_weights(captions)
|
||||
)
|
||||
with torch.no_grad():
|
||||
hidden_state, input_ids, attention_masks = (
|
||||
text_encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy,
|
||||
models,
|
||||
(tokens, attention_masks),
|
||||
weights_list,
|
||||
)
|
||||
hidden_state, input_ids, attention_masks = (
|
||||
text_encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy,
|
||||
models,
|
||||
(tokens, attention_masks),
|
||||
weights_list,
|
||||
)
|
||||
)
|
||||
else:
|
||||
tokens = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
hidden_state, input_ids, attention_masks = (
|
||||
text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, models, tokens
|
||||
)
|
||||
hidden_state, input_ids, attention_masks = (
|
||||
text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, models, tokens
|
||||
)
|
||||
)
|
||||
|
||||
if hidden_state.dtype != torch.float32:
|
||||
hidden_state = hidden_state.float()
|
||||
@@ -247,14 +246,14 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
attention_mask = attention_masks.cpu().numpy() # (B, S)
|
||||
input_ids = input_ids.cpu().numpy() # (B, S)
|
||||
|
||||
|
||||
for i, info in enumerate(batch):
|
||||
hidden_state_i = hidden_state[i]
|
||||
attention_mask_i = attention_mask[i]
|
||||
input_ids_i = input_ids[i]
|
||||
|
||||
assert info.text_encoder_outputs_npz is not None, "Text encoder cache outputs to disk not found for image {info.image_path}"
|
||||
|
||||
if self.cache_to_disk:
|
||||
assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}"
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
hidden_state=hidden_state_i,
|
||||
@@ -338,21 +337,21 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(
|
||||
self,
|
||||
vae,
|
||||
image_infos: List,
|
||||
model,
|
||||
batch: List,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
random_crop: bool,
|
||||
):
|
||||
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
encode_by_vae = lambda img_tensor: model.encode(img_tensor).to("cpu")
|
||||
vae_device = model.device
|
||||
vae_dtype = model.dtype
|
||||
|
||||
self._default_cache_batch_latents(
|
||||
encode_by_vae,
|
||||
vae_device,
|
||||
vae_dtype,
|
||||
image_infos,
|
||||
batch,
|
||||
flip_aug,
|
||||
alpha_mask,
|
||||
random_crop,
|
||||
@@ -360,4 +359,4 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
)
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
train_util.clean_memory_on_device(vae.device)
|
||||
train_util.clean_memory_on_device(model.device)
|
||||
|
||||
Reference in New Issue
Block a user