This commit is contained in:
sdbds
2025-02-12 16:30:05 +08:00
parent 0778dd9b1d
commit d154e76c45
7 changed files with 2373 additions and 0 deletions

1144
library/lumina_models.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,554 @@
import argparse
import math
import os
import numpy as np
import toml
import json
import time
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from accelerate import Accelerator, PartialState
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
from PIL import Image
from safetensors.torch import save_file
from library import lumina_models, lumina_util, strategy_base, train_util
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from .utils import setup_logging, mem_eff_save_file
setup_logging()
import logging
logger = logging.getLogger(__name__)
# region sample images
def sample_images(
accelerator: Accelerator,
args: argparse.Namespace,
epoch,
steps,
nextdit,
ae,
gemma2_model,
sample_prompts_gemma2_outputs,
prompt_replacement=None,
controlnet=None
):
if steps == 0:
if not args.sample_at_first:
return
else:
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return
logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts) and sample_prompts_gemma2_outputs is None:
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
# unwrap nextdit and gemma2_model
nextdit = accelerator.unwrap_model(nextdit)
if gemma2_model is not None:
gemma2_model = accelerator.unwrap_model(gemma2_model)
# if controlnet is not None:
# controlnet = accelerator.unwrap_model(controlnet)
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
prompts = train_util.load_prompts(args.sample_prompts)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
# save random state to restore later
rng_state = torch.get_rng_state()
cuda_rng_state = None
try:
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
except Exception:
pass
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
sample_image_inference(
accelerator,
args,
nextdit,
gemma2_model,
ae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_gemma2_outputs,
prompt_replacement,
controlnet
)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
per_process_prompts = [] # list of lists
for i in range(distributed_state.num_processes):
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
with torch.no_grad():
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(
accelerator,
args,
nextdit,
gemma2_model,
ae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_gemma2_outputs,
prompt_replacement,
controlnet
)
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
clean_memory_on_device(accelerator.device)
def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
nextdit,
gemma2_model,
ae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_gemma2_outputs,
prompt_replacement,
# controlnet
):
assert isinstance(prompt_dict, dict)
# negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 20)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 3.5)
seed = prompt_dict.get("seed")
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
# if negative_prompt is not None:
# negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
else:
# True random sample image generation
torch.seed()
torch.cuda.seed()
# if negative_prompt is None:
# negative_prompt = ""
height = max(64, height - height % 16) # round to divisible by 16
width = max(64, width - width % 16) # round to divisible by 16
logger.info(f"prompt: {prompt}")
# logger.info(f"negative_prompt: {negative_prompt}")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
logger.info(f"scale: {scale}")
# logger.info(f"sample_sampler: {sampler_name}")
if seed is not None:
logger.info(f"seed: {seed}")
# encode prompts
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
gemma2_conds = []
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
gemma2_conds = sample_prompts_gemma2_outputs[prompt]
print(f"Using cached Gemma2 outputs for prompt: {prompt}")
if gemma2_model is not None:
print(f"Encoding prompt with Gemma2: {prompt}")
tokens_and_masks = tokenize_strategy.tokenize(prompt)
# strategy has apply_gemma2_attn_mask option
encoded_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks)
# if gemma2_conds is not cached, use encoded_gemma2_conds
if len(gemma2_conds) == 0:
gemma2_conds = encoded_gemma2_conds
else:
# if encoded_gemma2_conds is not None, update cached gemma2_conds
for i in range(len(encoded_gemma2_conds)):
if encoded_gemma2_conds[i] is not None:
gemma2_conds[i] = encoded_gemma2_conds[i]
# Unpack Gemma2 outputs
gemma2_hidden_states, gemma2_attn_mask, input_ids = gemma2_conds
# sample image
weight_dtype = ae.dtype # TOFO give dtype as argument
packed_latent_height = height // 16
packed_latent_width = width // 16
noise = torch.randn(
1,
packed_latent_height * packed_latent_width,
16 * 2 * 2,
device=accelerator.device,
dtype=weight_dtype,
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
)
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True)
img_ids = lumina_util.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
gemma2_attn_mask = gemma2_attn_mask.to(accelerator.device) if args.apply_gemma2_attn_mask else None
# if controlnet_image is not None:
# controlnet_image = Image.open(controlnet_image).convert("RGB")
# controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
# controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
# controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
with accelerator.autocast(), torch.no_grad():
x = denoise(nextdit, noise, img_ids, gemma2_hidden_states, input_ids, None, timesteps=timesteps, guidance=scale, gemma2_attn_mask=gemma2_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
x = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width)
# latent to image
clean_memory_on_device(accelerator.device)
org_vae_device = ae.device # will be on cpu
ae.to(accelerator.device) # distributed_state.device is same as accelerator.device
with accelerator.autocast(), torch.no_grad():
x = ae.decode(x)
ae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
x = x.clamp(-1, 1)
x = x.permute(0, 2, 3, 1)
image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
# but adding 'enum' to the filename should be enough
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i: int = prompt_dict["enum"]
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# send images to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
def time_shift(mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
# endregion
# region train
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
"""Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu")
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
return weighting
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, _, h, w = latents.shape
sigmas = None
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random t-based noise sampling
if args.timestep_sampling == "sigmoid":
# https://github.com/XLabs-AI/x-flux/tree/main
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
else:
t = torch.rand((bsz,), device=device)
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * latents + t * noise
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
elif args.timestep_sampling == "nextdit_shift":
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
timesteps = time_shift(mu, 1.0, timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)
# Add noise according to flow matching.
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
weighting = None
if args.model_prediction_type == "raw":
pass
elif args.model_prediction_type == "additive":
# add the model_pred to the noisy_model_input
model_pred = model_pred + noisy_model_input
elif args.model_prediction_type == "sigma_scaled":
# apply sigma scaling
model_pred = model_pred * (-sigmas) + noisy_model_input
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
return model_pred, weighting
def save_models(
ckpt_path: str,
lumina: lumina_models.NextDiT,
sai_metadata: Optional[dict],
save_dtype: Optional[torch.dtype] = None,
use_mem_eff_save: bool = False,
):
state_dict = {}
def update_sd(prefix, sd):
for k, v in sd.items():
key = prefix + k
if save_dtype is not None and v.dtype != save_dtype:
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
update_sd("", lumina.state_dict())
if not use_mem_eff_save:
save_file(state_dict, ckpt_path, metadata=sai_metadata)
else:
mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata)
def save_lumina_model_on_train_end(
args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, lumina: lumina_models.NextDiT
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2")
save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save)
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合してている
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
def save_lumina_model_on_epoch_end_or_stepwise(
args: argparse.Namespace,
on_epoch_end: bool,
accelerator,
save_dtype: torch.dtype,
epoch: int,
num_train_epochs: int,
global_step: int,
lumina: lumina_models.NextDiT,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2")
save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save)
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
args,
on_epoch_end,
accelerator,
True,
True,
epoch,
num_train_epochs,
global_step,
sd_saver,
None,
)
# endregion
def add_lumina_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--gemma2",
type=str,
help="path to gemma2 model (*.sft or *.safetensors), should be float16 / gemma2のパス*.sftまたは*.safetensors、float16が前提",
)
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス*.sftまたは*.safetensors")
parser.add_argument(
"--gemma2_max_token_length",
type=int,
default=None,
help="maximum token length for Gemma2. if omitted, 256 for schnell and 512 for dev"
" / Gemma2の最大トークン長。省略された場合、schnellの場合は256、devの場合は512",
)
parser.add_argument(
"--apply_gemma2_attn_mask",
action="store_true",
help="apply attention mask to Gemma2 encode and NextDIT double blocks / Gemma2エンコードとNextDITダブルブロックにアテンションマスクを適用する",
)
parser.add_argument(
"--guidance_scale",
type=float,
default=3.5,
help="the NextDIT.1 dev variant is a guidance distilled model",
)
parser.add_argument(
"--timestep_sampling",
choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"],
default="sigma",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting."
" / タイムステップをサンプリングする方法sigma、random uniform、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。",
)
parser.add_argument(
"--sigmoid_scale",
type=float,
default=1.0,
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率timestep-samplingが"sigmoid"の場合のみ有効)。',
)
parser.add_argument(
"--model_prediction_type",
choices=["raw", "additive", "sigma_scaled"],
default="sigma_scaled",
help="How to interpret and process the model prediction: "
"raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."
" / モデル予測の解釈と処理方法:"
"rawそのまま使用、additiveイズ入力に加算、sigma_scaledシグマスケーリングを適用",
)
parser.add_argument(
"--discrete_flow_shift",
type=float,
default=3.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
)

194
library/lumina_util.py Normal file
View File

@@ -0,0 +1,194 @@
import json
import os
from dataclasses import replace
from typing import List, Optional, Tuple, Union
import einops
import torch
from accelerate import init_empty_weights
from safetensors import safe_open
from safetensors.torch import load_file
from transformers import Gemma2Config, Gemma2Model
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import lumina_models, flux_models
from library.utils import load_safetensors
MODEL_VERSION_LUMINA_V2 = "lumina2"
def load_lumina_model(
ckpt_path: str,
dtype: torch.dtype,
device: Union[str, torch.device],
disable_mmap: bool = False,
) -> lumina_models.Lumina:
logger.info("Building Lumina")
with torch.device("meta"):
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner().to(dtype)
logger.info(f"Loading state dict from {ckpt_path}")
state_dict = load_safetensors(
ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype
)
info = model.load_state_dict(state_dict, strict=False, assign=True)
logger.info(f"Loaded Lumina: {info}")
return model
def load_ae(
ckpt_path: str,
dtype: torch.dtype,
device: Union[str, torch.device],
disable_mmap: bool = False,
) -> flux_models.AutoEncoder:
logger.info("Building AutoEncoder")
with torch.device("meta"):
# dev and schnell have the same AE params
ae = flux_models.AutoEncoder(flux_models.configs["schnell"].ae_params).to(dtype)
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(
ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype
)
info = ae.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded AE: {info}")
return ae
def load_gemma2(
ckpt_path: Optional[str],
dtype: torch.dtype,
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[dict] = None,
) -> Gemma2Model:
logger.info("Building Gemma2")
GEMMA2_CONFIG = {
"_name_or_path": "google/gemma-2b",
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 2,
"eos_token_id": 1,
"head_dim": 256,
"hidden_act": "gelu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 16384,
"max_position_embeddings": 8192,
"model_type": "gemma",
"num_attention_heads": 8,
"num_hidden_layers": 18,
"num_key_value_heads": 1,
"pad_token_id": 0,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 10000.0,
"torch_dtype": "bfloat16",
"transformers_version": "4.38.0.dev0",
"use_cache": true,
"vocab_size": 256000
}
config = Gemma2Config(**GEMMA2_CONFIG)
with init_empty_weights():
gemma2 = Gemma2Model._from_config(config)
if state_dict is not None:
sd = state_dict
else:
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(
ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype
)
info = gemma2.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Gemma2: {info}")
return gemma2
def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int):
img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :]
img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
return img_ids
def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
"""
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
"""
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
return x
def pack_latents(x: torch.Tensor) -> torch.Tensor:
"""
x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2
"""
x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
return x
DIFFUSERS_TO_ALPHA_VLLM_MAP = {
# Embedding layers
"cap_embedder.0.weight": ["time_caption_embed.caption_embedder.0.weight"],
"cap_embedder.1.weight": "time_caption_embed.caption_embedder.1.weight",
"cap_embedder.1.bias": "text_embedder.1.bias",
"x_embedder.weight": "patch_embedder.proj.weight",
"x_embedder.bias": "patch_embedder.proj.bias",
# Attention modulation
"layers.().adaLN_modulation.1.weight": "transformer_blocks.().adaln_modulation.1.weight",
"layers.().adaLN_modulation.1.bias": "transformer_blocks.().adaln_modulation.1.bias",
# Final layers
"final_layer.adaLN_modulation.1.weight": "final_adaln_modulation.1.weight",
"final_layer.adaLN_modulation.1.bias": "final_adaln_modulation.1.bias",
"final_layer.linear.weight": "final_linear.weight",
"final_layer.linear.bias": "final_linear.bias",
# Noise refiner
"noise_refiner.().adaLN_modulation.1.weight": "single_transformer_blocks.().adaln_modulation.1.weight",
"noise_refiner.().adaLN_modulation.1.bias": "single_transformer_blocks.().adaln_modulation.1.bias",
"noise_refiner.().attention.qkv.weight": "single_transformer_blocks.().attn.to_qkv.weight",
"noise_refiner.().attention.out.weight": "single_transformer_blocks.().attn.to_out.0.weight",
# Time embedding
"t_embedder.mlp.0.weight": "time_embedder.0.weight",
"t_embedder.mlp.0.bias": "time_embedder.0.bias",
"t_embedder.mlp.2.weight": "time_embedder.2.weight",
"t_embedder.mlp.2.bias": "time_embedder.2.bias",
# Context attention
"context_refiner.().attention.qkv.weight": "transformer_blocks.().attn2.to_qkv.weight",
"context_refiner.().attention.out.weight": "transformer_blocks.().attn2.to_out.0.weight",
# Normalization
"layers.().attention_norm1.weight": "transformer_blocks.().norm1.weight",
"layers.().attention_norm2.weight": "transformer_blocks.().norm2.weight",
# FFN
"layers.().feed_forward.w1.weight": "transformer_blocks.().ff.net.0.proj.weight",
"layers.().feed_forward.w2.weight": "transformer_blocks.().ff.net.2.weight",
"layers.().feed_forward.w3.weight": "transformer_blocks.().ff.net.4.weight",
}
def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict:
"""Convert Diffusers checkpoint to Alpha-VLLM format"""
logger.info("Converting Diffusers checkpoint to Alpha-VLLM format")
new_sd = {}
for key, value in sd.items():
new_key = key
for pattern, replacement in DIFFUSERS_TO_ALPHA_VLLM_MAP.items():
if "()." in pattern:
for block_idx in range(num_double_blocks):
if str(block_idx) in key:
converted = pattern.replace("()", str(block_idx))
new_key = key.replace(
converted, replacement.replace("()", str(block_idx))
)
break
if new_key == key:
logger.debug(f"Unmatched key in conversion: {key}")
new_sd[new_key] = value
logger.info(f"Converted {len(new_sd)} keys to Alpha-VLLM format")
return new_sd

View File

@@ -61,6 +61,8 @@ ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc.
# ARCH_SD3_UNKNOWN = "stable-diffusion-3"
ARCH_FLUX_1_DEV = "flux-1-dev"
ARCH_FLUX_1_UNKNOWN = "flux-1"
ARCH_LUMINA_2 = "lumina-2"
ARCH_LUMINA_UNKNOWN = "lumina"
ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
@@ -69,6 +71,7 @@ IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI"
IMPL_DIFFUSERS = "diffusers"
IMPL_FLUX = "https://github.com/black-forest-labs/flux"
IMPL_LUMINA = "https://github.com/Alpha-VLLM/Lumina-Image-2.0"
PRED_TYPE_EPSILON = "epsilon"
PRED_TYPE_V = "v"
@@ -123,6 +126,7 @@ def build_metadata(
clip_skip: Optional[int] = None,
sd3: Optional[str] = None,
flux: Optional[str] = None,
lumina: Optional[str] = None,
):
"""
sd3: only supports "m", flux: only supports "dev"
@@ -146,6 +150,11 @@ def build_metadata(
arch = ARCH_FLUX_1_DEV
else:
arch = ARCH_FLUX_1_UNKNOWN
elif lumina is not None:
if lumina == "lumina2":
arch = ARCH_LUMINA_2
else:
arch = ARCH_LUMINA_UNKNOWN
elif v2:
if v_parameterization:
arch = ARCH_SD_V2_768_V
@@ -167,6 +176,9 @@ def build_metadata(
if flux is not None:
# Flux
impl = IMPL_FLUX
elif lumina is not None:
# Lumina
impl = IMPL_LUMINA
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
# Stable Diffusion ckpt, TI, SDXL LoRA
impl = IMPL_STABILITY_AI

275
library/strategy_lumina.py Normal file
View File

@@ -0,0 +1,275 @@
import glob
import os
from typing import Any, List, Optional, Tuple, Union
import torch
from transformers import AutoTokenizer, AutoModel
from library import train_util
from library.strategy_base import (
LatentsCachingStrategy,
TokenizeStrategy,
TextEncodingStrategy,
)
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
GEMMA_ID = "google/gemma-2-2b"
class LuminaTokenizeStrategy(TokenizeStrategy):
def __init__(
self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None
) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(
GEMMA_ID, cache_dir=tokenizer_cache_dir
)
self.tokenizer.padding_side = "right"
if max_length is None:
self.max_length = self.tokenizer.model_max_length
else:
self.max_length = max_length
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
encodings = self.tokenizer(
text,
padding="max_length",
max_length=self.max_length,
return_tensors="pt",
truncation=True,
)
return [encodings.input_ids]
def tokenize_with_weights(
self, text: str | List[str]
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
# Gemma doesn't support weighted prompts, return uniform weights
tokens = self.tokenize(text)
weights = [torch.ones_like(t) for t in tokens]
return tokens, weights
class LuminaTextEncodingStrategy(TextEncodingStrategy):
def __init__(self, apply_gemma2_attn_mask: Optional[bool] = None) -> None:
super().__init__()
self.apply_gemma2_attn_mask = apply_gemma2_attn_mask
def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
apply_gemma2_attn_mask: Optional[bool] = None,
) -> List[torch.Tensor]:
if apply_gemma2_attn_mask is None:
apply_gemma2_attn_mask = self.apply_gemma2_attn_mask
text_encoder = models[0]
input_ids = tokens[0].to(text_encoder.device)
attention_mask = None
position_ids = None
if apply_gemma2_attn_mask:
# Create attention mask (1 for non-padding, 0 for padding)
attention_mask = (input_ids != tokenize_strategy.tokenizer.pad_token_id).to(
text_encoder.device
)
# Create position IDs
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
with torch.no_grad():
outputs = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_hidden_states=True,
return_dict=True,
)
# Get the last hidden state
hidden_states = outputs.last_hidden_state
return [hidden_states]
def encode_tokens_with_weights(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens_list: List[torch.Tensor],
weights_list: List[torch.Tensor],
) -> List[torch.Tensor]:
# For simplicity, use uniform weighting
return self.encode_tokens(tokenize_strategy, models, tokens_list)
class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_lumina_te.npz"
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
apply_gemma2_attn_mask: bool = False,
) -> None:
super().__init__(
cache_to_disk,
batch_size,
skip_disk_cache_validity_check,
is_partial,
)
self.apply_gemma2_attn_mask = apply_gemma2_attn_mask
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return (
os.path.splitext(image_abs_path)[0]
+ LuminaTextEncoderOutputsCachingStrategy.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
)
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
if "hidden_state" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
hidden_state = data["hidden_state"]
return [hidden_state]
def cache_batch_outputs(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
text_encoding_strategy: TextEncodingStrategy,
infos: List,
):
lumina_text_encoding_strategy: LuminaTextEncodingStrategy = (
text_encoding_strategy
)
captions = [info.caption for info in infos]
if self.is_weighted:
tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(
captions
)
with torch.no_grad():
hidden_state = lumina_text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, models, tokens_list, weights_list
)[0]
else:
tokens = tokenize_strategy.tokenize(captions)
with torch.no_grad():
hidden_state = lumina_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens
)[0]
if hidden_state.dtype == torch.bfloat16:
hidden_state = hidden_state.float()
hidden_state = hidden_state.cpu().numpy()
for i, info in enumerate(infos):
hidden_state_i = hidden_state[i]
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
hidden_state=hidden_state_i,
)
else:
info.text_encoder_outputs = [hidden_state_i]
class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
LUMINA_LATENTS_NPZ_SUFFIX = "_lumina.npz"
def __init__(
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(
self, absolute_path: str, image_size: Tuple[int, int]
) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],
npz_path: str,
flip_aug: bool,
alpha_mask: bool,
):
return self._default_is_disk_cached_latents_expected(
8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True
)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[
Optional[np.ndarray],
Optional[List[int]],
Optional[List[int]],
Optional[np.ndarray],
Optional[np.ndarray],
]:
return self._default_load_latents_from_disk(
8, npz_path, bucket_reso
) # support multi-resolution
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(
self,
vae,
image_infos: List,
flip_aug: bool,
alpha_mask: bool,
random_crop: bool,
):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(
encode_by_vae,
vae_device,
vae_dtype,
image_infos,
flip_aug,
alpha_mask,
random_crop,
multi_resolution=True,
)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)

View File

@@ -3463,6 +3463,7 @@ def get_sai_model_spec(
is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA
sd3: str = None,
flux: str = None,
lumina: str = None,
):
timestamp = time.time()
@@ -3498,6 +3499,7 @@ def get_sai_model_spec(
clip_skip=args.clip_skip, # None or int
sd3=sd3,
flux=flux,
lumina=lumina,
)
return metadata

192
lumina_train_network.py Normal file
View File

@@ -0,0 +1,192 @@
import argparse
import copy
import math
import random
from typing import Any, Optional, Union
import torch
from accelerate import Accelerator
from library.device_utils import clean_memory_on_device, init_ipex
init_ipex()
import train_network
from library import (
lumina_models,
flux_train_utils,
lumina_util,
lumina_train_util,
sd3_train_utils,
strategy_base,
strategy_lumina,
train_util,
)
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class LuminaNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
self.is_swapping_blocks: bool = False
def assert_extra_args(self, args, train_dataset_group, val_dataset_group):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
if (
args.cache_text_encoder_outputs_to_disk
and not args.cache_text_encoder_outputs
):
logger.warning("Enabling cache_text_encoder_outputs due to disk caching")
args.cache_text_encoder_outputs = True
train_dataset_group.verify_bucket_reso_steps(32)
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(32)
self.train_gemma2 = not args.network_train_unet_only
def load_target_model(self, args, weight_dtype, accelerator):
loading_dtype = None if args.fp8 else weight_dtype
model = lumina_util.load_lumina_model(
args.pretrained_model_name_or_path,
loading_dtype,
"cpu",
disable_mmap=args.disable_mmap_load_safetensors,
)
# if args.blocks_to_swap:
# logger.info(f'Enabling block swap: {args.blocks_to_swap}')
# model.enable_block_swap(args.blocks_to_swap, accelerator.device)
# self.is_swapping_blocks = True
gemma2 = lumina_util.load_gemma2(args.gemma2, weight_dtype, "cpu")
ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu")
return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model
def get_tokenize_strategy(self, args):
return strategy_lumina.LuminaTokenizeStrategy(
args.gemma2_max_token_length, args.tokenizer_cache_dir
)
def get_tokenizers(self, tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy):
return [tokenize_strategy.tokenizer]
def get_latents_caching_strategy(self, args):
return strategy_lumina.LuminaLatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, False
)
def get_text_encoding_strategy(self, args):
return strategy_lumina.LuminaTextEncodingStrategy(args.apply_gemma2_attn_mask)
def get_text_encoders_train_flags(self, args, text_encoders):
return [self.train_gemma2]
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_lumina.LuminaTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=self.train_gemma2,
apply_gemma2_attn_mask=args.apply_gemma2_attn_mask,
)
else:
return None
def cache_text_encoder_outputs_if_needed(
self,
args,
accelerator: Accelerator,
unet,
vae,
text_encoders,
dataset,
weight_dtype,
):
for text_encoder in text_encoders:
text_encoder_outputs_caching_strategy = (
self.get_text_encoder_outputs_caching_strategy(args)
)
if text_encoder_outputs_caching_strategy is not None:
text_encoder_outputs_caching_strategy.cache_batch_outputs(
self.get_tokenize_strategy(args),
[text_encoder],
self.get_text_encoding_strategy(args),
dataset,
)
def sample_images(
self,
accelerator,
args,
epoch,
global_step,
device,
ae,
tokenizer,
text_encoder,
lumina,
):
lumina_train_util.sample_images(
accelerator,
args,
epoch,
global_step,
lumina,
ae,
self.get_models_for_text_encoding(args, accelerator, text_encoder),
self.sample_prompts_te_outputs,
)
# Remaining methods maintain similar structure to flux implementation
# with Lumina-specific model calls and strategies
def get_noise_scheduler(
self, args: argparse.Namespace, device: torch.device
) -> Any:
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000, shift=args.discrete_flow_shift
)
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
return noise_scheduler
def encode_images_to_latents(self, args, accelerator, vae, images):
return vae.encode(images)
# not sure, they use same flux vae
def shift_scale_latents(self, args, latents):
return latents
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
def get_sai_model_spec(self, args):
return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)
lumina_train_utils.add_lumina_train_arguments(parser)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = LuminaNetworkTrainer()
trainer.train(args)