Add sample batch size for Lumina

This commit is contained in:
rockerBOO
2025-02-23 20:19:24 -05:00
parent ba725a84e9
commit 48e7da2d4a
3 changed files with 201 additions and 106 deletions

View File

@@ -880,8 +880,8 @@ class NextDiT(nn.Module):
self.n_heads = n_heads
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
self.blocks_to_swap = None
self.cpu_offload_checkpointing = False # TODO: not yet supported
self.blocks_to_swap = None # TODO: not yet supported
@property
def device(self):
@@ -982,8 +982,8 @@ class NextDiT(nn.Module):
l_effective_cap_len = cap_mask.sum(dim=1).tolist()
encoder_seq_len = cap_mask.shape[1]
image_seq_len = (height // self.patch_size) * (width // self.patch_size)
seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len]
max_seq_len = max(seq_lengths)

View File

@@ -4,7 +4,7 @@ import math
import os
import numpy as np
import time
from typing import Callable, Dict, List, Optional, Tuple, Any, Union
from typing import Callable, Dict, List, Optional, Tuple, Any, Union, Generator
import torch
from torch import Tensor
@@ -32,6 +32,59 @@ logger = logging.getLogger(__name__)
# region sample images
def batchify(prompt_dicts, batch_size=None) -> Generator[list[dict[str, str]], None, None]:
"""
Group prompt dictionaries into batches with configurable batch size.
Args:
prompt_dicts (list): List of dictionaries containing prompt parameters.
batch_size (int, optional): Number of prompts per batch. Defaults to None.
Yields:
list[dict[str, str]]: Batch of prompts.
"""
# Validate batch_size
if batch_size is not None:
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError("batch_size must be a positive integer or None")
# Group prompts by their parameters
batches = {}
for prompt_dict in prompt_dicts:
# Extract parameters
width = int(prompt_dict.get("width", 1024))
height = int(prompt_dict.get("height", 1024))
height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
guidance_scale = float(prompt_dict.get("scale", 3.5))
sample_steps = int(prompt_dict.get("sample_steps", 38))
seed = prompt_dict.get("seed", None)
seed = int(seed) if seed is not None else None
# Create a key based on the parameters
key = (width, height, guidance_scale, seed, sample_steps)
# Add the prompt_dict to the corresponding batch
if key not in batches:
batches[key] = []
batches[key].append(prompt_dict)
# Yield each batch with its parameters
for key in batches:
prompts = batches[key]
if batch_size is None:
# Yield the entire group as a single batch
yield prompts
else:
# Split the group into batches of size `batch_size`
start = 0
while start < len(prompts):
end = start + batch_size
batch = prompts[start:end]
yield batch
start = end
@torch.no_grad()
def sample_images(
accelerator: Accelerator,
@@ -39,9 +92,9 @@ def sample_images(
epoch: int,
global_step: int,
nextdit: lumina_models.NextDiT,
vae: torch.nn.Module,
vae: AutoEncoder,
gemma2_model: Gemma2Model,
sample_prompts_gemma2_outputs: List[Tuple[Tensor, Tensor, Tensor]],
sample_prompts_gemma2_outputs: dict[str, Tuple[Tensor, Tensor, Tensor]],
prompt_replacement: Optional[Tuple[str, str]] = None,
controlnet=None,
):
@@ -54,11 +107,13 @@ def sample_images(
epoch (int): Current epoch number.
global_step (int): Current global step number.
nextdit (lumina_models.NextDiT): The NextDiT model instance.
vae (torch.nn.Module): The VAE module.
vae (AutoEncoder): The VAE module.
gemma2_model (Gemma2Model): The Gemma2 model instance.
sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing the encoded prompts, text masks, and timestep for each sample.
prompt_replacement (Optional[Tuple[str, str]], optional): Tuple containing the prompt and negative prompt replacements. Defaults to None.
controlnet:: ControlNet model
sample_prompts_gemma2_outputs (dict[str, Tuple[Tensor, Tensor, Tensor]]):
Dictionary ist of tuples containing the encoded prompts, text masks, and timestep for each sample.
prompt_replacement (Optional[Tuple[str, str]], optional):
Tuple containing the prompt and negative prompt replacements. Defaults to None.
controlnet (): ControlNet model, not yet supported
Returns:
None
@@ -110,9 +165,12 @@ def sample_images(
except Exception:
pass
batch_size = args.sample_batch_size or args.train_batch_size or 1
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
for prompt_dict in prompts:
# TODO: batch prompts together with buckets of image sizes
for prompt_dicts in batchify(prompts, batch_size):
sample_image_inference(
accelerator,
args,
@@ -120,7 +178,7 @@ def sample_images(
gemma2_model,
vae,
save_dir,
prompt_dict,
prompt_dicts,
epoch,
global_step,
sample_prompts_gemma2_outputs,
@@ -135,7 +193,8 @@ def sample_images(
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
# TODO: batch prompts together with buckets of image sizes
for prompt_dicts in batchify(prompt_dict_lists[0], batch_size):
sample_image_inference(
accelerator,
args,
@@ -143,7 +202,7 @@ def sample_images(
gemma2_model,
vae,
save_dir,
prompt_dict,
prompt_dicts,
epoch,
global_step,
sample_prompts_gemma2_outputs,
@@ -166,10 +225,10 @@ def sample_image_inference(
gemma2_model: Gemma2Model,
vae: AutoEncoder,
save_dir: str,
prompt_dict: Dict[str, str],
prompt_dicts: list[Dict[str, str]],
epoch: int,
global_step: int,
sample_prompts_gemma2_outputs: dict[str, List[Tuple[Tensor, Tensor, Tensor]]],
sample_prompts_gemma2_outputs: dict[str, Tuple[Tensor, Tensor, Tensor]],
prompt_replacement: Optional[Tuple[str, str]] = None,
controlnet=None,
):
@@ -192,43 +251,6 @@ def sample_image_inference(
Returns:
None
"""
assert isinstance(prompt_dict, dict)
# negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = int(prompt_dict.get("sample_steps", 38))
width = int(prompt_dict.get("width", 1024))
height = int(prompt_dict.get("height", 1024))
guidance_scale = float(prompt_dict.get("scale", 3.5))
seed = prompt_dict.get("seed", None)
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
negative_prompt: str = prompt_dict.get("negative_prompt", "")
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
seed = int(seed) if seed is not None else None
assert seed is None or seed > 0, f"Invalid seed {seed}"
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
generator = torch.Generator(device=accelerator.device)
if seed is not None:
generator.manual_seed(seed)
# if negative_prompt is None:
# negative_prompt = ""
height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
logger.info(f"prompt: {prompt}")
logger.info(f"negative_prompt: {negative_prompt}")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
logger.info(f"scale: {guidance_scale}")
# logger.info(f"sample_sampler: {sampler_name}")
if seed is not None:
logger.info(f"seed: {seed}")
# encode prompts
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
@@ -237,33 +259,86 @@ def sample_image_inference(
assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
assert isinstance(encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
system_prompt = args.system_prompt or ""
text_conds = []
# Apply system prompt to prompts
prompt = system_prompt + prompt
negative_prompt = system_prompt + negative_prompt
# assuming seed, width, height, sample steps, guidance are the same
width = int(prompt_dicts[0].get("width", 1024))
height = int(prompt_dicts[0].get("height", 1024))
height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
# Get sample prompts from cache
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
gemma2_conds = sample_prompts_gemma2_outputs[prompt]
logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}")
guidance_scale = float(prompt_dicts[0].get("scale", 3.5))
sample_steps = int(prompt_dicts[0].get("sample_steps", 36))
seed = prompt_dicts[0].get("seed", None)
seed = int(seed) if seed is not None else None
assert seed is None or seed > 0, f"Invalid seed {seed}"
generator = torch.Generator(device=accelerator.device)
if seed is not None:
generator.manual_seed(seed)
if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs:
neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt]
logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}")
for prompt_dict in prompt_dicts:
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
negative_prompt = prompt_dict.get("negative_prompt", "")
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
# Load sample prompts from Gemma 2
if gemma2_model is not None:
logger.info(f"Encoding prompt with Gemma2: {prompt}")
tokens_and_masks = tokenize_strategy.tokenize(prompt)
gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt)
neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks)
if negative_prompt is None:
negative_prompt = ""
logger.info(f"prompt: {prompt}")
logger.info(f"negative_prompt: {negative_prompt}")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
logger.info(f"scale: {guidance_scale}")
# logger.info(f"sample_sampler: {sampler_name}")
# Unpack Gemma2 outputs
gemma2_hidden_states, input_ids, gemma2_attn_mask = gemma2_conds
neg_gemma2_hidden_states, neg_input_ids, neg_gemma2_attn_mask = neg_gemma2_conds
system_prompt = args.system_prompt or ""
# Apply system prompt to prompts
prompt = system_prompt + prompt
negative_prompt = system_prompt + negative_prompt
# Get sample prompts from cache
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
gemma2_conds = sample_prompts_gemma2_outputs[prompt]
logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}")
if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs:
neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt]
logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}")
# Load sample prompts from Gemma 2
if gemma2_model is not None:
logger.info(f"Encoding prompt with Gemma2: {prompt}")
tokens_and_masks = tokenize_strategy.tokenize(prompt)
gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks)
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt)
neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks)
# Unpack Gemma2 outputs
gemma2_hidden_states, _, gemma2_attn_mask = gemma2_conds
neg_gemma2_hidden_states, _, neg_gemma2_attn_mask = neg_gemma2_conds
text_conds.append(
(
gemma2_hidden_states.squeeze(0),
gemma2_attn_mask.squeeze(0),
neg_gemma2_hidden_states.squeeze(0),
neg_gemma2_attn_mask.squeeze(0),
)
)
# Stack conditioning
cond_hidden_states = torch.stack([text_cond[0] for text_cond in text_conds]).to(accelerator.device)
cond_attn_masks = torch.stack([text_cond[1] for text_cond in text_conds]).to(accelerator.device)
uncond_hidden_states = torch.stack([text_cond[2] for text_cond in text_conds]).to(accelerator.device)
uncond_attn_masks = torch.stack([text_cond[3] for text_cond in text_conds]).to(accelerator.device)
# sample image
weight_dtype = vae.dtype # TOFO give dtype as argument
@@ -279,6 +354,7 @@ def sample_image_inference(
dtype=weight_dtype,
generator=generator,
)
noise = noise.repeat(cond_hidden_states.shape[0], 1, 1, 1)
scheduler = FlowMatchEulerDiscreteScheduler(shift=6.0)
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=sample_steps)
@@ -294,10 +370,10 @@ def sample_image_inference(
scheduler,
nextdit,
noise,
gemma2_hidden_states,
gemma2_attn_mask.to(accelerator.device),
neg_gemma2_hidden_states,
neg_gemma2_attn_mask.to(accelerator.device),
cond_hidden_states,
cond_attn_masks,
uncond_hidden_states,
uncond_attn_masks,
timesteps=timesteps,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
@@ -307,34 +383,44 @@ def sample_image_inference(
clean_memory_on_device(accelerator.device)
org_vae_device = vae.device # will be on cpu
vae.to(accelerator.device) # distributed_state.device is same as accelerator.device
with accelerator.autocast():
x = vae.decode((x / vae.scale_factor) + vae.shift_factor)
for img, prompt_dict in zip(x, prompt_dicts):
img = (img / vae.scale_factor) + vae.shift_factor
with accelerator.autocast():
# Add a single batch image for the VAE to decode
img = vae.decode(img.unsqueeze(0))
img = img.clamp(-1, 1)
img = img.permute(0, 2, 3, 1) # B, H, W, C
# Scale images back to 0 to 255
img = (127.5 * (img + 1.0)).float().cpu().numpy().astype(np.uint8)
# Get single image
image = Image.fromarray(img[0])
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
# but adding 'enum' to the filename should be enough
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{global_step:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i: int = int(prompt_dict.get("enum", 0))
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# send images to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
x = x.clamp(-1, 1)
x = x.permute(0, 2, 3, 1)
image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
# but adding 'enum' to the filename should be enough
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{global_step:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i: int = int(prompt_dict.get("enum", 0))
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# send images to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
def time_shift(mu: float, sigma: float, t: torch.Tensor):
# the following implementation was original for t=0: clean / t=1: noise
@@ -879,16 +965,22 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
"--discrete_flow_shift",
type=float,
default=6.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 6.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは6.0",
help="Discrete flow shift for the Euler Discrete Scheduler, default is 6.0 / Euler Discrete Schedulerの離散フローシフト、デフォルトは6.0",
)
parser.add_argument(
"--use_flash_attn",
action="store_true",
help="Use Flash Attention for the model. / モデルにFlash Attentionを使用する",
help="Use Flash Attention for the model / モデルにFlash Attentionを使用する",
)
parser.add_argument(
"--system_prompt",
type=str,
default="",
help="System prompt to add to the prompt. / プロンプトに追加するシステムプロンプト",
help="System prompt to add to the prompt / プロンプトに追加するシステムプロンプト",
)
parser.add_argument(
"--sample_batch_size",
type=int,
default=None,
help="Batch size to use for sampling, defaults to --training_batch_size value. Sample batches are bucketed by width, height, guidance scale, and seed / サンプリングに使用するバッチサイズ。デフォルトは --training_batch_size の値です。サンプルバッチは、幅、高さ、ガイダンススケール、シードによってバケット化されます",
)

View File

@@ -1242,6 +1242,7 @@ class NetworkTrainer:
# For --sample_at_first
optimizer_eval_fn()
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
progress_bar.unpause() # Reset progress bar to before sampling images
optimizer_train_fn()
is_tracking = len(accelerator.trackers) > 0
if is_tracking:
@@ -1344,6 +1345,7 @@ class NetworkTrainer:
self.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
)
progress_bar.unpause()
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
@@ -1531,6 +1533,7 @@ class NetworkTrainer:
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
progress_bar.unpause()
optimizer_train_fn()
# end of epoch