mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
Add caching gemma2, add gradient checkpointing, refactor lumina model code
This commit is contained in:
@@ -16,6 +16,8 @@ from dataclasses import dataclass
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
@@ -91,6 +93,25 @@ class LuminaParams:
|
||||
)
|
||||
|
||||
|
||||
class GradientCheckpointMixin(nn.Module):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.cpu_offload_checkpointing = False
|
||||
|
||||
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
|
||||
self.gradient_checkpointing = True
|
||||
|
||||
def disable_gradient_checkpointing(self, cpu_offload: bool = False):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
|
||||
else:
|
||||
return self._forward(*args, **kwargs)
|
||||
|
||||
#############################################################################
|
||||
# RMSNorm #
|
||||
#############################################################################
|
||||
@@ -114,7 +135,7 @@ class RMSNorm(torch.nn.Module):
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
def _norm(self, x) -> Tensor:
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
|
||||
@@ -125,21 +146,14 @@ class RMSNorm(torch.nn.Module):
|
||||
torch.Tensor: The normalized tensor.
|
||||
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
|
||||
"""
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return ((x * rrms) * self.weight.float()).to(dtype=x_dtype)
|
||||
|
||||
|
||||
def modulate(x, scale):
|
||||
@@ -151,7 +165,7 @@ def modulate(x, scale):
|
||||
#############################################################################
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
class TimestepEmbedder(GradientCheckpointMixin):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
@@ -203,11 +217,32 @@ class TimestepEmbedder(nn.Module):
|
||||
)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
def _forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
|
||||
return t_emb
|
||||
|
||||
def to_cuda(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.cuda()
|
||||
elif isinstance(x, (list, tuple)):
|
||||
return [to_cuda(elem) for elem in x]
|
||||
elif isinstance(x, dict):
|
||||
return {k: to_cuda(v) for k, v in x.items()}
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def to_cpu(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.cpu()
|
||||
elif isinstance(x, (list, tuple)):
|
||||
return [to_cpu(elem) for elem in x]
|
||||
elif isinstance(x, dict):
|
||||
return {k: to_cpu(v) for k, v in x.items()}
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
#############################################################################
|
||||
# Core NextDiT Model #
|
||||
@@ -284,7 +319,7 @@ class JointAttention(nn.Module):
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
|
||||
and key tensor with rotary embeddings.
|
||||
"""
|
||||
with torch.amp.autocast("cuda",enabled=False):
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||
@@ -496,15 +531,15 @@ class FeedForward(nn.Module):
|
||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||
|
||||
|
||||
class JointTransformerBlock(nn.Module):
|
||||
class JointTransformerBlock(GradientCheckpointMixin):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
n_kv_heads: Optional[int],
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: float,
|
||||
ffn_dim_multiplier: Optional[float],
|
||||
norm_eps: float,
|
||||
qk_norm: bool,
|
||||
modulation=True,
|
||||
@@ -520,7 +555,7 @@ class JointTransformerBlock(nn.Module):
|
||||
value features (if using GQA), or set to None for the same as
|
||||
query.
|
||||
multiple_of (int):
|
||||
ffn_dim_multiplier (float):
|
||||
ffn_dim_multiplier (Optional[float]):
|
||||
norm_eps (float):
|
||||
|
||||
"""
|
||||
@@ -554,7 +589,7 @@ class JointTransformerBlock(nn.Module):
|
||||
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
||||
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
||||
|
||||
def forward(
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
@@ -608,7 +643,7 @@ class JointTransformerBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
class FinalLayer(GradientCheckpointMixin):
|
||||
"""
|
||||
The final layer of NextDiT.
|
||||
"""
|
||||
@@ -661,22 +696,21 @@ class RopeEmbedder:
|
||||
self.axes_dims, self.axes_lens, theta=self.theta
|
||||
)
|
||||
|
||||
def __call__(self, ids: torch.Tensor):
|
||||
def get_freqs_cis(self, ids: torch.Tensor):
|
||||
self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
|
||||
result = []
|
||||
for i in range(len(self.axes_dims)):
|
||||
# import torch.distributed as dist
|
||||
# if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
index = (
|
||||
ids[:, :, i : i + 1]
|
||||
.repeat(1, 1, self.freqs_cis[i].shape[-1])
|
||||
.to(torch.int64)
|
||||
)
|
||||
|
||||
axes = self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1)
|
||||
|
||||
result.append(
|
||||
torch.gather(
|
||||
self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1),
|
||||
axes,
|
||||
dim=1,
|
||||
index=index,
|
||||
)
|
||||
@@ -790,76 +824,98 @@ class NextDiT(nn.Module):
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.cpu_offload_checkpointing = False
|
||||
self.blocks_to_swap = None
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
|
||||
self.gradient_checkpointing = True
|
||||
self.cpu_offload_checkpointing = cpu_offload
|
||||
|
||||
self.t_embedder.enable_gradient_checkpointing()
|
||||
|
||||
for block in self.layers + self.context_refiner + self.noise_refiner:
|
||||
block.enable_gradient_checkpointing(cpu_offload=cpu_offload)
|
||||
|
||||
self.final_layer.enable_gradient_checkpointing()
|
||||
|
||||
print(f"Lumina: Gradient checkpointing enabled. CPU offload: {cpu_offload}")
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing = False
|
||||
self.cpu_offload_checkpointing = False
|
||||
|
||||
self.t_embedder.disable_gradient_checkpointing()
|
||||
|
||||
for block in self.layers + self.context_refiner + self.noise_refiner:
|
||||
block.disable_gradient_checkpointing()
|
||||
|
||||
self.final_layer.disable_gradient_checkpointing()
|
||||
|
||||
print("Lumina: Gradient checkpointing disabled.")
|
||||
|
||||
def unpatchify(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
img_size: List[Tuple[int, int]],
|
||||
cap_size: List[int],
|
||||
return_tensor=False,
|
||||
) -> List[torch.Tensor]:
|
||||
width: int,
|
||||
height: int,
|
||||
encoder_seq_lengths: List[int],
|
||||
seq_lengths: List[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
pH = pW = self.patch_size
|
||||
imgs = []
|
||||
for i in range(x.size(0)):
|
||||
H, W = img_size[i]
|
||||
begin = cap_size[i]
|
||||
end = begin + (H // pH) * (W // pW)
|
||||
imgs.append(
|
||||
x[i][begin:end]
|
||||
.view(H // pH, W // pW, pH, pW, self.out_channels)
|
||||
|
||||
output = []
|
||||
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
|
||||
output.append(
|
||||
x[i][encoder_seq_len:seq_len]
|
||||
.view(height // pH, width // pW, pH, pW, self.out_channels)
|
||||
.permute(4, 0, 2, 1, 3)
|
||||
.flatten(3, 4)
|
||||
.flatten(1, 2)
|
||||
)
|
||||
output = torch.stack(output, dim=0)
|
||||
|
||||
if return_tensor:
|
||||
imgs = torch.stack(imgs, dim=0)
|
||||
return imgs
|
||||
return output
|
||||
|
||||
def patchify_and_embed(
|
||||
self,
|
||||
x: List[torch.Tensor] | torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
cap_feats: torch.Tensor,
|
||||
cap_mask: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
) -> Tuple[
|
||||
torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, List[int], List[int]
|
||||
]:
|
||||
bsz = len(x)
|
||||
bsz, channels, height, width = x.shape
|
||||
pH = pW = self.patch_size
|
||||
device = x[0].device
|
||||
device = x.device
|
||||
|
||||
l_effective_cap_len = cap_mask.sum(dim=1).tolist()
|
||||
img_sizes = [(img.size(1), img.size(2)) for img in x]
|
||||
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
|
||||
encoder_seq_len = cap_mask.shape[1]
|
||||
|
||||
max_seq_len = max(
|
||||
(
|
||||
cap_len + img_len
|
||||
for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)
|
||||
)
|
||||
)
|
||||
max_cap_len = max(l_effective_cap_len)
|
||||
max_img_len = max(l_effective_img_len)
|
||||
image_seq_len = (height // self.patch_size) * (width // self.patch_size)
|
||||
seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len]
|
||||
max_seq_len = max(seq_lengths)
|
||||
|
||||
position_ids = torch.zeros(
|
||||
bsz, max_seq_len, 3, dtype=torch.int32, device=device
|
||||
)
|
||||
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
|
||||
|
||||
for i in range(bsz):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
H, W = img_sizes[i]
|
||||
H_tokens, W_tokens = H // pH, W // pW
|
||||
assert H_tokens * W_tokens == img_len
|
||||
for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
||||
H_tokens, W_tokens = height // pH, width // pW
|
||||
|
||||
position_ids[i, :cap_len, 0] = torch.arange(
|
||||
cap_len, dtype=torch.int32, device=device
|
||||
)
|
||||
position_ids[i, cap_len : cap_len + img_len, 0] = cap_len
|
||||
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
|
||||
position_ids[i, cap_len : cap_len + seq_len, 0] = cap_len
|
||||
row_ids = (
|
||||
torch.arange(H_tokens, dtype=torch.int32, device=device)
|
||||
.view(-1, 1)
|
||||
@@ -872,77 +928,40 @@ class NextDiT(nn.Module):
|
||||
.repeat(H_tokens, 1)
|
||||
.flatten()
|
||||
)
|
||||
position_ids[i, cap_len : cap_len + img_len, 1] = row_ids
|
||||
position_ids[i, cap_len : cap_len + img_len, 2] = col_ids
|
||||
position_ids[i, cap_len : cap_len + seq_len, 1] = row_ids
|
||||
position_ids[i, cap_len : cap_len + seq_len, 2] = col_ids
|
||||
|
||||
freqs_cis = self.rope_embedder(position_ids)
|
||||
freqs_cis = self.rope_embedder.get_freqs_cis(position_ids)
|
||||
|
||||
# build freqs_cis for cap and image individually
|
||||
cap_freqs_cis_shape = list(freqs_cis.shape)
|
||||
# cap_freqs_cis_shape[1] = max_cap_len
|
||||
cap_freqs_cis_shape[1] = cap_feats.shape[1]
|
||||
cap_freqs_cis = torch.zeros(
|
||||
*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype
|
||||
)
|
||||
cap_freqs_cis = torch.zeros(bsz, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype)
|
||||
img_freqs_cis = torch.zeros(bsz, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
img_freqs_cis_shape = list(freqs_cis.shape)
|
||||
img_freqs_cis_shape[1] = max_img_len
|
||||
img_freqs_cis = torch.zeros(
|
||||
*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype
|
||||
)
|
||||
|
||||
for i in range(bsz):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
||||
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
|
||||
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len : cap_len + img_len]
|
||||
img_freqs_cis[i, :seq_len] = freqs_cis[i, cap_len : cap_len + seq_len]
|
||||
|
||||
x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)
|
||||
x_mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device)
|
||||
|
||||
# refine context
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
|
||||
|
||||
# refine image
|
||||
flat_x = []
|
||||
for i in range(bsz):
|
||||
img = x[i]
|
||||
C, H, W = img.size()
|
||||
img = (
|
||||
img.view(C, H // pH, pH, W // pW, pW)
|
||||
.permute(1, 3, 2, 4, 0)
|
||||
.flatten(2)
|
||||
.flatten(0, 1)
|
||||
)
|
||||
flat_x.append(img)
|
||||
x = flat_x
|
||||
padded_img_embed = torch.zeros(
|
||||
bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype
|
||||
)
|
||||
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=torch.bool, device=device)
|
||||
for i in range(bsz):
|
||||
padded_img_embed[i, : l_effective_img_len[i]] = x[i]
|
||||
padded_img_mask[i, : l_effective_img_len[i]] = True
|
||||
x = self.x_embedder(x)
|
||||
|
||||
padded_img_embed = self.x_embedder(padded_img_embed)
|
||||
for layer in self.noise_refiner:
|
||||
padded_img_embed = layer(
|
||||
padded_img_embed, padded_img_mask, img_freqs_cis, t
|
||||
)
|
||||
x = layer(x, x_mask, img_freqs_cis, t)
|
||||
|
||||
mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device)
|
||||
padded_full_embed = torch.zeros(
|
||||
bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype
|
||||
)
|
||||
for i in range(bsz):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
joint_hidden_states = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x.dtype)
|
||||
attention_mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device)
|
||||
for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
||||
attention_mask[i, :seq_len] = True
|
||||
joint_hidden_states[i, :cap_len] = cap_feats[i, :cap_len]
|
||||
joint_hidden_states[i, cap_len:seq_len] = x[i]
|
||||
|
||||
mask[i, : cap_len + img_len] = True
|
||||
padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
|
||||
padded_full_embed[i, cap_len : cap_len + img_len] = padded_img_embed[
|
||||
i, :img_len
|
||||
]
|
||||
x = joint_hidden_states
|
||||
|
||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
||||
return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths
|
||||
|
||||
def forward(self, x, t, cap_feats, cap_mask):
|
||||
"""
|
||||
@@ -950,30 +969,19 @@ class NextDiT(nn.Module):
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N,) tensor of text tokens/features
|
||||
"""
|
||||
|
||||
# import torch.distributed as dist
|
||||
# if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
# torch.save([x, t, cap_feats, cap_mask], "./fake_input.pt")
|
||||
_, _, height, width = x.shape # B, C, H, W
|
||||
t = self.t_embedder(t) # (N, D)
|
||||
adaln_input = t
|
||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||
|
||||
cap_feats = self.cap_embedder(
|
||||
cap_feats
|
||||
) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||
|
||||
x_is_tensor = isinstance(x, torch.Tensor)
|
||||
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(
|
||||
x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed(
|
||||
x, cap_feats, cap_mask, t
|
||||
)
|
||||
freqs_cis = freqs_cis.to(x.device)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask, freqs_cis, adaln_input)
|
||||
x = layer(x, mask, freqs_cis, t)
|
||||
|
||||
x = self.final_layer(x, adaln_input)
|
||||
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)
|
||||
x = self.final_layer(x, t)
|
||||
x = self.unpatchify(x, width, height, l_effective_cap_len, seq_lengths)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
from transformers import AutoTokenizer, AutoModel, GemmaTokenizerFast
|
||||
from library import train_util
|
||||
from library.strategy_base import (
|
||||
LatentsCachingStrategy,
|
||||
@@ -27,34 +27,35 @@ class LuminaTokenizeStrategy(TokenizeStrategy):
|
||||
def __init__(
|
||||
self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None
|
||||
) -> None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.tokenizer: GemmaTokenizerFast = AutoTokenizer.from_pretrained(
|
||||
GEMMA_ID, cache_dir=tokenizer_cache_dir
|
||||
)
|
||||
self.tokenizer.padding_side = "right"
|
||||
|
||||
if max_length is None:
|
||||
self.max_length = self.tokenizer.model_max_length
|
||||
self.max_length = 256
|
||||
else:
|
||||
self.max_length = max_length
|
||||
|
||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||
def tokenize(self, text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
encodings = self.tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
max_length=self.max_length,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
pad_to_multiple_of=8,
|
||||
truncation=True,
|
||||
)
|
||||
return [encodings.input_ids]
|
||||
return encodings.input_ids, encodings.attention_mask
|
||||
|
||||
def tokenize_with_weights(
|
||||
self, text: str | List[str]
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||
# Gemma doesn't support weighted prompts, return uniform weights
|
||||
tokens = self.tokenize(text)
|
||||
tokens, attention_masks = self.tokenize(text)
|
||||
weights = [torch.ones_like(t) for t in tokens]
|
||||
return tokens, weights
|
||||
return tokens, attention_masks, weights
|
||||
|
||||
|
||||
class LuminaTextEncodingStrategy(TextEncodingStrategy):
|
||||
@@ -66,50 +67,39 @@ class LuminaTextEncodingStrategy(TextEncodingStrategy):
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens: List[torch.Tensor],
|
||||
tokens: torch.Tensor,
|
||||
attention_masks: torch.Tensor,
|
||||
apply_gemma2_attn_mask: Optional[bool] = None,
|
||||
) -> List[torch.Tensor]:
|
||||
|
||||
) -> torch.Tensor:
|
||||
if apply_gemma2_attn_mask is None:
|
||||
apply_gemma2_attn_mask = self.apply_gemma2_attn_mask
|
||||
|
||||
text_encoder = models[0]
|
||||
input_ids = tokens[0].to(text_encoder.device)
|
||||
|
||||
attention_mask = None
|
||||
position_ids = None
|
||||
if apply_gemma2_attn_mask:
|
||||
# Create attention mask (1 for non-padding, 0 for padding)
|
||||
attention_mask = (input_ids != tokenize_strategy.tokenizer.pad_token_id).to(
|
||||
text_encoder.device
|
||||
)
|
||||
# Create position IDs
|
||||
position_ids = attention_masks.cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_masks == 0, 1)
|
||||
|
||||
# Create position IDs
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
outputs = text_encoder(
|
||||
input_ids=tokens.to(text_encoder.device),
|
||||
attention_mask=attention_masks.to(text_encoder.device) if apply_gemma2_attn_mask else None,
|
||||
position_ids=position_ids.to(text_encoder.device),
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
# Get the last hidden state
|
||||
hidden_states = outputs.last_hidden_state
|
||||
|
||||
return [hidden_states]
|
||||
return outputs.hidden_states[-2]
|
||||
|
||||
def encode_tokens_with_weights(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens_list: List[torch.Tensor],
|
||||
tokens: torch.Tensor,
|
||||
weights_list: List[torch.Tensor],
|
||||
) -> List[torch.Tensor]:
|
||||
attention_masks: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
# For simplicity, use uniform weighting
|
||||
return self.encode_tokens(tokenize_strategy, models, tokens_list)
|
||||
return self.encode_tokens(tokenize_strategy, models, tokens, attention_masks)
|
||||
|
||||
|
||||
class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
@@ -149,6 +139,15 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
npz = np.load(npz_path)
|
||||
if "hidden_state" not in npz:
|
||||
return False
|
||||
if "attention_mask" not in npz:
|
||||
return False
|
||||
if "input_ids" not in npz:
|
||||
return False
|
||||
if "apply_gemma2_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_gemma2_attn_mask = npz["apply_gemma2_attn_mask"]
|
||||
if npz_apply_gemma2_attn_mask != self.apply_gemma2_attn_mask:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
@@ -158,13 +157,15 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
data = np.load(npz_path)
|
||||
hidden_state = data["hidden_state"]
|
||||
return [hidden_state]
|
||||
attention_mask = data["attention_mask"]
|
||||
input_ids = data["input_ids"]
|
||||
return [hidden_state, attention_mask, input_ids]
|
||||
|
||||
def cache_batch_outputs(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
tokenize_strategy: LuminaTokenizeStrategy,
|
||||
models: List[Any],
|
||||
text_encoding_strategy: TextEncodingStrategy,
|
||||
text_encoding_strategy: LuminaTextEncodingStrategy,
|
||||
infos: List,
|
||||
):
|
||||
lumina_text_encoding_strategy: LuminaTextEncodingStrategy = (
|
||||
@@ -173,35 +174,44 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
captions = [info.caption for info in infos]
|
||||
|
||||
if self.is_weighted:
|
||||
tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(
|
||||
tokens, attention_masks, weights_list = tokenize_strategy.tokenize_with_weights(
|
||||
captions
|
||||
)
|
||||
with torch.no_grad():
|
||||
hidden_state = lumina_text_encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy, models, tokens_list, weights_list
|
||||
)[0]
|
||||
tokenize_strategy, models, tokens, weights_list, attention_masks
|
||||
)
|
||||
else:
|
||||
tokens = tokenize_strategy.tokenize(captions)
|
||||
tokens, attention_masks = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
hidden_state = lumina_text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, models, tokens
|
||||
)[0]
|
||||
tokenize_strategy, models, tokens, attention_masks
|
||||
)
|
||||
|
||||
if hidden_state.dtype == torch.bfloat16:
|
||||
if hidden_state.dtype != torch.float32:
|
||||
hidden_state = hidden_state.float()
|
||||
|
||||
hidden_state = hidden_state.cpu().numpy()
|
||||
attention_mask = attention_masks.cpu().numpy()
|
||||
input_ids = tokens.cpu().numpy()
|
||||
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
hidden_state_i = hidden_state[i]
|
||||
attention_mask_i = attention_mask[i]
|
||||
input_ids_i = input_ids[i]
|
||||
apply_gemma2_attn_mask_i = self.apply_gemma2_attn_mask
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
hidden_state=hidden_state_i,
|
||||
attention_mask=attention_mask_i,
|
||||
input_ids=input_ids_i,
|
||||
apply_gemma2_attn_mask=apply_gemma2_attn_mask_i,
|
||||
)
|
||||
else:
|
||||
info.text_encoder_outputs = [hidden_state_i]
|
||||
info.text_encoder_outputs = [hidden_state_i, attention_mask_i, input_ids_i]
|
||||
|
||||
|
||||
class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
|
||||
@@ -62,6 +62,19 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
disable_mmap=args.disable_mmap_load_safetensors,
|
||||
)
|
||||
|
||||
if args.fp8_base:
|
||||
# check dtype of model
|
||||
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
|
||||
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
|
||||
elif model.dtype == torch.float8_e4m3fn:
|
||||
logger.info("Loaded fp8 Lumina 2 model")
|
||||
else:
|
||||
logger.info(
|
||||
"Cast Lumina 2 model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
|
||||
" / Lumina 2モデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
|
||||
)
|
||||
model.to(torch.float8_e4m3fn)
|
||||
|
||||
# if args.blocks_to_swap:
|
||||
# logger.info(f'Enabling block swap: {args.blocks_to_swap}')
|
||||
# model.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
@@ -70,6 +83,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
gemma2 = lumina_util.load_gemma2(
|
||||
args.gemma2, weight_dtype, "cpu"
|
||||
)
|
||||
gemma2.eval()
|
||||
ae = lumina_util.load_ae(
|
||||
args.ae, weight_dtype, "cpu"
|
||||
)
|
||||
@@ -118,17 +132,65 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
dataset,
|
||||
weight_dtype,
|
||||
):
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder_outputs_caching_strategy = (
|
||||
self.get_text_encoder_outputs_caching_strategy(args)
|
||||
)
|
||||
if text_encoder_outputs_caching_strategy is not None:
|
||||
text_encoder_outputs_caching_strategy.cache_batch_outputs(
|
||||
self.get_tokenize_strategy(args),
|
||||
[text_encoder],
|
||||
self.get_text_encoding_strategy(args),
|
||||
dataset,
|
||||
)
|
||||
if args.cache_text_encoder_outputs:
|
||||
if not args.lowram:
|
||||
# メモリ消費を減らす
|
||||
logger.info("move vae and unet to cpu to save memory")
|
||||
org_vae_device = vae.device
|
||||
org_unet_device = unet.device
|
||||
vae.to("cpu")
|
||||
unet.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||
logger.info("move text encoders to gpu")
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
||||
|
||||
if text_encoders[0].dtype == torch.float8_e4m3fn:
|
||||
# if we load fp8 weights, the model is already fp8, so we use it as is
|
||||
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
|
||||
else:
|
||||
# otherwise, we need to convert it to target dtype
|
||||
text_encoders[0].to(weight_dtype)
|
||||
|
||||
with accelerator.autocast():
|
||||
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
||||
|
||||
# cache sample prompts
|
||||
if args.sample_prompts is not None:
|
||||
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
|
||||
|
||||
tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
||||
if p not in sample_prompts_te_outputs:
|
||||
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(p)
|
||||
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask
|
||||
)
|
||||
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# move back to cpu
|
||||
if not self.is_train_text_encoder(args):
|
||||
logger.info("move Gemma 2 back to cpu")
|
||||
text_encoders[0].to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
if not args.lowram:
|
||||
logger.info("move vae and unet back to original device")
|
||||
vae.to(org_vae_device)
|
||||
unet.to(org_unet_device)
|
||||
else:
|
||||
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
def sample_images(
|
||||
self,
|
||||
@@ -196,12 +258,13 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
)
|
||||
)
|
||||
|
||||
# May not need to pack/unpack?
|
||||
# pack latents and get img_ids - 这部分可以保留因为NextDiT也需要packed格式的输入
|
||||
packed_noisy_model_input = lumina_util.pack_latents(noisy_model_input)
|
||||
packed_latent_height, packed_latent_width = (
|
||||
noisy_model_input.shape[2] // 2,
|
||||
noisy_model_input.shape[3] // 2,
|
||||
)
|
||||
# packed_noisy_model_input = lumina_util.pack_latents(noisy_model_input)
|
||||
# packed_latent_height, packed_latent_width = (
|
||||
# noisy_model_input.shape[2] // 2,
|
||||
# noisy_model_input.shape[3] // 2,
|
||||
# )
|
||||
|
||||
# ensure the hidden state will require grad
|
||||
if args.gradient_checkpointing:
|
||||
@@ -212,32 +275,30 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
# Unpack Gemma2 outputs
|
||||
gemma2_hidden_states, gemma2_attn_mask, input_ids = text_encoder_conds
|
||||
if not args.apply_gemma2_attn_mask:
|
||||
gemma2_attn_mask = None
|
||||
|
||||
def call_dit(img, gemma2_hidden_states, input_ids, timesteps, gemma2_attn_mask):
|
||||
def call_dit(img, gemma2_hidden_states, timesteps, gemma2_attn_mask):
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# NextDiT forward expects (x, t, cap_feats, cap_mask)
|
||||
model_pred = unet(
|
||||
x=img, # packed latents
|
||||
x=img, # image latents (B, C, H, W)
|
||||
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
||||
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
||||
cap_mask=gemma2_attn_mask, # Gemma2的attention mask
|
||||
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
|
||||
)
|
||||
return model_pred
|
||||
|
||||
model_pred = call_dit(
|
||||
img=packed_noisy_model_input,
|
||||
img=noisy_model_input,
|
||||
gemma2_hidden_states=gemma2_hidden_states,
|
||||
input_ids=input_ids,
|
||||
timesteps=timesteps,
|
||||
gemma2_attn_mask=gemma2_attn_mask,
|
||||
)
|
||||
|
||||
# May not need to pack/unpack?
|
||||
# unpack latents
|
||||
model_pred = lumina_util.unpack_latents(
|
||||
model_pred, packed_latent_height, packed_latent_width
|
||||
)
|
||||
# model_pred = lumina_util.unpack_latents(
|
||||
# model_pred, packed_latent_height, packed_latent_width
|
||||
# )
|
||||
|
||||
# apply model prediction type
|
||||
model_pred, weighting = flux_train_utils.apply_model_prediction_type(
|
||||
|
||||
@@ -462,7 +462,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, wei
|
||||
|
||||
class LoRANetwork(torch.nn.Module):
|
||||
LUMINA_TARGET_REPLACE_MODULE = ["JointTransformerBlock"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["GemmaAttention", "GemmaDecoderLayer", "GemmaMLP"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Gemma2Attention", "Gemma2MLP"]
|
||||
LORA_PREFIX_LUMINA = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te" # Simplified prefix since we only have one text encoder
|
||||
|
||||
@@ -533,7 +533,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
filter: Optional[str] = None,
|
||||
default_dim: Optional[int] = None,
|
||||
) -> List[LoRAModule]:
|
||||
prefix = self.LORA_PREFIX_FLUX if is_lumina else self.LORA_PREFIX_TEXT_ENCODER
|
||||
prefix = self.LORA_PREFIX_LUMINA if is_lumina else self.LORA_PREFIX_TEXT_ENCODER
|
||||
|
||||
loras = []
|
||||
skipped = []
|
||||
@@ -611,7 +611,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
skipped_te = []
|
||||
|
||||
logger.info(f"create LoRA for Gemma2 Text Encoder:")
|
||||
text_encoder_loras, skipped = create_modules(False, text_encoders, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
text_encoder_loras, skipped = create_modules(False, text_encoders[0], LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
logger.info(f"create LoRA for Gemma2 Text Encoder: {len(text_encoder_loras)} modules.")
|
||||
self.text_encoder_loras.extend(text_encoder_loras)
|
||||
skipped_te += skipped
|
||||
@@ -718,10 +718,10 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
if not self.split_qkv:
|
||||
return super().state_dict(destination, prefix, keep_vars)
|
||||
return super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
|
||||
# merge qkv
|
||||
state_dict = super().state_dict(destination, prefix, keep_vars)
|
||||
state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
new_state_dict = {}
|
||||
for key in list(state_dict.keys()):
|
||||
if "double" in key and "qkv" in key:
|
||||
|
||||
Reference in New Issue
Block a user