Add caching gemma2, add gradient checkpointing, refactor lumina model code

This commit is contained in:
rockerBOO
2025-02-16 01:06:34 -05:00
parent a00b06bc97
commit 60a76ebb72
4 changed files with 304 additions and 225 deletions

View File

@@ -16,6 +16,8 @@ from dataclasses import dataclass
from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import torch.nn as nn
import torch.nn.functional as F
@@ -91,6 +93,25 @@ class LuminaParams:
)
class GradientCheckpointMixin(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
self.gradient_checkpointing = True
def disable_gradient_checkpointing(self, cpu_offload: bool = False):
self.gradient_checkpointing = False
def forward(self, *args, **kwargs):
if self.training and self.gradient_checkpointing:
return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
else:
return self._forward(*args, **kwargs)
#############################################################################
# RMSNorm #
#############################################################################
@@ -114,7 +135,7 @@ class RMSNorm(torch.nn.Module):
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
def _norm(self, x) -> Tensor:
"""
Apply the RMSNorm normalization to the input tensor.
@@ -125,21 +146,14 @@ class RMSNorm(torch.nn.Module):
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
return output * self.weight
def forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return ((x * rrms) * self.weight.float()).to(dtype=x_dtype)
def modulate(x, scale):
@@ -151,7 +165,7 @@ def modulate(x, scale):
#############################################################################
class TimestepEmbedder(nn.Module):
class TimestepEmbedder(GradientCheckpointMixin):
"""
Embeds scalar timesteps into vector representations.
"""
@@ -203,11 +217,32 @@ class TimestepEmbedder(nn.Module):
)
return embedding
def forward(self, t):
def _forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
return t_emb
def to_cuda(x):
if isinstance(x, torch.Tensor):
return x.cuda()
elif isinstance(x, (list, tuple)):
return [to_cuda(elem) for elem in x]
elif isinstance(x, dict):
return {k: to_cuda(v) for k, v in x.items()}
else:
return x
def to_cpu(x):
if isinstance(x, torch.Tensor):
return x.cpu()
elif isinstance(x, (list, tuple)):
return [to_cpu(elem) for elem in x]
elif isinstance(x, dict):
return {k: to_cpu(v) for k, v in x.items()}
else:
return x
#############################################################################
# Core NextDiT Model #
@@ -284,7 +319,7 @@ class JointAttention(nn.Module):
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
and key tensor with rotary embeddings.
"""
with torch.amp.autocast("cuda",enabled=False):
with torch.autocast("cuda", enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
@@ -496,15 +531,15 @@ class FeedForward(nn.Module):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
class JointTransformerBlock(nn.Module):
class JointTransformerBlock(GradientCheckpointMixin):
def __init__(
self,
layer_id: int,
dim: int,
n_heads: int,
n_kv_heads: int,
n_kv_heads: Optional[int],
multiple_of: int,
ffn_dim_multiplier: float,
ffn_dim_multiplier: Optional[float],
norm_eps: float,
qk_norm: bool,
modulation=True,
@@ -520,7 +555,7 @@ class JointTransformerBlock(nn.Module):
value features (if using GQA), or set to None for the same as
query.
multiple_of (int):
ffn_dim_multiplier (float):
ffn_dim_multiplier (Optional[float]):
norm_eps (float):
"""
@@ -554,7 +589,7 @@ class JointTransformerBlock(nn.Module):
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)
def forward(
def _forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
@@ -608,7 +643,7 @@ class JointTransformerBlock(nn.Module):
return x
class FinalLayer(nn.Module):
class FinalLayer(GradientCheckpointMixin):
"""
The final layer of NextDiT.
"""
@@ -661,22 +696,21 @@ class RopeEmbedder:
self.axes_dims, self.axes_lens, theta=self.theta
)
def __call__(self, ids: torch.Tensor):
def get_freqs_cis(self, ids: torch.Tensor):
self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
result = []
for i in range(len(self.axes_dims)):
# import torch.distributed as dist
# if not dist.is_initialized() or dist.get_rank() == 0:
# import pdb
# pdb.set_trace()
index = (
ids[:, :, i : i + 1]
.repeat(1, 1, self.freqs_cis[i].shape[-1])
.to(torch.int64)
)
axes = self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1)
result.append(
torch.gather(
self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1),
axes,
dim=1,
index=index,
)
@@ -790,76 +824,98 @@ class NextDiT(nn.Module):
self.dim = dim
self.n_heads = n_heads
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
self.blocks_to_swap = None
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
self.gradient_checkpointing = True
self.cpu_offload_checkpointing = cpu_offload
self.t_embedder.enable_gradient_checkpointing()
for block in self.layers + self.context_refiner + self.noise_refiner:
block.enable_gradient_checkpointing(cpu_offload=cpu_offload)
self.final_layer.enable_gradient_checkpointing()
print(f"Lumina: Gradient checkpointing enabled. CPU offload: {cpu_offload}")
def disable_gradient_checkpointing(self):
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
self.t_embedder.disable_gradient_checkpointing()
for block in self.layers + self.context_refiner + self.noise_refiner:
block.disable_gradient_checkpointing()
self.final_layer.disable_gradient_checkpointing()
print("Lumina: Gradient checkpointing disabled.")
def unpatchify(
self,
x: torch.Tensor,
img_size: List[Tuple[int, int]],
cap_size: List[int],
return_tensor=False,
) -> List[torch.Tensor]:
width: int,
height: int,
encoder_seq_lengths: List[int],
seq_lengths: List[int],
) -> torch.Tensor:
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
pH = pW = self.patch_size
imgs = []
for i in range(x.size(0)):
H, W = img_size[i]
begin = cap_size[i]
end = begin + (H // pH) * (W // pW)
imgs.append(
x[i][begin:end]
.view(H // pH, W // pW, pH, pW, self.out_channels)
output = []
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
output.append(
x[i][encoder_seq_len:seq_len]
.view(height // pH, width // pW, pH, pW, self.out_channels)
.permute(4, 0, 2, 1, 3)
.flatten(3, 4)
.flatten(1, 2)
)
output = torch.stack(output, dim=0)
if return_tensor:
imgs = torch.stack(imgs, dim=0)
return imgs
return output
def patchify_and_embed(
self,
x: List[torch.Tensor] | torch.Tensor,
x: torch.Tensor,
cap_feats: torch.Tensor,
cap_mask: torch.Tensor,
t: torch.Tensor,
) -> Tuple[
torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor
torch.Tensor, torch.Tensor, torch.Tensor, List[int], List[int]
]:
bsz = len(x)
bsz, channels, height, width = x.shape
pH = pW = self.patch_size
device = x[0].device
device = x.device
l_effective_cap_len = cap_mask.sum(dim=1).tolist()
img_sizes = [(img.size(1), img.size(2)) for img in x]
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
encoder_seq_len = cap_mask.shape[1]
max_seq_len = max(
(
cap_len + img_len
for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)
)
)
max_cap_len = max(l_effective_cap_len)
max_img_len = max(l_effective_img_len)
image_seq_len = (height // self.patch_size) * (width // self.patch_size)
seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len]
max_seq_len = max(seq_lengths)
position_ids = torch.zeros(
bsz, max_seq_len, 3, dtype=torch.int32, device=device
)
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
H, W = img_sizes[i]
H_tokens, W_tokens = H // pH, W // pW
assert H_tokens * W_tokens == img_len
for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
H_tokens, W_tokens = height // pH, width // pW
position_ids[i, :cap_len, 0] = torch.arange(
cap_len, dtype=torch.int32, device=device
)
position_ids[i, cap_len : cap_len + img_len, 0] = cap_len
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
position_ids[i, cap_len : cap_len + seq_len, 0] = cap_len
row_ids = (
torch.arange(H_tokens, dtype=torch.int32, device=device)
.view(-1, 1)
@@ -872,77 +928,40 @@ class NextDiT(nn.Module):
.repeat(H_tokens, 1)
.flatten()
)
position_ids[i, cap_len : cap_len + img_len, 1] = row_ids
position_ids[i, cap_len : cap_len + img_len, 2] = col_ids
position_ids[i, cap_len : cap_len + seq_len, 1] = row_ids
position_ids[i, cap_len : cap_len + seq_len, 2] = col_ids
freqs_cis = self.rope_embedder(position_ids)
freqs_cis = self.rope_embedder.get_freqs_cis(position_ids)
# build freqs_cis for cap and image individually
cap_freqs_cis_shape = list(freqs_cis.shape)
# cap_freqs_cis_shape[1] = max_cap_len
cap_freqs_cis_shape[1] = cap_feats.shape[1]
cap_freqs_cis = torch.zeros(
*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype
)
cap_freqs_cis = torch.zeros(bsz, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype)
img_freqs_cis = torch.zeros(bsz, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype)
img_freqs_cis_shape = list(freqs_cis.shape)
img_freqs_cis_shape[1] = max_img_len
img_freqs_cis = torch.zeros(
*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype
)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len : cap_len + img_len]
img_freqs_cis[i, :seq_len] = freqs_cis[i, cap_len : cap_len + seq_len]
x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)
x_mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device)
# refine context
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
# refine image
flat_x = []
for i in range(bsz):
img = x[i]
C, H, W = img.size()
img = (
img.view(C, H // pH, pH, W // pW, pW)
.permute(1, 3, 2, 4, 0)
.flatten(2)
.flatten(0, 1)
)
flat_x.append(img)
x = flat_x
padded_img_embed = torch.zeros(
bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype
)
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=torch.bool, device=device)
for i in range(bsz):
padded_img_embed[i, : l_effective_img_len[i]] = x[i]
padded_img_mask[i, : l_effective_img_len[i]] = True
x = self.x_embedder(x)
padded_img_embed = self.x_embedder(padded_img_embed)
for layer in self.noise_refiner:
padded_img_embed = layer(
padded_img_embed, padded_img_mask, img_freqs_cis, t
)
x = layer(x, x_mask, img_freqs_cis, t)
mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device)
padded_full_embed = torch.zeros(
bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype
)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
joint_hidden_states = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x.dtype)
attention_mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device)
for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
attention_mask[i, :seq_len] = True
joint_hidden_states[i, :cap_len] = cap_feats[i, :cap_len]
joint_hidden_states[i, cap_len:seq_len] = x[i]
mask[i, : cap_len + img_len] = True
padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
padded_full_embed[i, cap_len : cap_len + img_len] = padded_img_embed[
i, :img_len
]
x = joint_hidden_states
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths
def forward(self, x, t, cap_feats, cap_mask):
"""
@@ -950,30 +969,19 @@ class NextDiT(nn.Module):
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of text tokens/features
"""
# import torch.distributed as dist
# if not dist.is_initialized() or dist.get_rank() == 0:
# import pdb
# pdb.set_trace()
# torch.save([x, t, cap_feats, cap_mask], "./fake_input.pt")
_, _, height, width = x.shape # B, C, H, W
t = self.t_embedder(t) # (N, D)
adaln_input = t
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
cap_feats = self.cap_embedder(
cap_feats
) # (N, L, D) # todo check if able to batchify w.o. redundant compute
x_is_tensor = isinstance(x, torch.Tensor)
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(
x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed(
x, cap_feats, cap_mask, t
)
freqs_cis = freqs_cis.to(x.device)
for layer in self.layers:
x = layer(x, mask, freqs_cis, adaln_input)
x = layer(x, mask, freqs_cis, t)
x = self.final_layer(x, adaln_input)
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)
x = self.final_layer(x, t)
x = self.unpatchify(x, width, height, l_effective_cap_len, seq_lengths)
return x

View File

@@ -3,7 +3,7 @@ import os
from typing import Any, List, Optional, Tuple, Union
import torch
from transformers import AutoTokenizer, AutoModel
from transformers import AutoTokenizer, AutoModel, GemmaTokenizerFast
from library import train_util
from library.strategy_base import (
LatentsCachingStrategy,
@@ -27,34 +27,35 @@ class LuminaTokenizeStrategy(TokenizeStrategy):
def __init__(
self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None
) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer: GemmaTokenizerFast = AutoTokenizer.from_pretrained(
GEMMA_ID, cache_dir=tokenizer_cache_dir
)
self.tokenizer.padding_side = "right"
if max_length is None:
self.max_length = self.tokenizer.model_max_length
self.max_length = 256
else:
self.max_length = max_length
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
def tokenize(self, text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor]:
text = [text] if isinstance(text, str) else text
encodings = self.tokenizer(
text,
padding="max_length",
max_length=self.max_length,
return_tensors="pt",
padding=True,
pad_to_multiple_of=8,
truncation=True,
)
return [encodings.input_ids]
return encodings.input_ids, encodings.attention_mask
def tokenize_with_weights(
self, text: str | List[str]
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
# Gemma doesn't support weighted prompts, return uniform weights
tokens = self.tokenize(text)
tokens, attention_masks = self.tokenize(text)
weights = [torch.ones_like(t) for t in tokens]
return tokens, weights
return tokens, attention_masks, weights
class LuminaTextEncodingStrategy(TextEncodingStrategy):
@@ -66,50 +67,39 @@ class LuminaTextEncodingStrategy(TextEncodingStrategy):
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
tokens: torch.Tensor,
attention_masks: torch.Tensor,
apply_gemma2_attn_mask: Optional[bool] = None,
) -> List[torch.Tensor]:
) -> torch.Tensor:
if apply_gemma2_attn_mask is None:
apply_gemma2_attn_mask = self.apply_gemma2_attn_mask
text_encoder = models[0]
input_ids = tokens[0].to(text_encoder.device)
attention_mask = None
position_ids = None
if apply_gemma2_attn_mask:
# Create attention mask (1 for non-padding, 0 for padding)
attention_mask = (input_ids != tokenize_strategy.tokenizer.pad_token_id).to(
text_encoder.device
)
# Create position IDs
position_ids = attention_masks.cumsum(-1) - 1
position_ids.masked_fill_(attention_masks == 0, 1)
# Create position IDs
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
outputs = text_encoder(
input_ids=tokens.to(text_encoder.device),
attention_mask=attention_masks.to(text_encoder.device) if apply_gemma2_attn_mask else None,
position_ids=position_ids.to(text_encoder.device),
output_hidden_states=True,
return_dict=True,
)
with torch.no_grad():
outputs = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_hidden_states=True,
return_dict=True,
)
# Get the last hidden state
hidden_states = outputs.last_hidden_state
return [hidden_states]
return outputs.hidden_states[-2]
def encode_tokens_with_weights(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens_list: List[torch.Tensor],
tokens: torch.Tensor,
weights_list: List[torch.Tensor],
) -> List[torch.Tensor]:
attention_masks: torch.Tensor
) -> torch.Tensor:
# For simplicity, use uniform weighting
return self.encode_tokens(tokenize_strategy, models, tokens_list)
return self.encode_tokens(tokenize_strategy, models, tokens, attention_masks)
class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
@@ -149,6 +139,15 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
npz = np.load(npz_path)
if "hidden_state" not in npz:
return False
if "attention_mask" not in npz:
return False
if "input_ids" not in npz:
return False
if "apply_gemma2_attn_mask" not in npz:
return False
npz_apply_gemma2_attn_mask = npz["apply_gemma2_attn_mask"]
if npz_apply_gemma2_attn_mask != self.apply_gemma2_attn_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -158,13 +157,15 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
hidden_state = data["hidden_state"]
return [hidden_state]
attention_mask = data["attention_mask"]
input_ids = data["input_ids"]
return [hidden_state, attention_mask, input_ids]
def cache_batch_outputs(
self,
tokenize_strategy: TokenizeStrategy,
tokenize_strategy: LuminaTokenizeStrategy,
models: List[Any],
text_encoding_strategy: TextEncodingStrategy,
text_encoding_strategy: LuminaTextEncodingStrategy,
infos: List,
):
lumina_text_encoding_strategy: LuminaTextEncodingStrategy = (
@@ -173,35 +174,44 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
captions = [info.caption for info in infos]
if self.is_weighted:
tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(
tokens, attention_masks, weights_list = tokenize_strategy.tokenize_with_weights(
captions
)
with torch.no_grad():
hidden_state = lumina_text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, models, tokens_list, weights_list
)[0]
tokenize_strategy, models, tokens, weights_list, attention_masks
)
else:
tokens = tokenize_strategy.tokenize(captions)
tokens, attention_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
hidden_state = lumina_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens
)[0]
tokenize_strategy, models, tokens, attention_masks
)
if hidden_state.dtype == torch.bfloat16:
if hidden_state.dtype != torch.float32:
hidden_state = hidden_state.float()
hidden_state = hidden_state.cpu().numpy()
attention_mask = attention_masks.cpu().numpy()
input_ids = tokens.cpu().numpy()
for i, info in enumerate(infos):
hidden_state_i = hidden_state[i]
attention_mask_i = attention_mask[i]
input_ids_i = input_ids[i]
apply_gemma2_attn_mask_i = self.apply_gemma2_attn_mask
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
hidden_state=hidden_state_i,
attention_mask=attention_mask_i,
input_ids=input_ids_i,
apply_gemma2_attn_mask=apply_gemma2_attn_mask_i,
)
else:
info.text_encoder_outputs = [hidden_state_i]
info.text_encoder_outputs = [hidden_state_i, attention_mask_i, input_ids_i]
class LuminaLatentsCachingStrategy(LatentsCachingStrategy):

View File

@@ -62,6 +62,19 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
disable_mmap=args.disable_mmap_load_safetensors,
)
if args.fp8_base:
# check dtype of model
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
elif model.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 Lumina 2 model")
else:
logger.info(
"Cast Lumina 2 model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
" / Lumina 2モデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
)
model.to(torch.float8_e4m3fn)
# if args.blocks_to_swap:
# logger.info(f'Enabling block swap: {args.blocks_to_swap}')
# model.enable_block_swap(args.blocks_to_swap, accelerator.device)
@@ -70,6 +83,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
gemma2 = lumina_util.load_gemma2(
args.gemma2, weight_dtype, "cpu"
)
gemma2.eval()
ae = lumina_util.load_ae(
args.ae, weight_dtype, "cpu"
)
@@ -118,17 +132,65 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
dataset,
weight_dtype,
):
for text_encoder in text_encoders:
text_encoder_outputs_caching_strategy = (
self.get_text_encoder_outputs_caching_strategy(args)
)
if text_encoder_outputs_caching_strategy is not None:
text_encoder_outputs_caching_strategy.cache_batch_outputs(
self.get_tokenize_strategy(args),
[text_encoder],
self.get_text_encoding_strategy(args),
dataset,
)
if args.cache_text_encoder_outputs:
if not args.lowram:
# メモリ消費を減らす
logger.info("move vae and unet to cpu to save memory")
org_vae_device = vae.device
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
clean_memory_on_device(accelerator.device)
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
logger.info("move text encoders to gpu")
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
if text_encoders[0].dtype == torch.float8_e4m3fn:
# if we load fp8 weights, the model is already fp8, so we use it as is
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
else:
# otherwise, we need to convert it to target dtype
text_encoders[0].to(weight_dtype)
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
# cache sample prompts
if args.sample_prompts is not None:
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
if p not in sample_prompts_te_outputs:
logger.info(f"cache Text Encoder outputs for prompt: {p}")
tokens_and_masks = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask
)
self.sample_prompts_te_outputs = sample_prompts_te_outputs
accelerator.wait_for_everyone()
# move back to cpu
if not self.is_train_text_encoder(args):
logger.info("move Gemma 2 back to cpu")
text_encoders[0].to("cpu")
clean_memory_on_device(accelerator.device)
if not args.lowram:
logger.info("move vae and unet back to original device")
vae.to(org_vae_device)
unet.to(org_unet_device)
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
def sample_images(
self,
@@ -196,12 +258,13 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
)
)
# May not need to pack/unpack?
# pack latents and get img_ids - 这部分可以保留因为NextDiT也需要packed格式的输入
packed_noisy_model_input = lumina_util.pack_latents(noisy_model_input)
packed_latent_height, packed_latent_width = (
noisy_model_input.shape[2] // 2,
noisy_model_input.shape[3] // 2,
)
# packed_noisy_model_input = lumina_util.pack_latents(noisy_model_input)
# packed_latent_height, packed_latent_width = (
# noisy_model_input.shape[2] // 2,
# noisy_model_input.shape[3] // 2,
# )
# ensure the hidden state will require grad
if args.gradient_checkpointing:
@@ -212,32 +275,30 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
# Unpack Gemma2 outputs
gemma2_hidden_states, gemma2_attn_mask, input_ids = text_encoder_conds
if not args.apply_gemma2_attn_mask:
gemma2_attn_mask = None
def call_dit(img, gemma2_hidden_states, input_ids, timesteps, gemma2_attn_mask):
def call_dit(img, gemma2_hidden_states, timesteps, gemma2_attn_mask):
with torch.set_grad_enabled(is_train), accelerator.autocast():
# NextDiT forward expects (x, t, cap_feats, cap_mask)
model_pred = unet(
x=img, # packed latents
x=img, # image latents (B, C, H, W)
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
cap_mask=gemma2_attn_mask, # Gemma2的attention mask
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
)
return model_pred
model_pred = call_dit(
img=packed_noisy_model_input,
img=noisy_model_input,
gemma2_hidden_states=gemma2_hidden_states,
input_ids=input_ids,
timesteps=timesteps,
gemma2_attn_mask=gemma2_attn_mask,
)
# May not need to pack/unpack?
# unpack latents
model_pred = lumina_util.unpack_latents(
model_pred, packed_latent_height, packed_latent_width
)
# model_pred = lumina_util.unpack_latents(
# model_pred, packed_latent_height, packed_latent_width
# )
# apply model prediction type
model_pred, weighting = flux_train_utils.apply_model_prediction_type(

View File

@@ -462,7 +462,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, wei
class LoRANetwork(torch.nn.Module):
LUMINA_TARGET_REPLACE_MODULE = ["JointTransformerBlock"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["GemmaAttention", "GemmaDecoderLayer", "GemmaMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Gemma2Attention", "Gemma2MLP"]
LORA_PREFIX_LUMINA = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te" # Simplified prefix since we only have one text encoder
@@ -533,7 +533,7 @@ class LoRANetwork(torch.nn.Module):
filter: Optional[str] = None,
default_dim: Optional[int] = None,
) -> List[LoRAModule]:
prefix = self.LORA_PREFIX_FLUX if is_lumina else self.LORA_PREFIX_TEXT_ENCODER
prefix = self.LORA_PREFIX_LUMINA if is_lumina else self.LORA_PREFIX_TEXT_ENCODER
loras = []
skipped = []
@@ -611,7 +611,7 @@ class LoRANetwork(torch.nn.Module):
skipped_te = []
logger.info(f"create LoRA for Gemma2 Text Encoder:")
text_encoder_loras, skipped = create_modules(False, text_encoders, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
text_encoder_loras, skipped = create_modules(False, text_encoders[0], LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
logger.info(f"create LoRA for Gemma2 Text Encoder: {len(text_encoder_loras)} modules.")
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
@@ -718,10 +718,10 @@ class LoRANetwork(torch.nn.Module):
def state_dict(self, destination=None, prefix="", keep_vars=False):
if not self.split_qkv:
return super().state_dict(destination, prefix, keep_vars)
return super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
# merge qkv
state_dict = super().state_dict(destination, prefix, keep_vars)
state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
new_state_dict = {}
for key in list(state_dict.keys()):
if "double" in key and "qkv" in key: