From 60a76ebb72772327fcb7b2a10c87ad8f7b09f56f Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 16 Feb 2025 01:06:34 -0500 Subject: [PATCH] Add caching gemma2, add gradient checkpointing, refactor lumina model code --- library/lumina_models.py | 298 +++++++++++++++++++------------------ library/strategy_lumina.py | 108 ++++++++------ lumina_train_network.py | 113 ++++++++++---- networks/lora_lumina.py | 10 +- 4 files changed, 304 insertions(+), 225 deletions(-) diff --git a/library/lumina_models.py b/library/lumina_models.py index 3f2e854e..27194e2f 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -16,6 +16,8 @@ from dataclasses import dataclass from flash_attn import flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa import torch +from torch import Tensor +from torch.utils.checkpoint import checkpoint import torch.nn as nn import torch.nn.functional as F @@ -91,6 +93,25 @@ class LuminaParams: ) +class GradientCheckpointMixin(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = False + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + ############################################################################# # RMSNorm # ############################################################################# @@ -114,7 +135,7 @@ class RMSNorm(torch.nn.Module): self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) - def _norm(self, x): + def _norm(self, x) -> Tensor: """ Apply the RMSNorm normalization to the input tensor. @@ -125,21 +146,14 @@ class RMSNorm(torch.nn.Module): torch.Tensor: The normalized tensor. """ - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) - def forward(self, x): - """ - Forward pass through the RMSNorm layer. - Args: - x (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: The output tensor after applying RMSNorm. - - """ - output = self._norm(x.float()).type_as(x) - return output * self.weight + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return ((x * rrms) * self.weight.float()).to(dtype=x_dtype) def modulate(x, scale): @@ -151,7 +165,7 @@ def modulate(x, scale): ############################################################################# -class TimestepEmbedder(nn.Module): +class TimestepEmbedder(GradientCheckpointMixin): """ Embeds scalar timesteps into vector representations. """ @@ -203,11 +217,32 @@ class TimestepEmbedder(nn.Module): ) return embedding - def forward(self, t): + def _forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype)) return t_emb +def to_cuda(x): + if isinstance(x, torch.Tensor): + return x.cuda() + elif isinstance(x, (list, tuple)): + return [to_cuda(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cuda(v) for k, v in x.items()} + else: + return x + + +def to_cpu(x): + if isinstance(x, torch.Tensor): + return x.cpu() + elif isinstance(x, (list, tuple)): + return [to_cpu(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cpu(v) for k, v in x.items()} + else: + return x + ############################################################################# # Core NextDiT Model # @@ -284,7 +319,7 @@ class JointAttention(nn.Module): Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ - with torch.amp.autocast("cuda",enabled=False): + with torch.autocast("cuda", enabled=False): x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) freqs_cis = freqs_cis.unsqueeze(2) x_out = torch.view_as_real(x * freqs_cis).flatten(3) @@ -496,15 +531,15 @@ class FeedForward(nn.Module): return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) -class JointTransformerBlock(nn.Module): +class JointTransformerBlock(GradientCheckpointMixin): def __init__( self, layer_id: int, dim: int, n_heads: int, - n_kv_heads: int, + n_kv_heads: Optional[int], multiple_of: int, - ffn_dim_multiplier: float, + ffn_dim_multiplier: Optional[float], norm_eps: float, qk_norm: bool, modulation=True, @@ -520,7 +555,7 @@ class JointTransformerBlock(nn.Module): value features (if using GQA), or set to None for the same as query. multiple_of (int): - ffn_dim_multiplier (float): + ffn_dim_multiplier (Optional[float]): norm_eps (float): """ @@ -554,7 +589,7 @@ class JointTransformerBlock(nn.Module): nn.init.zeros_(self.adaLN_modulation[1].weight) nn.init.zeros_(self.adaLN_modulation[1].bias) - def forward( + def _forward( self, x: torch.Tensor, x_mask: torch.Tensor, @@ -608,7 +643,7 @@ class JointTransformerBlock(nn.Module): return x -class FinalLayer(nn.Module): +class FinalLayer(GradientCheckpointMixin): """ The final layer of NextDiT. """ @@ -661,22 +696,21 @@ class RopeEmbedder: self.axes_dims, self.axes_lens, theta=self.theta ) - def __call__(self, ids: torch.Tensor): + def get_freqs_cis(self, ids: torch.Tensor): self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis] result = [] for i in range(len(self.axes_dims)): - # import torch.distributed as dist - # if not dist.is_initialized() or dist.get_rank() == 0: - # import pdb - # pdb.set_trace() index = ( ids[:, :, i : i + 1] .repeat(1, 1, self.freqs_cis[i].shape[-1]) .to(torch.int64) ) + + axes = self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1) + result.append( torch.gather( - self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1), + axes, dim=1, index=index, ) @@ -790,76 +824,98 @@ class NextDiT(nn.Module): self.dim = dim self.n_heads = n_heads + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + self.blocks_to_swap = None + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + self.t_embedder.enable_gradient_checkpointing() + + for block in self.layers + self.context_refiner + self.noise_refiner: + block.enable_gradient_checkpointing(cpu_offload=cpu_offload) + + self.final_layer.enable_gradient_checkpointing() + + print(f"Lumina: Gradient checkpointing enabled. CPU offload: {cpu_offload}") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + self.t_embedder.disable_gradient_checkpointing() + + for block in self.layers + self.context_refiner + self.noise_refiner: + block.disable_gradient_checkpointing() + + self.final_layer.disable_gradient_checkpointing() + + print("Lumina: Gradient checkpointing disabled.") + def unpatchify( self, x: torch.Tensor, - img_size: List[Tuple[int, int]], - cap_size: List[int], - return_tensor=False, - ) -> List[torch.Tensor]: + width: int, + height: int, + encoder_seq_lengths: List[int], + seq_lengths: List[int], + ) -> torch.Tensor: """ x: (N, T, patch_size**2 * C) imgs: (N, H, W, C) """ pH = pW = self.patch_size - imgs = [] - for i in range(x.size(0)): - H, W = img_size[i] - begin = cap_size[i] - end = begin + (H // pH) * (W // pW) - imgs.append( - x[i][begin:end] - .view(H // pH, W // pW, pH, pW, self.out_channels) + + output = [] + for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + output.append( + x[i][encoder_seq_len:seq_len] + .view(height // pH, width // pW, pH, pW, self.out_channels) .permute(4, 0, 2, 1, 3) .flatten(3, 4) .flatten(1, 2) ) + output = torch.stack(output, dim=0) - if return_tensor: - imgs = torch.stack(imgs, dim=0) - return imgs + return output def patchify_and_embed( self, - x: List[torch.Tensor] | torch.Tensor, + x: torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, ) -> Tuple[ - torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor + torch.Tensor, torch.Tensor, torch.Tensor, List[int], List[int] ]: - bsz = len(x) + bsz, channels, height, width = x.shape pH = pW = self.patch_size - device = x[0].device + device = x.device l_effective_cap_len = cap_mask.sum(dim=1).tolist() - img_sizes = [(img.size(1), img.size(2)) for img in x] - l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] + encoder_seq_len = cap_mask.shape[1] - max_seq_len = max( - ( - cap_len + img_len - for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len) - ) - ) - max_cap_len = max(l_effective_cap_len) - max_img_len = max(l_effective_img_len) + image_seq_len = (height // self.patch_size) * (width // self.patch_size) + seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len] + max_seq_len = max(seq_lengths) - position_ids = torch.zeros( - bsz, max_seq_len, 3, dtype=torch.int32, device=device - ) + position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device) - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - H, W = img_sizes[i] - H_tokens, W_tokens = H // pH, W // pW - assert H_tokens * W_tokens == img_len + for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): + H_tokens, W_tokens = height // pH, width // pW - position_ids[i, :cap_len, 0] = torch.arange( - cap_len, dtype=torch.int32, device=device - ) - position_ids[i, cap_len : cap_len + img_len, 0] = cap_len + position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) + position_ids[i, cap_len : cap_len + seq_len, 0] = cap_len row_ids = ( torch.arange(H_tokens, dtype=torch.int32, device=device) .view(-1, 1) @@ -872,77 +928,40 @@ class NextDiT(nn.Module): .repeat(H_tokens, 1) .flatten() ) - position_ids[i, cap_len : cap_len + img_len, 1] = row_ids - position_ids[i, cap_len : cap_len + img_len, 2] = col_ids + position_ids[i, cap_len : cap_len + seq_len, 1] = row_ids + position_ids[i, cap_len : cap_len + seq_len, 2] = col_ids - freqs_cis = self.rope_embedder(position_ids) + freqs_cis = self.rope_embedder.get_freqs_cis(position_ids) - # build freqs_cis for cap and image individually - cap_freqs_cis_shape = list(freqs_cis.shape) - # cap_freqs_cis_shape[1] = max_cap_len - cap_freqs_cis_shape[1] = cap_feats.shape[1] - cap_freqs_cis = torch.zeros( - *cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype - ) + cap_freqs_cis = torch.zeros(bsz, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype) + img_freqs_cis = torch.zeros(bsz, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype) - img_freqs_cis_shape = list(freqs_cis.shape) - img_freqs_cis_shape[1] = max_img_len - img_freqs_cis = torch.zeros( - *img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype - ) - - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] + for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] - img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len : cap_len + img_len] + img_freqs_cis[i, :seq_len] = freqs_cis[i, cap_len : cap_len + seq_len] + + x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2) + x_mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device) # refine context for layer in self.context_refiner: cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) - # refine image - flat_x = [] - for i in range(bsz): - img = x[i] - C, H, W = img.size() - img = ( - img.view(C, H // pH, pH, W // pW, pW) - .permute(1, 3, 2, 4, 0) - .flatten(2) - .flatten(0, 1) - ) - flat_x.append(img) - x = flat_x - padded_img_embed = torch.zeros( - bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype - ) - padded_img_mask = torch.zeros(bsz, max_img_len, dtype=torch.bool, device=device) - for i in range(bsz): - padded_img_embed[i, : l_effective_img_len[i]] = x[i] - padded_img_mask[i, : l_effective_img_len[i]] = True + x = self.x_embedder(x) - padded_img_embed = self.x_embedder(padded_img_embed) for layer in self.noise_refiner: - padded_img_embed = layer( - padded_img_embed, padded_img_mask, img_freqs_cis, t - ) + x = layer(x, x_mask, img_freqs_cis, t) - mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device) - padded_full_embed = torch.zeros( - bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype - ) - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] + joint_hidden_states = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x.dtype) + attention_mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device) + for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): + attention_mask[i, :seq_len] = True + joint_hidden_states[i, :cap_len] = cap_feats[i, :cap_len] + joint_hidden_states[i, cap_len:seq_len] = x[i] - mask[i, : cap_len + img_len] = True - padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len] - padded_full_embed[i, cap_len : cap_len + img_len] = padded_img_embed[ - i, :img_len - ] + x = joint_hidden_states - return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis + return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths def forward(self, x, t, cap_feats, cap_mask): """ @@ -950,30 +969,19 @@ class NextDiT(nn.Module): t: (N,) tensor of diffusion timesteps y: (N,) tensor of text tokens/features """ - - # import torch.distributed as dist - # if not dist.is_initialized() or dist.get_rank() == 0: - # import pdb - # pdb.set_trace() - # torch.save([x, t, cap_feats, cap_mask], "./fake_input.pt") + _, _, height, width = x.shape # B, C, H, W t = self.t_embedder(t) # (N, D) - adaln_input = t + cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute - cap_feats = self.cap_embedder( - cap_feats - ) # (N, L, D) # todo check if able to batchify w.o. redundant compute - - x_is_tensor = isinstance(x, torch.Tensor) - x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed( + x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed( x, cap_feats, cap_mask, t ) - freqs_cis = freqs_cis.to(x.device) for layer in self.layers: - x = layer(x, mask, freqs_cis, adaln_input) + x = layer(x, mask, freqs_cis, t) - x = self.final_layer(x, adaln_input) - x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor) + x = self.final_layer(x, t) + x = self.unpatchify(x, width, height, l_effective_cap_len, seq_lengths) return x diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index 615f6e00..6feea387 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -3,7 +3,7 @@ import os from typing import Any, List, Optional, Tuple, Union import torch -from transformers import AutoTokenizer, AutoModel +from transformers import AutoTokenizer, AutoModel, GemmaTokenizerFast from library import train_util from library.strategy_base import ( LatentsCachingStrategy, @@ -27,34 +27,35 @@ class LuminaTokenizeStrategy(TokenizeStrategy): def __init__( self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None ) -> None: - self.tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer: GemmaTokenizerFast = AutoTokenizer.from_pretrained( GEMMA_ID, cache_dir=tokenizer_cache_dir ) self.tokenizer.padding_side = "right" if max_length is None: - self.max_length = self.tokenizer.model_max_length + self.max_length = 256 else: self.max_length = max_length - def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + def tokenize(self, text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor]: text = [text] if isinstance(text, str) else text encodings = self.tokenizer( text, - padding="max_length", max_length=self.max_length, return_tensors="pt", + padding=True, + pad_to_multiple_of=8, truncation=True, ) - return [encodings.input_ids] + return encodings.input_ids, encodings.attention_mask def tokenize_with_weights( self, text: str | List[str] - ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: # Gemma doesn't support weighted prompts, return uniform weights - tokens = self.tokenize(text) + tokens, attention_masks = self.tokenize(text) weights = [torch.ones_like(t) for t in tokens] - return tokens, weights + return tokens, attention_masks, weights class LuminaTextEncodingStrategy(TextEncodingStrategy): @@ -66,50 +67,39 @@ class LuminaTextEncodingStrategy(TextEncodingStrategy): self, tokenize_strategy: TokenizeStrategy, models: List[Any], - tokens: List[torch.Tensor], + tokens: torch.Tensor, + attention_masks: torch.Tensor, apply_gemma2_attn_mask: Optional[bool] = None, - ) -> List[torch.Tensor]: - + ) -> torch.Tensor: if apply_gemma2_attn_mask is None: apply_gemma2_attn_mask = self.apply_gemma2_attn_mask text_encoder = models[0] - input_ids = tokens[0].to(text_encoder.device) - attention_mask = None - position_ids = None - if apply_gemma2_attn_mask: - # Create attention mask (1 for non-padding, 0 for padding) - attention_mask = (input_ids != tokenize_strategy.tokenizer.pad_token_id).to( - text_encoder.device - ) + # Create position IDs + position_ids = attention_masks.cumsum(-1) - 1 + position_ids.masked_fill_(attention_masks == 0, 1) - # Create position IDs - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) + outputs = text_encoder( + input_ids=tokens.to(text_encoder.device), + attention_mask=attention_masks.to(text_encoder.device) if apply_gemma2_attn_mask else None, + position_ids=position_ids.to(text_encoder.device), + output_hidden_states=True, + return_dict=True, + ) - with torch.no_grad(): - outputs = text_encoder( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_hidden_states=True, - return_dict=True, - ) - # Get the last hidden state - hidden_states = outputs.last_hidden_state - - return [hidden_states] + return outputs.hidden_states[-2] def encode_tokens_with_weights( self, tokenize_strategy: TokenizeStrategy, models: List[Any], - tokens_list: List[torch.Tensor], + tokens: torch.Tensor, weights_list: List[torch.Tensor], - ) -> List[torch.Tensor]: + attention_masks: torch.Tensor + ) -> torch.Tensor: # For simplicity, use uniform weighting - return self.encode_tokens(tokenize_strategy, models, tokens_list) + return self.encode_tokens(tokenize_strategy, models, tokens, attention_masks) class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): @@ -149,6 +139,15 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) npz = np.load(npz_path) if "hidden_state" not in npz: return False + if "attention_mask" not in npz: + return False + if "input_ids" not in npz: + return False + if "apply_gemma2_attn_mask" not in npz: + return False + npz_apply_gemma2_attn_mask = npz["apply_gemma2_attn_mask"] + if npz_apply_gemma2_attn_mask != self.apply_gemma2_attn_mask: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -158,13 +157,15 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: data = np.load(npz_path) hidden_state = data["hidden_state"] - return [hidden_state] + attention_mask = data["attention_mask"] + input_ids = data["input_ids"] + return [hidden_state, attention_mask, input_ids] def cache_batch_outputs( self, - tokenize_strategy: TokenizeStrategy, + tokenize_strategy: LuminaTokenizeStrategy, models: List[Any], - text_encoding_strategy: TextEncodingStrategy, + text_encoding_strategy: LuminaTextEncodingStrategy, infos: List, ): lumina_text_encoding_strategy: LuminaTextEncodingStrategy = ( @@ -173,35 +174,44 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy) captions = [info.caption for info in infos] if self.is_weighted: - tokens_list, weights_list = tokenize_strategy.tokenize_with_weights( + tokens, attention_masks, weights_list = tokenize_strategy.tokenize_with_weights( captions ) with torch.no_grad(): hidden_state = lumina_text_encoding_strategy.encode_tokens_with_weights( - tokenize_strategy, models, tokens_list, weights_list - )[0] + tokenize_strategy, models, tokens, weights_list, attention_masks + ) else: - tokens = tokenize_strategy.tokenize(captions) + tokens, attention_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): hidden_state = lumina_text_encoding_strategy.encode_tokens( - tokenize_strategy, models, tokens - )[0] + tokenize_strategy, models, tokens, attention_masks + ) - if hidden_state.dtype == torch.bfloat16: + if hidden_state.dtype != torch.float32: hidden_state = hidden_state.float() hidden_state = hidden_state.cpu().numpy() + attention_mask = attention_masks.cpu().numpy() + input_ids = tokens.cpu().numpy() + for i, info in enumerate(infos): hidden_state_i = hidden_state[i] + attention_mask_i = attention_mask[i] + input_ids_i = input_ids[i] + apply_gemma2_attn_mask_i = self.apply_gemma2_attn_mask if self.cache_to_disk: np.savez( info.text_encoder_outputs_npz, hidden_state=hidden_state_i, + attention_mask=attention_mask_i, + input_ids=input_ids_i, + apply_gemma2_attn_mask=apply_gemma2_attn_mask_i, ) else: - info.text_encoder_outputs = [hidden_state_i] + info.text_encoder_outputs = [hidden_state_i, attention_mask_i, input_ids_i] class LuminaLatentsCachingStrategy(LatentsCachingStrategy): diff --git a/lumina_train_network.py b/lumina_train_network.py index 1f8ba613..3d0c7062 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -62,6 +62,19 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): disable_mmap=args.disable_mmap_load_safetensors, ) + if args.fp8_base: + # check dtype of model + if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") + elif model.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 Lumina 2 model") + else: + logger.info( + "Cast Lumina 2 model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." + " / Lumina 2モデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。" + ) + model.to(torch.float8_e4m3fn) + # if args.blocks_to_swap: # logger.info(f'Enabling block swap: {args.blocks_to_swap}') # model.enable_block_swap(args.blocks_to_swap, accelerator.device) @@ -70,6 +83,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): gemma2 = lumina_util.load_gemma2( args.gemma2, weight_dtype, "cpu" ) + gemma2.eval() ae = lumina_util.load_ae( args.ae, weight_dtype, "cpu" ) @@ -118,17 +132,65 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): dataset, weight_dtype, ): - for text_encoder in text_encoders: - text_encoder_outputs_caching_strategy = ( - self.get_text_encoder_outputs_caching_strategy(args) - ) - if text_encoder_outputs_caching_strategy is not None: - text_encoder_outputs_caching_strategy.cache_batch_outputs( - self.get_tokenize_strategy(args), - [text_encoder], - self.get_text_encoding_strategy(args), - dataset, - ) + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + logger.info("move text encoders to gpu") + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + + if text_encoders[0].dtype == torch.float8_e4m3fn: + # if we load fp8 weights, the model is already fp8, so we use it as is + self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) + else: + # otherwise, we need to convert it to target dtype + text_encoders[0].to(weight_dtype) + + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) + + # cache sample prompts + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask + ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs + + accelerator.wait_for_everyone() + + # move back to cpu + if not self.is_train_text_encoder(args): + logger.info("move Gemma 2 back to cpu") + text_encoders[0].to("cpu") + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device, dtype=weight_dtype) def sample_images( self, @@ -196,12 +258,13 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): ) ) + # May not need to pack/unpack? # pack latents and get img_ids - 这部分可以保留因为NextDiT也需要packed格式的输入 - packed_noisy_model_input = lumina_util.pack_latents(noisy_model_input) - packed_latent_height, packed_latent_width = ( - noisy_model_input.shape[2] // 2, - noisy_model_input.shape[3] // 2, - ) + # packed_noisy_model_input = lumina_util.pack_latents(noisy_model_input) + # packed_latent_height, packed_latent_width = ( + # noisy_model_input.shape[2] // 2, + # noisy_model_input.shape[3] // 2, + # ) # ensure the hidden state will require grad if args.gradient_checkpointing: @@ -212,32 +275,30 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # Unpack Gemma2 outputs gemma2_hidden_states, gemma2_attn_mask, input_ids = text_encoder_conds - if not args.apply_gemma2_attn_mask: - gemma2_attn_mask = None - def call_dit(img, gemma2_hidden_states, input_ids, timesteps, gemma2_attn_mask): + def call_dit(img, gemma2_hidden_states, timesteps, gemma2_attn_mask): with torch.set_grad_enabled(is_train), accelerator.autocast(): # NextDiT forward expects (x, t, cap_feats, cap_mask) model_pred = unet( - x=img, # packed latents + x=img, # image latents (B, C, H, W) t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期 cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features - cap_mask=gemma2_attn_mask, # Gemma2的attention mask + cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask ) return model_pred model_pred = call_dit( - img=packed_noisy_model_input, + img=noisy_model_input, gemma2_hidden_states=gemma2_hidden_states, - input_ids=input_ids, timesteps=timesteps, gemma2_attn_mask=gemma2_attn_mask, ) + # May not need to pack/unpack? # unpack latents - model_pred = lumina_util.unpack_latents( - model_pred, packed_latent_height, packed_latent_width - ) + # model_pred = lumina_util.unpack_latents( + # model_pred, packed_latent_height, packed_latent_width + # ) # apply model prediction type model_pred, weighting = flux_train_utils.apply_model_prediction_type( diff --git a/networks/lora_lumina.py b/networks/lora_lumina.py index d554ce13..3f6c9b41 100644 --- a/networks/lora_lumina.py +++ b/networks/lora_lumina.py @@ -462,7 +462,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, wei class LoRANetwork(torch.nn.Module): LUMINA_TARGET_REPLACE_MODULE = ["JointTransformerBlock"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["GemmaAttention", "GemmaDecoderLayer", "GemmaMLP"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Gemma2Attention", "Gemma2MLP"] LORA_PREFIX_LUMINA = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" # Simplified prefix since we only have one text encoder @@ -533,7 +533,7 @@ class LoRANetwork(torch.nn.Module): filter: Optional[str] = None, default_dim: Optional[int] = None, ) -> List[LoRAModule]: - prefix = self.LORA_PREFIX_FLUX if is_lumina else self.LORA_PREFIX_TEXT_ENCODER + prefix = self.LORA_PREFIX_LUMINA if is_lumina else self.LORA_PREFIX_TEXT_ENCODER loras = [] skipped = [] @@ -611,7 +611,7 @@ class LoRANetwork(torch.nn.Module): skipped_te = [] logger.info(f"create LoRA for Gemma2 Text Encoder:") - text_encoder_loras, skipped = create_modules(False, text_encoders, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + text_encoder_loras, skipped = create_modules(False, text_encoders[0], LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) logger.info(f"create LoRA for Gemma2 Text Encoder: {len(text_encoder_loras)} modules.") self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped @@ -718,10 +718,10 @@ class LoRANetwork(torch.nn.Module): def state_dict(self, destination=None, prefix="", keep_vars=False): if not self.split_qkv: - return super().state_dict(destination, prefix, keep_vars) + return super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) # merge qkv - state_dict = super().state_dict(destination, prefix, keep_vars) + state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) new_state_dict = {} for key in list(state_dict.keys()): if "double" in key and "qkv" in key: