diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index c9e65423..74f15cec 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -196,6 +196,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) input_ids = data["input_ids"] return [hidden_state, input_ids, attention_mask] + @torch.no_grad() def cache_batch_outputs( self, tokenize_strategy: TokenizeStrategy, @@ -222,23 +223,21 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) tokens, attention_masks, weights_list = ( tokenize_strategy.tokenize_with_weights(captions) ) - with torch.no_grad(): - hidden_state, input_ids, attention_masks = ( - text_encoding_strategy.encode_tokens_with_weights( - tokenize_strategy, - models, - (tokens, attention_masks), - weights_list, - ) + hidden_state, input_ids, attention_masks = ( + text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, + models, + (tokens, attention_masks), + weights_list, ) + ) else: tokens = tokenize_strategy.tokenize(captions) - with torch.no_grad(): - hidden_state, input_ids, attention_masks = ( - text_encoding_strategy.encode_tokens( - tokenize_strategy, models, tokens - ) + hidden_state, input_ids, attention_masks = ( + text_encoding_strategy.encode_tokens( + tokenize_strategy, models, tokens ) + ) if hidden_state.dtype != torch.float32: hidden_state = hidden_state.float() @@ -247,14 +246,14 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) attention_mask = attention_masks.cpu().numpy() # (B, S) input_ids = input_ids.cpu().numpy() # (B, S) + for i, info in enumerate(batch): hidden_state_i = hidden_state[i] attention_mask_i = attention_mask[i] input_ids_i = input_ids[i] - assert info.text_encoder_outputs_npz is not None, "Text encoder cache outputs to disk not found for image {info.image_path}" - if self.cache_to_disk: + assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}" np.savez( info.text_encoder_outputs_npz, hidden_state=hidden_state_i, @@ -338,21 +337,21 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy): # TODO remove circular dependency for ImageInfo def cache_batch_latents( self, - vae, - image_infos: List, + model, + batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool, ): - encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu") - vae_device = vae.device - vae_dtype = vae.dtype + encode_by_vae = lambda img_tensor: model.encode(img_tensor).to("cpu") + vae_device = model.device + vae_dtype = model.dtype self._default_cache_batch_latents( encode_by_vae, vae_device, vae_dtype, - image_infos, + batch, flip_aug, alpha_mask, random_crop, @@ -360,4 +359,4 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy): ) if not train_util.HIGH_VRAM: - train_util.clean_memory_on_device(vae.device) + train_util.clean_memory_on_device(model.device) diff --git a/train_network.py b/train_network.py index ff62f46a..b4b0d42d 100644 --- a/train_network.py +++ b/train_network.py @@ -1282,7 +1282,6 @@ class NetworkTrainer: # For --sample_at_first optimizer_eval_fn() self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) - progress_bar.unpause() # Reset progress bar to before sampling images optimizer_train_fn() is_tracking = len(accelerator.trackers) > 0 if is_tracking: