Fix samples, LoRA training. Add system prompt, use_flash_attn

This commit is contained in:
rockerBOO
2025-02-23 01:29:18 -05:00
parent 6597631b90
commit 025cca699b
10 changed files with 888 additions and 386 deletions

View File

@@ -75,6 +75,7 @@ class BaseSubsetParams:
custom_attributes: Optional[Dict[str, Any]] = None
validation_seed: int = 0
validation_split: float = 0.0
system_prompt: Optional[str] = None
@dataclass
@@ -106,6 +107,7 @@ class BaseDatasetParams:
debug_dataset: bool = False
validation_seed: Optional[int] = None
validation_split: float = 0.0
system_prompt: Optional[str] = None
@dataclass
@@ -196,6 +198,7 @@ class ConfigSanitizer:
"caption_prefix": str,
"caption_suffix": str,
"custom_attributes": dict,
"system_prompt": str,
}
# DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = {
@@ -241,6 +244,7 @@ class ConfigSanitizer:
"validation_split": float,
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,
"system_prompt": str,
}
# options handled by argparse but not handled by user config
@@ -526,6 +530,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
system_prompt: {dataset.system_prompt}
""")
if dataset.enable_bucket:
@@ -559,6 +564,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
token_warmup_step: {subset.token_warmup_step},
alpha_mask: {subset.alpha_mask}
custom_attributes: {subset.custom_attributes}
system_prompt: {subset.system_prompt}
"""), " ")
if is_dreambooth:

View File

@@ -14,14 +14,19 @@ from typing import List, Optional, Tuple
from dataclasses import dataclass
from einops import rearrange
from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import torch.nn as nn
import torch.nn.functional as F
try:
from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
except:
pass
try:
from apex.normalization import FusedRMSNorm as RMSNorm
except:
@@ -75,7 +80,15 @@ class LuminaParams:
@classmethod
def get_7b_config(cls) -> "LuminaParams":
"""Returns the configuration for the 7B parameter model"""
return cls(patch_size=2, dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, axes_dims=[64, 64, 64], axes_lens=[300, 512, 512])
return cls(
patch_size=2,
dim=4096,
n_layers=32,
n_heads=32,
n_kv_heads=8,
axes_dims=[64, 64, 64],
axes_lens=[300, 512, 512],
)
class GradientCheckpointMixin(nn.Module):
@@ -248,6 +261,7 @@ class JointAttention(nn.Module):
n_heads: int,
n_kv_heads: Optional[int],
qk_norm: bool,
use_flash_attn=False,
):
"""
Initialize the Attention module.
@@ -286,7 +300,7 @@ class JointAttention(nn.Module):
else:
self.q_norm = self.k_norm = nn.Identity()
self.flash_attn = False
self.use_flash_attn = use_flash_attn
# self.attention_processor = xformers.ops.memory_efficient_attention
self.attention_processor = F.scaled_dot_product_attention
@@ -294,35 +308,63 @@ class JointAttention(nn.Module):
def set_attention_processor(self, attention_processor):
self.attention_processor = attention_processor
@staticmethod
def apply_rotary_emb(
x_in: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
def forward(
self,
x: Tensor,
x_mask: Tensor,
freqs_cis: Tensor,
) -> Tensor:
"""
Apply rotary embeddings to input tensors using the given frequency
tensor.
This function applies rotary embeddings to the given query 'xq' and
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors
contain rotary embeddings and are returned as real tensors.
Args:
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
exponentials.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
and key tensor with rotary embeddings.
x:
x_mask:
freqs_cis:
"""
with torch.autocast("cuda", enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in)
bsz, seqlen, _ = x.shape
dtype = x.dtype
xq, xk, xv = torch.split(
self.qkv(x),
[
self.n_local_heads * self.head_dim,
self.n_local_kv_heads * self.head_dim,
self.n_local_kv_heads * self.head_dim,
],
dim=-1,
)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq = self.q_norm(xq)
xk = self.k_norm(xk)
xq = apply_rope(xq, freqs_cis=freqs_cis)
xk = apply_rope(xk, freqs_cis=freqs_cis)
xq, xk = xq.to(dtype), xk.to(dtype)
softmax_scale = math.sqrt(1 / self.head_dim)
if self.use_flash_attn:
output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale)
else:
n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
output = (
self.attention_processor(
xq.permute(0, 2, 1, 3),
xk.permute(0, 2, 1, 3),
xv.permute(0, 2, 1, 3),
attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
scale=softmax_scale,
)
.permute(0, 2, 1, 3)
.to(dtype)
)
output = output.flatten(-2)
return self.out(output)
# copied from huggingface modeling_llama.py
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
@@ -377,46 +419,17 @@ class JointAttention(nn.Module):
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
def forward(
def flash_attn(
self,
x: Tensor,
q: Tensor,
k: Tensor,
v: Tensor,
x_mask: Tensor,
freqs_cis: Tensor,
softmax_scale,
) -> Tensor:
"""
bsz, seqlen, _, _ = q.shape
Args:
x:
x_mask:
freqs_cis:
Returns:
"""
bsz, seqlen, _ = x.shape
dtype = x.dtype
xq, xk, xv = torch.split(
self.qkv(x),
[
self.n_local_heads * self.head_dim,
self.n_local_kv_heads * self.head_dim,
self.n_local_kv_heads * self.head_dim,
],
dim=-1,
)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq = self.q_norm(xq)
xk = self.k_norm(xk)
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
xq, xk = xq.to(dtype), xk.to(dtype)
softmax_scale = math.sqrt(1 / self.head_dim)
if self.flash_attn:
try:
# begin var_len flash attn
(
query_states,
@@ -425,7 +438,7 @@ class JointAttention(nn.Module):
indices_q,
cu_seq_lens,
max_seq_lens,
) = self._upad_input(xq, xk, xv, x_mask, seqlen)
) = self._upad_input(q, k, v, x_mask, seqlen)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
@@ -445,27 +458,12 @@ class JointAttention(nn.Module):
output = pad_input(attn_output_unpad, indices_q, bsz, seqlen)
# end var_len_flash_attn
else:
n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
output = (
self.attention_processor(
xq.permute(0, 2, 1, 3),
xk.permute(0, 2, 1, 3),
xv.permute(0, 2, 1, 3),
attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
scale=softmax_scale,
)
.permute(0, 2, 1, 3)
.to(dtype)
return output
except NameError as e:
raise RuntimeError(
f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}"
)
output = output.flatten(-2)
return self.out(output)
def apply_rope(
x_in: torch.Tensor,
@@ -563,6 +561,7 @@ class JointTransformerBlock(GradientCheckpointMixin):
norm_eps: float,
qk_norm: bool,
modulation=True,
use_flash_attn=False,
) -> None:
"""
Initialize a TransformerBlock.
@@ -585,7 +584,7 @@ class JointTransformerBlock(GradientCheckpointMixin):
super().__init__()
self.dim = dim
self.head_dim = dim // n_heads
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm)
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, use_flash_attn=use_flash_attn)
self.feed_forward = FeedForward(
dim=dim,
hidden_dim=4 * dim,
@@ -711,7 +710,12 @@ class FinalLayer(GradientCheckpointMixin):
class RopeEmbedder:
def __init__(self, theta: float = 10000.0, axes_dims: List[int] = [16, 56, 56], axes_lens: List[int] = [1, 512, 512]):
def __init__(
self,
theta: float = 10000.0,
axes_dims: List[int] = [16, 56, 56],
axes_lens: List[int] = [1, 512, 512],
):
super().__init__()
self.theta = theta
self.axes_dims = axes_dims
@@ -750,6 +754,7 @@ class NextDiT(nn.Module):
cap_feat_dim: int = 5120,
axes_dims: List[int] = [16, 56, 56],
axes_lens: List[int] = [1, 512, 512],
use_flash_attn=False,
) -> None:
"""
Initialize the NextDiT model.
@@ -803,6 +808,7 @@ class NextDiT(nn.Module):
norm_eps,
qk_norm,
modulation=False,
use_flash_attn=use_flash_attn,
)
for layer_id in range(n_refiner_layers)
]
@@ -828,6 +834,7 @@ class NextDiT(nn.Module):
norm_eps,
qk_norm,
modulation=True,
use_flash_attn=use_flash_attn,
)
for layer_id in range(n_refiner_layers)
]
@@ -848,6 +855,7 @@ class NextDiT(nn.Module):
ffn_dim_multiplier,
norm_eps,
qk_norm,
use_flash_attn=use_flash_attn,
)
for layer_id in range(n_layers)
]
@@ -988,8 +996,20 @@ class NextDiT(nn.Module):
freqs_cis = self.rope_embedder(position_ids)
# Create separate rotary embeddings for captions and images
cap_freqs_cis = torch.zeros(bsz, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype)
img_freqs_cis = torch.zeros(bsz, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype)
cap_freqs_cis = torch.zeros(
bsz,
encoder_seq_len,
freqs_cis.shape[-1],
device=device,
dtype=freqs_cis.dtype,
)
img_freqs_cis = torch.zeros(
bsz,
image_seq_len,
freqs_cis.shape[-1],
device=device,
dtype=freqs_cis.dtype,
)
for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]

View File

@@ -1,21 +1,28 @@
import inspect
import enum
import argparse
import math
import os
import numpy as np
import time
from typing import Callable, Dict, List, Optional, Tuple, Any
from typing import Callable, Dict, List, Optional, Tuple, Any, Union
import torch
from torch import Tensor
from torchdiffeq import odeint
from accelerate import Accelerator, PartialState
from transformers import Gemma2Model
from tqdm import tqdm
from PIL import Image
from safetensors.torch import save_file
from diffusers.schedulers.scheduling_heun_discrete import HeunDiscreteScheduler
from library import lumina_models, lumina_util, strategy_base, strategy_lumina, train_util
from library.flux_models import AutoEncoder
from library.device_utils import init_ipex, clean_memory_on_device
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
from library.lumina_dpm_solver import NoiseScheduleFlow, DPM_Solver
import library.lumina_path as path
init_ipex()
@@ -162,12 +169,12 @@ def sample_image_inference(
args: argparse.Namespace,
nextdit: lumina_models.NextDiT,
gemma2_model: Gemma2Model,
vae: torch.nn.Module,
vae: AutoEncoder,
save_dir: str,
prompt_dict: Dict[str, str],
epoch: int,
global_step: int,
sample_prompts_gemma2_outputs: List[Tuple[Tensor, Tensor, Tensor]],
sample_prompts_gemma2_outputs: dict[str, List[Tuple[Tensor, Tensor, Tensor]]],
prompt_replacement: Optional[Tuple[str, str]] = None,
controlnet=None,
):
@@ -179,12 +186,12 @@ def sample_image_inference(
args (argparse.Namespace): Arguments object
nextdit (lumina_models.NextDiT): NextDiT model
gemma2_model (Gemma2Model): Gemma2 model
vae (torch.nn.Module): VAE model
vae (AutoEncoder): VAE model
save_dir (str): Directory to save images
prompt_dict (Dict[str, str]): Prompt dictionary
epoch (int): Epoch number
steps (int): Number of steps to run
sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing gemma2 outputs
sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing Gemma 2 outputs
prompt_replacement (Optional[Tuple[str, str]], optional): Replacement for positive and negative prompt. Defaults to None.
Returns:
@@ -192,16 +199,19 @@ def sample_image_inference(
"""
assert isinstance(prompt_dict, dict)
# negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 38)
width = prompt_dict.get("width", 1024)
height = prompt_dict.get("height", 1024)
guidance_scale: int = prompt_dict.get("scale", 3.5)
seed: int = prompt_dict.get("seed", None)
sample_steps = int(prompt_dict.get("sample_steps", 38))
width = int(prompt_dict.get("width", 1024))
height = int(prompt_dict.get("height", 1024))
guidance_scale = float(prompt_dict.get("scale", 3.5))
seed = prompt_dict.get("seed", None)
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
negative_prompt: str = prompt_dict.get("negative_prompt", "")
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
seed = int(seed) if seed is not None else None
assert seed is None or seed > 0, f"Invalid seed {seed}"
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
@@ -213,10 +223,10 @@ def sample_image_inference(
# if negative_prompt is None:
# negative_prompt = ""
height = max(64, height - height % 16) # round to divisible by 16
width = max(64, width - width % 16) # round to divisible by 16
height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
logger.info(f"prompt: {prompt}")
# logger.info(f"negative_prompt: {negative_prompt}")
logger.info(f"negative_prompt: {negative_prompt}")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
@@ -232,46 +242,51 @@ def sample_image_inference(
assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
assert isinstance(encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
gemma2_conds = []
system_prompt = args.system_prompt or ""
# Apply system prompt to prompts
prompt = system_prompt + prompt
negative_prompt = system_prompt + negative_prompt
# Get sample prompts from cache
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
gemma2_conds = sample_prompts_gemma2_outputs[prompt]
logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}")
if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs:
neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt]
logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}")
# Load sample prompts from Gemma 2
if gemma2_model is not None:
logger.info(f"Encoding prompt with Gemma2: {prompt}")
tokens_and_masks = tokenize_strategy.tokenize(prompt)
encoded_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks)
gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks)
# if gemma2_conds is not cached, use encoded_gemma2_conds
if len(gemma2_conds) == 0:
gemma2_conds = encoded_gemma2_conds
else:
# if encoded_gemma2_conds is not None, update cached gemma2_conds
for i in range(len(encoded_gemma2_conds)):
if encoded_gemma2_conds[i] is not None:
gemma2_conds[i] = encoded_gemma2_conds[i]
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt)
neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks)
# Unpack Gemma2 outputs
gemma2_hidden_states, input_ids, gemma2_attn_mask = gemma2_conds
neg_gemma2_hidden_states, neg_input_ids, neg_gemma2_attn_mask = neg_gemma2_conds
# sample image
weight_dtype = vae.dtype # TOFO give dtype as argument
latent_height = height // 8
latent_width = width // 8
latent_channels = 16
noise = torch.randn(
1,
16,
latent_channels,
latent_height,
latent_width,
device=accelerator.device,
dtype=weight_dtype,
generator=generator,
)
# Prompts are paired positive/negative
noise = noise.repeat(gemma2_attn_mask.shape[0], 1, 1, 1)
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True)
# img_ids = lumina_util.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
gemma2_attn_mask = gemma2_attn_mask.to(accelerator.device)
scheduler = FlowMatchEulerDiscreteScheduler(shift=6.0, use_karras_sigmas=True)
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=sample_steps)
# if controlnet_image is not None:
# controlnet_image = Image.open(controlnet_image).convert("RGB")
@@ -280,16 +295,25 @@ def sample_image_inference(
# controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
with accelerator.autocast():
x = denoise(nextdit, noise, gemma2_hidden_states, gemma2_attn_mask, timesteps=timesteps, guidance=guidance_scale)
x = denoise(
scheduler,
nextdit,
noise,
gemma2_hidden_states,
gemma2_attn_mask.to(accelerator.device),
neg_gemma2_hidden_states,
neg_gemma2_attn_mask.to(accelerator.device),
timesteps=timesteps,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
)
# x = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width)
# latent to image
# Latent to image
clean_memory_on_device(accelerator.device)
org_vae_device = vae.device # will be on cpu
vae.to(accelerator.device) # distributed_state.device is same as accelerator.device
with accelerator.autocast():
x = vae.decode(x)
x = vae.decode((x / vae.scale_factor) + vae.shift_factor)
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
@@ -317,30 +341,25 @@ def sample_image_inference(
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
def time_shift(mu: float, sigma: float, t: Tensor):
"""
Get time shift
Args:
mu (float): mu value.
sigma (float): sigma value.
t (Tensor): timestep.
Return:
float: time shift
"""
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def time_shift(mu: float, sigma: float, t: torch.Tensor):
# the following implementation was original for t=0: clean / t=1: noise
# Since we adopt the reverse, the 1-t operations are needed
t = 1 - t
t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
t = 1 - t
return t
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
def get_lin_function(x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15) -> Callable[[float], float]:
"""
Get linear function
Args:
x1 (float, optional): x1 value. Defaults to 256.
y1 (float, optional): y1 value. Defaults to 0.5.
x2 (float, optional): x2 value. Defaults to 4096.
y2 (float, optional): y2 value. Defaults to 1.15.
image_seq_len,
x1 base_seq_len: int = 256,
y2 max_seq_len: int = 4096,
y1 base_shift: float = 0.5,
y2 max_shift: float = 1.15,
Return:
Callable[[float], float]: linear function
@@ -370,51 +389,164 @@ def get_schedule(
Return:
List[float]: timesteps schedule
"""
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
timesteps = torch.linspace(1, 1 / num_steps, num_steps)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
) -> Tuple[torch.Tensor, int]:
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
def denoise(
model: lumina_models.NextDiT, img: Tensor, txt: Tensor, txt_mask: Tensor, timesteps: List[float], guidance: float = 4.0
scheduler,
model: lumina_models.NextDiT,
img: Tensor,
txt: Tensor,
txt_mask: Tensor,
neg_txt: Tensor,
neg_txt_mask: Tensor,
timesteps: Union[List[float], torch.Tensor],
num_inference_steps: int = 38,
guidance_scale: float = 4.0,
cfg_trunc_ratio: float = 1.0,
cfg_normalization: bool = True,
):
"""
Denoise an image using the NextDiT model.
Args:
scheduler ():
Noise scheduler
model (lumina_models.NextDiT): The NextDiT model instance.
img (Tensor): The input image tensor.
txt (Tensor): The input text tensor.
txt_mask (Tensor): The input text mask tensor.
timesteps (List[float]): A list of timesteps for the denoising process.
guidance (float, optional): The guidance scale for the denoising process. Defaults to 4.0.
img (Tensor):
The input image latent tensor.
txt (Tensor):
The input text tensor.
txt_mask (Tensor):
The input text mask tensor.
neg_txt (Tensor):
The negative input txt tensor
neg_txt_mask (Tensor):
The negative input text mask tensor.
timesteps (List[Union[float, torch.FloatTensor]]):
A list of timesteps for the denoising process.
guidance_scale (float, optional):
The guidance scale for the denoising process. Defaults to 4.0.
cfg_trunc_ratio (float, optional):
The ratio of the timestep interval to apply normalization-based guidance scale.
cfg_normalization (bool, optional):
Whether to apply normalization-based guidance scale.
Returns:
img (Tensor): Denoised tensor
img (Tensor): Denoised latent tensor
"""
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
# model.prepare_block_swap_before_forward()
# block_samples = None
# block_single_samples = None
pred = model.forward_with_cfg(
x=img, # image latents (B, C, H, W)
t=t_vec / 1000, # timesteps需要除以1000来匹配模型预期
for i, t in enumerate(tqdm(timesteps)):
# compute whether apply classifier-free truncation on this timestep
do_classifier_free_truncation = (i + 1) / num_inference_steps > cfg_trunc_ratio
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
current_timestep = 1 - t / scheduler.config.num_train_timesteps
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep.expand(img.shape[0]).to(model.device)
noise_pred_cond = model(
img,
current_timestep,
cap_feats=txt, # Gemma2的hidden states作为caption features
cap_mask=txt_mask.to(dtype=torch.int32), # Gemma2的attention mask
cfg_scale=guidance,
)
img = img + (t_prev - t_curr) * pred
if not do_classifier_free_truncation:
noise_pred_uncond = model(
img,
current_timestep,
cap_feats=neg_txt, # Gemma2的hidden states作为caption features
cap_mask=neg_txt_mask.to(dtype=torch.int32), # Gemma2的attention mask
)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
# apply normalization after classifier-free guidance
if cfg_normalization:
cond_norm = torch.norm(noise_pred_cond, dim=-1, keepdim=True)
noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
noise_pred = noise_pred * (cond_norm / noise_norm)
else:
noise_pred = noise_pred_cond
img_dtype = img.dtype
if img.dtype != img_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
img = img.to(img_dtype)
# compute the previous noisy sample x_t -> x_t-1
noise_pred = -noise_pred
img = scheduler.step(noise_pred, t, img, return_dict=False)[0]
# model.prepare_block_swap_before_forward()
return img
@@ -754,3 +886,14 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
default=3.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
)
parser.add_argument(
"--use_flash_attn",
action="store_true",
help="Use Flash Attention for the model. / モデルにFlash Attentionを使用する。",
)
parser.add_argument(
"--system_prompt",
type=str,
default="You are an assistant designed to generate high-quality images based on user prompts. <Prompt Start> ",
help="System prompt to add to the prompt. / プロンプトに追加するシステムプロンプト。",
)

View File

@@ -20,11 +20,13 @@ logger = logging.getLogger(__name__)
MODEL_VERSION_LUMINA_V2 = "lumina2"
def load_lumina_model(
ckpt_path: str,
dtype: torch.dtype,
dtype: Optional[torch.dtype],
device: torch.device,
disable_mmap: bool = False,
use_flash_attn: bool = False,
):
"""
Load the Lumina model from the checkpoint path.
@@ -34,22 +36,22 @@ def load_lumina_model(
dtype (torch.dtype): The data type for the model.
device (torch.device): The device to load the model on.
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
use_flash_attn (bool, optional): Whether to use flash attention. Defaults to False.
Returns:
model (lumina_models.NextDiT): The loaded model.
"""
logger.info("Building Lumina")
with torch.device("meta"):
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner().to(dtype)
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn).to(dtype)
logger.info(f"Loading state dict from {ckpt_path}")
state_dict = load_safetensors(
ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype
)
state_dict = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
info = model.load_state_dict(state_dict, strict=False, assign=True)
logger.info(f"Loaded Lumina: {info}")
return model
def load_ae(
ckpt_path: str,
dtype: torch.dtype,
@@ -74,9 +76,7 @@ def load_ae(
ae = flux_models.AutoEncoder(flux_models.configs["schnell"].ae_params).to(dtype)
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(
ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype
)
sd = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
info = ae.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded AE: {info}")
return ae
@@ -104,37 +104,35 @@ def load_gemma2(
"""
logger.info("Building Gemma2")
GEMMA2_CONFIG = {
"_name_or_path": "google/gemma-2-2b",
"architectures": [
"Gemma2Model"
],
"attention_bias": False,
"attention_dropout": 0.0,
"attn_logit_softcapping": 50.0,
"bos_token_id": 2,
"cache_implementation": "hybrid",
"eos_token_id": 1,
"final_logit_softcapping": 30.0,
"head_dim": 256,
"hidden_act": "gelu_pytorch_tanh",
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 2304,
"initializer_range": 0.02,
"intermediate_size": 9216,
"max_position_embeddings": 8192,
"model_type": "gemma2",
"num_attention_heads": 8,
"num_hidden_layers": 26,
"num_key_value_heads": 4,
"pad_token_id": 0,
"query_pre_attn_scalar": 256,
"rms_norm_eps": 1e-06,
"rope_theta": 10000.0,
"sliding_window": 4096,
"torch_dtype": "float32",
"transformers_version": "4.44.2",
"use_cache": True,
"vocab_size": 256000
"_name_or_path": "google/gemma-2-2b",
"architectures": ["Gemma2Model"],
"attention_bias": False,
"attention_dropout": 0.0,
"attn_logit_softcapping": 50.0,
"bos_token_id": 2,
"cache_implementation": "hybrid",
"eos_token_id": 1,
"final_logit_softcapping": 30.0,
"head_dim": 256,
"hidden_act": "gelu_pytorch_tanh",
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 2304,
"initializer_range": 0.02,
"intermediate_size": 9216,
"max_position_embeddings": 8192,
"model_type": "gemma2",
"num_attention_heads": 8,
"num_hidden_layers": 26,
"num_key_value_heads": 4,
"pad_token_id": 0,
"query_pre_attn_scalar": 256,
"rms_norm_eps": 1e-06,
"rope_theta": 10000.0,
"sliding_window": 4096,
"torch_dtype": "float32",
"transformers_version": "4.44.2",
"use_cache": True,
"vocab_size": 256000,
}
config = Gemma2Config(**GEMMA2_CONFIG)
@@ -145,9 +143,7 @@ def load_gemma2(
sd = state_dict
else:
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(
ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype
)
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
for key in list(sd.keys()):
new_key = key.replace("model.", "")
@@ -159,6 +155,7 @@ def load_gemma2(
logger.info(f"Loaded Gemma2: {info}")
return gemma2
def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
"""
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
@@ -174,6 +171,7 @@ def pack_latents(x: torch.Tensor) -> torch.Tensor:
x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
return x
DIFFUSERS_TO_ALPHA_VLLM_MAP = {
# Embedding layers
"cap_embedder.0.weight": ["time_caption_embed.caption_embedder.0.weight"],
@@ -224,9 +222,7 @@ def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict
for block_idx in range(num_double_blocks):
if str(block_idx) in key:
converted = pattern.replace("()", str(block_idx))
new_key = key.replace(
converted, replacement.replace("()", str(block_idx))
)
new_key = key.replace(converted, replacement.replace("()", str(block_idx)))
break
if new_key == key:

View File

@@ -610,6 +610,21 @@ from diffusers.utils.torch_utils import randn_tensor
from diffusers.utils import BaseOutput
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@dataclass
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
"""
@@ -649,22 +664,49 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self,
num_train_timesteps: int = 1000,
shift: float = 1.0,
use_dynamic_shifting=False,
base_shift: Optional[float] = 0.5,
max_shift: Optional[float] = 1.15,
base_image_seq_len: Optional[int] = 256,
max_image_seq_len: Optional[int] = 4096,
invert_sigmas: bool = False,
shift_terminal: Optional[float] = None,
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
sigmas = timesteps / num_train_timesteps
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
if not use_dynamic_shifting:
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.timesteps = sigmas * num_train_timesteps
self._step_index = None
self._begin_index = None
self._shift = shift
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
@property
def shift(self):
"""
The value used for shifting.
"""
return self._shift
@property
def step_index(self):
"""
@@ -690,6 +732,9 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
self._begin_index = begin_index
def set_shift(self, shift: float):
self._shift = shift
def scale_noise(
self,
sample: torch.FloatTensor,
@@ -709,10 +754,31 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
A scaled input sample.
"""
if self.step_index is None:
self._init_step_index(timestep)
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
if sample.device.type == "mps" and torch.is_floating_point(timestep):
# mps does not support float64
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
timestep = timestep.to(sample.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(sample.device)
timestep = timestep.to(sample.device)
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timestep.shape[0]
else:
# add noise is called before first denoising step to create initial latent(img2img)
step_indices = [self.begin_index] * timestep.shape[0]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(sample.shape):
sigma = sigma.unsqueeze(-1)
sigma = self.sigmas[self.step_index]
sample = sigma * noise + (1.0 - sigma) * sample
return sample
@@ -720,7 +786,37 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
r"""
Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
value.
Reference:
https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
Args:
t (`torch.Tensor`):
A tensor of timesteps to be stretched and shifted.
Returns:
`torch.Tensor`:
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
"""
one_minus_z = 1 - t
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
stretched_t = 1 - (one_minus_z / scale_factor)
return stretched_t
def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[float] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -730,18 +826,49 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
if self.config.use_dynamic_shifting and mu is None:
raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
if sigmas is None:
timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
)
sigmas = timesteps / self.config.num_train_timesteps
else:
sigmas = np.array(sigmas).astype(np.float32)
num_inference_steps = len(sigmas)
self.num_inference_steps = num_inference_steps
timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps)
if self.config.use_dynamic_shifting:
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
if self.config.shift_terminal:
sigmas = self.stretch_shift_to_terminal(sigmas)
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
sigmas = timesteps / self.config.num_train_timesteps
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
timesteps = sigmas * self.config.num_train_timesteps
self.timesteps = timesteps.to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
if self.config.invert_sigmas:
sigmas = 1.0 - sigmas
timesteps = sigmas * self.config.num_train_timesteps
sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
else:
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self.timesteps = timesteps.to(device=device)
self.sigmas = sigmas
self._step_index = None
self._begin_index = None
@@ -807,7 +934,11 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
@@ -823,30 +954,10 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample = sample.to(torch.float32)
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
prev_sample = sample + (sigma_next - sigma) * model_output
noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator)
eps = noise * s_noise
sigma_hat = sigma * (gamma + 1)
if gamma > 0:
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
# backwards compatibility
# if self.config.prediction_type == "vector_field":
denoised = sample - model_output * sigma
# 2. Convert to an ODE derivative
derivative = (sample - denoised) / sigma_hat
dt = self.sigmas[self.step_index + 1] - sigma_hat
prev_sample = sample + derivative * dt
# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)
@@ -858,6 +969,86 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""Constructs an exponential noise schedule."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = np.array(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
scipy.stats.beta.ppf(timestep, alpha, beta)
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
]
]
)
return sigmas
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -2,7 +2,7 @@
import os
import re
from typing import Any, List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union, Callable
import numpy as np
import torch
@@ -430,9 +430,21 @@ class LatentsCachingStrategy:
bucket_reso: Tuple[int, int],
npz_path: str,
flip_aug: bool,
alpha_mask: bool,
apply_alpha_mask: bool,
multi_resolution: bool = False,
):
) -> bool:
"""
Args:
latents_stride: stride of latents
bucket_reso: resolution of the bucket
npz_path: path to the npz file
flip_aug: whether to flip images
apply_alpha_mask: whether to apply alpha mask
multi_resolution: whether to use multi-resolution latents
Returns:
bool
"""
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
@@ -451,7 +463,7 @@ class LatentsCachingStrategy:
return False
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
return False
if alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
@@ -462,22 +474,35 @@ class LatentsCachingStrategy:
# TODO remove circular dependency for ImageInfo
def _default_cache_batch_latents(
self,
encode_by_vae,
vae_device,
vae_dtype,
encode_by_vae: Callable,
vae_device: torch.device,
vae_dtype: torch.dtype,
image_infos: List,
flip_aug: bool,
alpha_mask: bool,
apply_alpha_mask: bool,
random_crop: bool,
multi_resolution: bool = False,
):
"""
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
Args:
encode_by_vae: function to encode images by VAE
vae_device: device to use for VAE
vae_dtype: dtype to use for VAE
image_infos: list of ImageInfo
flip_aug: whether to flip images
apply_alpha_mask: whether to apply alpha mask
random_crop: whether to random crop images
multi_resolution: whether to use multi-resolution latents
Returns:
None
"""
from library import train_util # import here to avoid circular import
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
image_infos, alpha_mask, random_crop
image_infos, apply_alpha_mask, random_crop
)
img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype)
@@ -519,12 +544,40 @@ class LatentsCachingStrategy:
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""
for SD/SDXL
Args:
npz_path (str): Path to the npz file.
bucket_reso (Tuple[int, int]): The resolution of the bucket.
Returns:
Tuple[
Optional[np.ndarray],
Optional[List[int]],
Optional[List[int]],
Optional[np.ndarray],
Optional[np.ndarray]
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
"""
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
def _default_load_latents_from_disk(
self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""
Args:
latents_stride (Optional[int]): Stride for latents. If None, load all latents.
npz_path (str): Path to the npz file.
bucket_reso (Tuple[int, int]): The resolution of the bucket.
Returns:
Tuple[
Optional[np.ndarray],
Optional[List[int]],
Optional[List[int]],
Optional[np.ndarray],
Optional[np.ndarray]
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
"""
if latents_stride is None:
key_reso_suffix = ""
else:
@@ -552,6 +605,19 @@ class LatentsCachingStrategy:
alpha_mask=None,
key_reso_suffix="",
):
"""
Args:
npz_path (str): Path to the npz file.
latents_tensor (torch.Tensor): Latent tensor
original_size (List[int]): Original size of the image
crop_ltrb (List[int]): Crop left top right bottom
flipped_latents_tensor (Optional[torch.Tensor]): Flipped latent tensor
alpha_mask (Optional[torch.Tensor]): Alpha mask
key_reso_suffix (str): Key resolution suffix
Returns:
None
"""
kwargs = {}
if os.path.exists(npz_path):

View File

@@ -3,13 +3,13 @@ import os
from typing import Any, List, Optional, Tuple, Union
import torch
from transformers import AutoTokenizer, AutoModel, GemmaTokenizerFast
from transformers import AutoTokenizer, AutoModel, Gemma2Model, GemmaTokenizerFast
from library import train_util
from library.strategy_base import (
LatentsCachingStrategy,
TokenizeStrategy,
TextEncodingStrategy,
TextEncoderOutputsCachingStrategy
TextEncoderOutputsCachingStrategy,
)
import numpy as np
from library.utils import setup_logging
@@ -37,21 +37,38 @@ class LuminaTokenizeStrategy(TokenizeStrategy):
else:
self.max_length = max_length
def tokenize(self, text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor]:
def tokenize(
self, text: Union[str, List[str]]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
text (Union[str, List[str]]): Text to tokenize
Returns:
Tuple[torch.Tensor, torch.Tensor]:
token input ids, attention_masks
"""
text = [text] if isinstance(text, str) else text
encodings = self.tokenizer(
text,
max_length=self.max_length,
return_tensors="pt",
padding=True,
padding="max_length",
pad_to_multiple_of=8,
truncation=True,
)
return [encodings.input_ids, encodings.attention_mask]
return (encodings.input_ids, encodings.attention_mask)
def tokenize_with_weights(
self, text: str | List[str]
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
"""
Args:
text (Union[str, List[str]]): Text to tokenize
Returns:
Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
token input ids, attention_masks, weights
"""
# Gemma doesn't support weighted prompts, return uniform weights
tokens, attention_masks = self.tokenize(text)
weights = [torch.ones_like(t) for t in tokens]
@@ -66,9 +83,20 @@ class LuminaTextEncodingStrategy(TextEncodingStrategy):
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
tokens: Tuple[torch.Tensor, torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
models (List[Any]): Text encoders
tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states, input_ids, attention_masks
"""
text_encoder = models[0]
assert isinstance(text_encoder, Gemma2Model)
input_ids, attention_masks = tokens
outputs = text_encoder(
@@ -84,9 +112,20 @@ class LuminaTextEncodingStrategy(TextEncodingStrategy):
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
weights_list: List[torch.Tensor],
tokens: Tuple[torch.Tensor, torch.Tensor],
weights: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
models (List[Any]): Text encoders
tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks
weights_list (List[torch.Tensor]): Currently unused
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states, input_ids, attention_masks
"""
# For simplicity, use uniform weighting
return self.encode_tokens(tokenize_strategy, models, tokens)
@@ -114,7 +153,14 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
+ LuminaTextEncoderOutputsCachingStrategy.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
)
def is_disk_cached_outputs_expected(self, npz_path: str):
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
"""
Args:
npz_path (str): Path to the npz file.
Returns:
bool: True if the npz file is expected to be cached.
"""
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
@@ -141,7 +187,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
Load outputs from a npz file
Returns:
List[np.ndarray]: hidden_state, input_ids, attention_mask
List[np.ndarray]: hidden_state, input_ids, attention_mask
"""
data = np.load(npz_path)
hidden_state = data["hidden_state"]
@@ -151,53 +197,75 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
def cache_batch_outputs(
self,
tokenize_strategy: LuminaTokenizeStrategy,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
text_encoding_strategy: LuminaTextEncodingStrategy,
infos: List,
):
lumina_text_encoding_strategy: LuminaTextEncodingStrategy = (
text_encoding_strategy
)
captions = [info.caption for info in infos]
text_encoding_strategy: TextEncodingStrategy,
batch: List[train_util.ImageInfo],
) -> None:
"""
Args:
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
models (List[Any]): Text encoders
text_encoding_strategy (LuminaTextEncodingStrategy):
infos (List): List of image_info
Returns:
None
"""
assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy)
assert isinstance(tokenize_strategy, LuminaTokenizeStrategy)
captions = [info.system_prompt or "" + info.caption for info in batch]
if self.is_weighted:
tokens, weights_list = tokenize_strategy.tokenize_with_weights(
captions
tokens, attention_masks, weights_list = (
tokenize_strategy.tokenize_with_weights(captions)
)
with torch.no_grad():
hidden_state, input_ids, attention_masks = lumina_text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, models, tokens, weights_list
hidden_state, input_ids, attention_masks = (
text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
models,
(tokens, attention_masks),
weights_list,
)
)
else:
tokens = tokenize_strategy.tokenize(captions)
with torch.no_grad():
hidden_state, input_ids, attention_masks = lumina_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens
hidden_state, input_ids, attention_masks = (
text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens
)
)
if hidden_state.dtype != torch.float32:
hidden_state = hidden_state.float()
hidden_state = hidden_state.cpu().numpy()
attention_mask = attention_masks.cpu().numpy()
input_ids = tokens.cpu().numpy()
attention_mask = attention_masks.cpu().numpy() # (B, S)
input_ids = input_ids.cpu().numpy() # (B, S)
for i, info in enumerate(infos):
for i, info in enumerate(batch):
hidden_state_i = hidden_state[i]
attention_mask_i = attention_mask[i]
input_ids_i = input_ids[i]
assert info.text_encoder_outputs_npz is not None, "Text encoder cache outputs to disk not found for image {info.image_path}"
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
hidden_state=hidden_state_i,
attention_mask=attention_mask_i,
input_ids=input_ids_i
input_ids=input_ids_i,
)
else:
info.text_encoder_outputs = [hidden_state_i, attention_mask_i, input_ids_i]
info.text_encoder_outputs = [
hidden_state_i,
attention_mask_i,
input_ids_i,
]
class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
@@ -227,7 +295,14 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
npz_path: str,
flip_aug: bool,
alpha_mask: bool,
):
) -> bool:
"""
Args:
bucket_reso (Tuple[int, int]): The resolution of the bucket.
npz_path (str): Path to the npz file.
flip_aug (bool): Whether to flip the image.
alpha_mask (bool): Whether to apply
"""
return self._default_is_disk_cached_latents_expected(
8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True
)
@@ -241,6 +316,20 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
Optional[np.ndarray],
Optional[np.ndarray],
]:
"""
Args:
npz_path (str): Path to the npz file.
bucket_reso (Tuple[int, int]): The resolution of the bucket.
Returns:
Tuple[
Optional[np.ndarray],
Optional[List[int]],
Optional[List[int]],
Optional[np.ndarray],
Optional[np.ndarray],
]: Tuple of latent tensors, attention_mask, input_ids, latents, latents_unet
"""
return self._default_load_latents_from_disk(
8, npz_path, bucket_reso
) # support multi-resolution

View File

@@ -195,7 +195,7 @@ class ImageInfo:
self.latents_flipped: Optional[torch.Tensor] = None
self.latents_npz: Optional[str] = None # set in cache_latents
self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size
self.latents_crop_ltrb: Optional[Tuple[int, int]] = (
self.latents_crop_ltrb: Optional[Tuple[int, int, int, int]] = (
None # crop left top right bottom in original pixel size, not latents size
)
self.cond_img_path: Optional[str] = None
@@ -211,6 +211,8 @@ class ImageInfo:
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
self.system_prompt: Optional[str] = None
class BucketManager:
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
@@ -434,6 +436,7 @@ class BaseSubset:
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
system_prompt: Optional[str] = None
) -> None:
self.image_dir = image_dir
self.alpha_mask = alpha_mask if alpha_mask is not None else False
@@ -464,6 +467,8 @@ class BaseSubset:
self.validation_seed = validation_seed
self.validation_split = validation_split
self.system_prompt = system_prompt
class DreamBoothSubset(BaseSubset):
def __init__(
@@ -495,6 +500,7 @@ class DreamBoothSubset(BaseSubset):
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
system_prompt: Optional[str] = None
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -522,6 +528,7 @@ class DreamBoothSubset(BaseSubset):
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
system_prompt=system_prompt
)
self.is_reg = is_reg
@@ -564,6 +571,7 @@ class FineTuningSubset(BaseSubset):
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
system_prompt: Optional[str] = None
) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
@@ -591,6 +599,7 @@ class FineTuningSubset(BaseSubset):
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
system_prompt=system_prompt
)
self.metadata_file = metadata_file
@@ -629,6 +638,7 @@ class ControlNetSubset(BaseSubset):
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
system_prompt: Optional[str] = None
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -656,6 +666,7 @@ class ControlNetSubset(BaseSubset):
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
system_prompt=system_prompt
)
self.conditioning_data_dir = conditioning_data_dir
@@ -1686,8 +1697,9 @@ class BaseDataset(torch.utils.data.Dataset):
text_encoder_outputs_list.append(text_encoder_outputs)
if tokenization_required:
system_prompt = subset.system_prompt or ""
caption = self.process_caption(subset, image_info.caption)
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(system_prompt + caption)] # remove batch dimension
# if self.XTI_layers:
# caption_layer = []
# for layer in self.XTI_layers:
@@ -2059,6 +2071,7 @@ class DreamBoothDataset(BaseDataset):
num_train_images = 0
num_reg_images = 0
reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = []
for subset in subsets:
num_repeats = subset.num_repeats if self.is_training_dataset else 1
if num_repeats < 1:
@@ -2086,7 +2099,7 @@ class DreamBoothDataset(BaseDataset):
num_train_images += num_repeats * len(img_paths)
for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
info = ImageInfo(img_path, num_repeats, subset.system_prompt or "" + caption, subset.is_reg, img_path)
if size is not None:
info.image_size = size
if subset.is_reg:
@@ -2967,7 +2980,7 @@ def trim_and_resize_if_required(
# for new_cache_latents
def load_images_and_masks_for_caching(
image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool
) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]:
) -> Tuple[torch.Tensor, List[torch.Tensor], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]:
r"""
requires image_infos to have: [absolute_path or image], bucket_reso, resized_size

View File

@@ -1,17 +1,17 @@
import argparse
import copy
import math
import random
from typing import Any, Optional, Union, Tuple
from typing import Any, Tuple
import torch
from torch import Tensor
from accelerate import Accelerator
from library.device_utils import clean_memory_on_device, init_ipex
init_ipex()
from torch import Tensor
from accelerate import Accelerator
import train_network
from library import (
lumina_models,
@@ -40,10 +40,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
def assert_extra_args(self, args, train_dataset_group, val_dataset_group):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
if (
args.cache_text_encoder_outputs_to_disk
and not args.cache_text_encoder_outputs
):
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning("Enabling cache_text_encoder_outputs due to disk caching")
args.cache_text_encoder_outputs = True
@@ -59,17 +56,14 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
model = lumina_util.load_lumina_model(
args.pretrained_model_name_or_path,
loading_dtype,
"cpu",
torch.device("cpu"),
disable_mmap=args.disable_mmap_load_safetensors,
use_flash_attn=args.use_flash_attn,
)
if args.fp8_base:
# check dtype of model
if (
model.dtype == torch.float8_e4m3fnuz
or model.dtype == torch.float8_e5m2
or model.dtype == torch.float8_e5m2fnuz
):
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
elif model.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 Lumina 2 model")
@@ -92,17 +86,13 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model
def get_tokenize_strategy(self, args):
return strategy_lumina.LuminaTokenizeStrategy(
args.gemma2_max_token_length, args.tokenizer_cache_dir
)
return strategy_lumina.LuminaTokenizeStrategy(args.gemma2_max_token_length, args.tokenizer_cache_dir)
def get_tokenizers(self, tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy):
return [tokenize_strategy.tokenizer]
def get_latents_caching_strategy(self, args):
return strategy_lumina.LuminaLatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, False
)
return strategy_lumina.LuminaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
def get_text_encoding_strategy(self, args):
return strategy_lumina.LuminaTextEncodingStrategy()
@@ -144,15 +134,11 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
logger.info("move text encoders to gpu")
text_encoders[0].to(
accelerator.device, dtype=weight_dtype
) # always not fp8
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
if text_encoders[0].dtype == torch.float8_e4m3fn:
# if we load fp8 weights, the model is already fp8, so we use it as is
self.prepare_text_encoder_fp8(
1, text_encoders[1], text_encoders[1].dtype, weight_dtype
)
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
else:
# otherwise, we need to convert it to target dtype
text_encoders[0].to(weight_dtype)
@@ -162,35 +148,36 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
# cache sample prompts
if args.sample_prompts is not None:
logger.info(
f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}"
)
logger.info(f"cache Text Encoder outputs for sample prompts: {args.sample_prompts}")
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
system_prompt = args.system_prompt or ""
sample_prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = (
{}
) # key: prompt, value: text encoder outputs
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in sample_prompts:
prompts = [prompt_dict.get("prompt", ""),
prompt_dict.get("negative_prompt", "")]
logger.info(
f"cache Text Encoder outputs for prompt: {prompts[0]}"
)
tokens_and_masks = tokenize_strategy.tokenize(prompts)
sample_prompts_te_outputs[prompts[0]] = (
text_encoding_strategy.encode_tokens(
prompts = [
prompt_dict.get("prompt", ""),
prompt_dict.get("negative_prompt", ""),
]
for prompt in prompts:
prompt = system_prompt + prompt
if prompt in sample_prompts_te_outputs:
continue
logger.info(f"cache Text Encoder outputs for prompt: {prompt}")
tokens_and_masks = tokenize_strategy.tokenize(prompt)
sample_prompts_te_outputs[prompt] = text_encoding_strategy.encode_tokens(
tokenize_strategy,
text_encoders,
tokens_and_masks,
)
)
self.sample_prompts_te_outputs = sample_prompts_te_outputs
accelerator.wait_for_everyone()
@@ -235,12 +222,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
# Remaining methods maintain similar structure to flux implementation
# with Lumina-specific model calls and strategies
def get_noise_scheduler(
self, args: argparse.Namespace, device: torch.device
) -> Any:
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000, shift=args.discrete_flow_shift
)
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
return noise_scheduler
@@ -258,26 +241,45 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
noise_scheduler,
latents,
batch,
text_encoder_conds: Tuple[Tensor, Tensor, Tensor], # (hidden_states, input_ids, attention_masks)
text_encoder_conds: Tuple[Tensor, Tensor, Tensor], # (hidden_states, input_ids, attention_masks)
dit: lumina_models.NextDiT,
network,
weight_dtype,
train_unet,
is_train=True,
):
assert isinstance(noise_scheduler, sd3_train_utils.FlowMatchEulerDiscreteScheduler)
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = (
flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
)
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
u = lumina_train_util.compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=latents.device)
# May not need to pack/unpack?
# pack latents and get img_ids - 这部分可以保留因为NextDiT也需要packed格式的输入
# packed_noisy_model_input = lumina_util.pack_latents(noisy_model_input)
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
timesteps = timesteps.to(accelerator.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
# Add noise according to flow matching.
# zt = (1 - texp) * x + texp * z1
# Lumina2 reverses the lerp i.e., sigma of 1.0 should mean `latents`
sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
noisy_model_input = (1.0 - sigmas) * noise + sigmas * latents
# ensure the hidden state will require grad
if args.gradient_checkpointing:
@@ -289,48 +291,35 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
# Unpack Gemma2 outputs
gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds
def call_dit(img, gemma2_hidden_states, timesteps, gemma2_attn_mask):
def call_dit(img, gemma2_hidden_states, gemma2_attn_mask, timesteps):
with torch.set_grad_enabled(is_train), accelerator.autocast():
# NextDiT forward expects (x, t, cap_feats, cap_mask)
model_pred = dit(
x=img, # image latents (B, C, H, W)
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
cap_mask=gemma2_attn_mask.to(
dtype=torch.int32
), # Gemma2的attention mask
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
)
return model_pred
model_pred = call_dit(
img=noisy_model_input,
gemma2_hidden_states=gemma2_hidden_states,
timesteps=timesteps,
gemma2_attn_mask=gemma2_attn_mask,
timesteps=timesteps,
)
# May not need to pack/unpack?
# unpack latents
# model_pred = lumina_util.unpack_latents(
# model_pred, packed_latent_height, packed_latent_width
# )
# apply model prediction type
model_pred, weighting = flux_train_utils.apply_model_prediction_type(
args, model_pred, noisy_model_input, sigmas
)
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
# flow matching loss: this is different from SD3
target = noise - latents
# flow matching loss
target = latents - noise
# differential output preservation
if "custom_attributes" in batch:
diff_output_pr_indices = []
for i, custom_attributes in enumerate(batch["custom_attributes"]):
if (
"diff_output_preservation" in custom_attributes
and custom_attributes["diff_output_preservation"]
):
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
diff_output_pr_indices.append(i)
if len(diff_output_pr_indices) > 0:
@@ -338,9 +327,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
with torch.no_grad():
model_pred_prior = call_dit(
img=noisy_model_input[diff_output_pr_indices],
gemma2_hidden_states=gemma2_hidden_states[
diff_output_pr_indices
],
gemma2_hidden_states=gemma2_hidden_states[diff_output_pr_indices],
timesteps=timesteps[diff_output_pr_indices],
gemma2_attn_mask=(gemma2_attn_mask[diff_output_pr_indices]),
)
@@ -363,9 +350,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
return loss
def get_sai_model_spec(self, args):
return train_util.get_sai_model_spec(
None, args, False, True, False, lumina="lumina2"
)
return train_util.get_sai_model_spec(None, args, False, True, False, lumina="lumina2")
def update_metadata(self, metadata, args):
metadata["ss_weighting_scheme"] = args.weighting_scheme
@@ -384,12 +369,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
text_encoder.embed_tokens.requires_grad_(True)
def prepare_text_encoder_fp8(
self, index, text_encoder, te_weight_dtype, weight_dtype
):
logger.info(
f"prepare Gemma2 for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}"
)
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
logger.info(f"prepare Gemma2 for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
text_encoder.to(te_weight_dtype) # fp8
text_encoder.embed_tokens.to(dtype=weight_dtype)
@@ -402,12 +383,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
# if we doesn't swap blocks, we can move the model to device
nextdit = unet
assert isinstance(nextdit, lumina_models.NextDiT)
nextdit = accelerator.prepare(
nextdit, device_placement=[not self.is_swapping_blocks]
)
accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks(
accelerator.device
) # reduce peak memory usage
nextdit = accelerator.prepare(nextdit, device_placement=[not self.is_swapping_blocks])
accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
accelerator.unwrap_model(nextdit).prepare_block_swap_before_forward()
return nextdit

View File

@@ -129,7 +129,7 @@ class NetworkTrainer:
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(64)
def load_target_model(self, args, weight_dtype, accelerator):
def load_target_model(self, args, weight_dtype, accelerator) -> tuple:
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
# モデルに xformers とか memory efficient attention を組み込む
@@ -354,12 +354,13 @@ class NetworkTrainer:
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch['captions'])
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),