mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
Fix samples, LoRA training. Add system prompt, use_flash_attn
This commit is contained in:
@@ -75,6 +75,7 @@ class BaseSubsetParams:
|
||||
custom_attributes: Optional[Dict[str, Any]] = None
|
||||
validation_seed: int = 0
|
||||
validation_split: float = 0.0
|
||||
system_prompt: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -106,6 +107,7 @@ class BaseDatasetParams:
|
||||
debug_dataset: bool = False
|
||||
validation_seed: Optional[int] = None
|
||||
validation_split: float = 0.0
|
||||
system_prompt: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -196,6 +198,7 @@ class ConfigSanitizer:
|
||||
"caption_prefix": str,
|
||||
"caption_suffix": str,
|
||||
"custom_attributes": dict,
|
||||
"system_prompt": str,
|
||||
}
|
||||
# DO means DropOut
|
||||
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
||||
@@ -241,6 +244,7 @@ class ConfigSanitizer:
|
||||
"validation_split": float,
|
||||
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||
"network_multiplier": float,
|
||||
"system_prompt": str,
|
||||
}
|
||||
|
||||
# options handled by argparse but not handled by user config
|
||||
@@ -526,6 +530,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
system_prompt: {dataset.system_prompt}
|
||||
""")
|
||||
|
||||
if dataset.enable_bucket:
|
||||
@@ -559,6 +564,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
token_warmup_step: {subset.token_warmup_step},
|
||||
alpha_mask: {subset.alpha_mask}
|
||||
custom_attributes: {subset.custom_attributes}
|
||||
system_prompt: {subset.system_prompt}
|
||||
"""), " ")
|
||||
|
||||
if is_dreambooth:
|
||||
|
||||
@@ -14,14 +14,19 @@ from typing import List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
from einops import rearrange
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
from apex.normalization import FusedRMSNorm as RMSNorm
|
||||
except:
|
||||
@@ -75,7 +80,15 @@ class LuminaParams:
|
||||
@classmethod
|
||||
def get_7b_config(cls) -> "LuminaParams":
|
||||
"""Returns the configuration for the 7B parameter model"""
|
||||
return cls(patch_size=2, dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, axes_dims=[64, 64, 64], axes_lens=[300, 512, 512])
|
||||
return cls(
|
||||
patch_size=2,
|
||||
dim=4096,
|
||||
n_layers=32,
|
||||
n_heads=32,
|
||||
n_kv_heads=8,
|
||||
axes_dims=[64, 64, 64],
|
||||
axes_lens=[300, 512, 512],
|
||||
)
|
||||
|
||||
|
||||
class GradientCheckpointMixin(nn.Module):
|
||||
@@ -248,6 +261,7 @@ class JointAttention(nn.Module):
|
||||
n_heads: int,
|
||||
n_kv_heads: Optional[int],
|
||||
qk_norm: bool,
|
||||
use_flash_attn=False,
|
||||
):
|
||||
"""
|
||||
Initialize the Attention module.
|
||||
@@ -286,7 +300,7 @@ class JointAttention(nn.Module):
|
||||
else:
|
||||
self.q_norm = self.k_norm = nn.Identity()
|
||||
|
||||
self.flash_attn = False
|
||||
self.use_flash_attn = use_flash_attn
|
||||
|
||||
# self.attention_processor = xformers.ops.memory_efficient_attention
|
||||
self.attention_processor = F.scaled_dot_product_attention
|
||||
@@ -294,35 +308,63 @@ class JointAttention(nn.Module):
|
||||
def set_attention_processor(self, attention_processor):
|
||||
self.attention_processor = attention_processor
|
||||
|
||||
@staticmethod
|
||||
def apply_rotary_emb(
|
||||
x_in: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
x_mask: Tensor,
|
||||
freqs_cis: Tensor,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency
|
||||
tensor.
|
||||
|
||||
This function applies rotary embeddings to the given query 'xq' and
|
||||
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
|
||||
input tensors are reshaped as complex numbers, and the frequency tensor
|
||||
is reshaped for broadcasting compatibility. The resulting tensors
|
||||
contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
|
||||
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
|
||||
exponentials.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
|
||||
and key tensor with rotary embeddings.
|
||||
x:
|
||||
x_mask:
|
||||
freqs_cis:
|
||||
"""
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||
return x_out.type_as(x_in)
|
||||
bsz, seqlen, _ = x.shape
|
||||
dtype = x.dtype
|
||||
|
||||
xq, xk, xv = torch.split(
|
||||
self.qkv(x),
|
||||
[
|
||||
self.n_local_heads * self.head_dim,
|
||||
self.n_local_kv_heads * self.head_dim,
|
||||
self.n_local_kv_heads * self.head_dim,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
xq = apply_rope(xq, freqs_cis=freqs_cis)
|
||||
xk = apply_rope(xk, freqs_cis=freqs_cis)
|
||||
xq, xk = xq.to(dtype), xk.to(dtype)
|
||||
|
||||
softmax_scale = math.sqrt(1 / self.head_dim)
|
||||
|
||||
if self.use_flash_attn:
|
||||
output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale)
|
||||
else:
|
||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
if n_rep >= 1:
|
||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
|
||||
output = (
|
||||
self.attention_processor(
|
||||
xq.permute(0, 2, 1, 3),
|
||||
xk.permute(0, 2, 1, 3),
|
||||
xv.permute(0, 2, 1, 3),
|
||||
attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
|
||||
scale=softmax_scale,
|
||||
)
|
||||
.permute(0, 2, 1, 3)
|
||||
.to(dtype)
|
||||
)
|
||||
|
||||
output = output.flatten(-2)
|
||||
return self.out(output)
|
||||
|
||||
# copied from huggingface modeling_llama.py
|
||||
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
||||
@@ -377,46 +419,17 @@ class JointAttention(nn.Module):
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||
)
|
||||
|
||||
def forward(
|
||||
def flash_attn(
|
||||
self,
|
||||
x: Tensor,
|
||||
q: Tensor,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
x_mask: Tensor,
|
||||
freqs_cis: Tensor,
|
||||
softmax_scale,
|
||||
) -> Tensor:
|
||||
"""
|
||||
bsz, seqlen, _, _ = q.shape
|
||||
|
||||
Args:
|
||||
x:
|
||||
x_mask:
|
||||
freqs_cis:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
bsz, seqlen, _ = x.shape
|
||||
dtype = x.dtype
|
||||
|
||||
xq, xk, xv = torch.split(
|
||||
self.qkv(x),
|
||||
[
|
||||
self.n_local_heads * self.head_dim,
|
||||
self.n_local_kv_heads * self.head_dim,
|
||||
self.n_local_kv_heads * self.head_dim,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
|
||||
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
|
||||
xq, xk = xq.to(dtype), xk.to(dtype)
|
||||
|
||||
softmax_scale = math.sqrt(1 / self.head_dim)
|
||||
|
||||
if self.flash_attn:
|
||||
try:
|
||||
# begin var_len flash attn
|
||||
(
|
||||
query_states,
|
||||
@@ -425,7 +438,7 @@ class JointAttention(nn.Module):
|
||||
indices_q,
|
||||
cu_seq_lens,
|
||||
max_seq_lens,
|
||||
) = self._upad_input(xq, xk, xv, x_mask, seqlen)
|
||||
) = self._upad_input(q, k, v, x_mask, seqlen)
|
||||
|
||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||
@@ -445,27 +458,12 @@ class JointAttention(nn.Module):
|
||||
output = pad_input(attn_output_unpad, indices_q, bsz, seqlen)
|
||||
# end var_len_flash_attn
|
||||
|
||||
else:
|
||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
if n_rep >= 1:
|
||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
|
||||
output = (
|
||||
self.attention_processor(
|
||||
xq.permute(0, 2, 1, 3),
|
||||
xk.permute(0, 2, 1, 3),
|
||||
xv.permute(0, 2, 1, 3),
|
||||
attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
|
||||
scale=softmax_scale,
|
||||
)
|
||||
.permute(0, 2, 1, 3)
|
||||
.to(dtype)
|
||||
return output
|
||||
except NameError as e:
|
||||
raise RuntimeError(
|
||||
f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}"
|
||||
)
|
||||
|
||||
output = output.flatten(-2)
|
||||
return self.out(output)
|
||||
|
||||
|
||||
def apply_rope(
|
||||
x_in: torch.Tensor,
|
||||
@@ -563,6 +561,7 @@ class JointTransformerBlock(GradientCheckpointMixin):
|
||||
norm_eps: float,
|
||||
qk_norm: bool,
|
||||
modulation=True,
|
||||
use_flash_attn=False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a TransformerBlock.
|
||||
@@ -585,7 +584,7 @@ class JointTransformerBlock(GradientCheckpointMixin):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm)
|
||||
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, use_flash_attn=use_flash_attn)
|
||||
self.feed_forward = FeedForward(
|
||||
dim=dim,
|
||||
hidden_dim=4 * dim,
|
||||
@@ -711,7 +710,12 @@ class FinalLayer(GradientCheckpointMixin):
|
||||
|
||||
|
||||
class RopeEmbedder:
|
||||
def __init__(self, theta: float = 10000.0, axes_dims: List[int] = [16, 56, 56], axes_lens: List[int] = [1, 512, 512]):
|
||||
def __init__(
|
||||
self,
|
||||
theta: float = 10000.0,
|
||||
axes_dims: List[int] = [16, 56, 56],
|
||||
axes_lens: List[int] = [1, 512, 512],
|
||||
):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dims = axes_dims
|
||||
@@ -750,6 +754,7 @@ class NextDiT(nn.Module):
|
||||
cap_feat_dim: int = 5120,
|
||||
axes_dims: List[int] = [16, 56, 56],
|
||||
axes_lens: List[int] = [1, 512, 512],
|
||||
use_flash_attn=False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the NextDiT model.
|
||||
@@ -803,6 +808,7 @@ class NextDiT(nn.Module):
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=False,
|
||||
use_flash_attn=use_flash_attn,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
@@ -828,6 +834,7 @@ class NextDiT(nn.Module):
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
use_flash_attn=use_flash_attn,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
@@ -848,6 +855,7 @@ class NextDiT(nn.Module):
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
use_flash_attn=use_flash_attn,
|
||||
)
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
@@ -988,8 +996,20 @@ class NextDiT(nn.Module):
|
||||
freqs_cis = self.rope_embedder(position_ids)
|
||||
|
||||
# Create separate rotary embeddings for captions and images
|
||||
cap_freqs_cis = torch.zeros(bsz, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype)
|
||||
img_freqs_cis = torch.zeros(bsz, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype)
|
||||
cap_freqs_cis = torch.zeros(
|
||||
bsz,
|
||||
encoder_seq_len,
|
||||
freqs_cis.shape[-1],
|
||||
device=device,
|
||||
dtype=freqs_cis.dtype,
|
||||
)
|
||||
img_freqs_cis = torch.zeros(
|
||||
bsz,
|
||||
image_seq_len,
|
||||
freqs_cis.shape[-1],
|
||||
device=device,
|
||||
dtype=freqs_cis.dtype,
|
||||
)
|
||||
|
||||
for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
||||
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
|
||||
|
||||
@@ -1,21 +1,28 @@
|
||||
import inspect
|
||||
import enum
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import numpy as np
|
||||
import time
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Any
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Any, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torchdiffeq import odeint
|
||||
from accelerate import Accelerator, PartialState
|
||||
from transformers import Gemma2Model
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from diffusers.schedulers.scheduling_heun_discrete import HeunDiscreteScheduler
|
||||
from library import lumina_models, lumina_util, strategy_base, strategy_lumina, train_util
|
||||
from library.flux_models import AutoEncoder
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
from library.lumina_dpm_solver import NoiseScheduleFlow, DPM_Solver
|
||||
import library.lumina_path as path
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -162,12 +169,12 @@ def sample_image_inference(
|
||||
args: argparse.Namespace,
|
||||
nextdit: lumina_models.NextDiT,
|
||||
gemma2_model: Gemma2Model,
|
||||
vae: torch.nn.Module,
|
||||
vae: AutoEncoder,
|
||||
save_dir: str,
|
||||
prompt_dict: Dict[str, str],
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
sample_prompts_gemma2_outputs: List[Tuple[Tensor, Tensor, Tensor]],
|
||||
sample_prompts_gemma2_outputs: dict[str, List[Tuple[Tensor, Tensor, Tensor]]],
|
||||
prompt_replacement: Optional[Tuple[str, str]] = None,
|
||||
controlnet=None,
|
||||
):
|
||||
@@ -179,12 +186,12 @@ def sample_image_inference(
|
||||
args (argparse.Namespace): Arguments object
|
||||
nextdit (lumina_models.NextDiT): NextDiT model
|
||||
gemma2_model (Gemma2Model): Gemma2 model
|
||||
vae (torch.nn.Module): VAE model
|
||||
vae (AutoEncoder): VAE model
|
||||
save_dir (str): Directory to save images
|
||||
prompt_dict (Dict[str, str]): Prompt dictionary
|
||||
epoch (int): Epoch number
|
||||
steps (int): Number of steps to run
|
||||
sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing gemma2 outputs
|
||||
sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing Gemma 2 outputs
|
||||
prompt_replacement (Optional[Tuple[str, str]], optional): Replacement for positive and negative prompt. Defaults to None.
|
||||
|
||||
Returns:
|
||||
@@ -192,16 +199,19 @@ def sample_image_inference(
|
||||
"""
|
||||
assert isinstance(prompt_dict, dict)
|
||||
# negative_prompt = prompt_dict.get("negative_prompt")
|
||||
sample_steps = prompt_dict.get("sample_steps", 38)
|
||||
width = prompt_dict.get("width", 1024)
|
||||
height = prompt_dict.get("height", 1024)
|
||||
guidance_scale: int = prompt_dict.get("scale", 3.5)
|
||||
seed: int = prompt_dict.get("seed", None)
|
||||
sample_steps = int(prompt_dict.get("sample_steps", 38))
|
||||
width = int(prompt_dict.get("width", 1024))
|
||||
height = int(prompt_dict.get("height", 1024))
|
||||
guidance_scale = float(prompt_dict.get("scale", 3.5))
|
||||
seed = prompt_dict.get("seed", None)
|
||||
controlnet_image = prompt_dict.get("controlnet_image")
|
||||
prompt: str = prompt_dict.get("prompt", "")
|
||||
negative_prompt: str = prompt_dict.get("negative_prompt", "")
|
||||
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
||||
|
||||
seed = int(seed) if seed is not None else None
|
||||
assert seed is None or seed > 0, f"Invalid seed {seed}"
|
||||
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
if negative_prompt is not None:
|
||||
@@ -213,10 +223,10 @@ def sample_image_inference(
|
||||
|
||||
# if negative_prompt is None:
|
||||
# negative_prompt = ""
|
||||
height = max(64, height - height % 16) # round to divisible by 16
|
||||
width = max(64, width - width % 16) # round to divisible by 16
|
||||
height = max(64, height - height % 8) # round to divisible by 8
|
||||
width = max(64, width - width % 8) # round to divisible by 8
|
||||
logger.info(f"prompt: {prompt}")
|
||||
# logger.info(f"negative_prompt: {negative_prompt}")
|
||||
logger.info(f"negative_prompt: {negative_prompt}")
|
||||
logger.info(f"height: {height}")
|
||||
logger.info(f"width: {width}")
|
||||
logger.info(f"sample_steps: {sample_steps}")
|
||||
@@ -232,46 +242,51 @@ def sample_image_inference(
|
||||
assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
|
||||
assert isinstance(encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
|
||||
|
||||
gemma2_conds = []
|
||||
system_prompt = args.system_prompt or ""
|
||||
|
||||
# Apply system prompt to prompts
|
||||
prompt = system_prompt + prompt
|
||||
negative_prompt = system_prompt + negative_prompt
|
||||
|
||||
# Get sample prompts from cache
|
||||
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
|
||||
gemma2_conds = sample_prompts_gemma2_outputs[prompt]
|
||||
logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}")
|
||||
|
||||
if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs:
|
||||
neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt]
|
||||
logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}")
|
||||
|
||||
# Load sample prompts from Gemma 2
|
||||
if gemma2_model is not None:
|
||||
logger.info(f"Encoding prompt with Gemma2: {prompt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
encoded_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks)
|
||||
gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks)
|
||||
|
||||
# if gemma2_conds is not cached, use encoded_gemma2_conds
|
||||
if len(gemma2_conds) == 0:
|
||||
gemma2_conds = encoded_gemma2_conds
|
||||
else:
|
||||
# if encoded_gemma2_conds is not None, update cached gemma2_conds
|
||||
for i in range(len(encoded_gemma2_conds)):
|
||||
if encoded_gemma2_conds[i] is not None:
|
||||
gemma2_conds[i] = encoded_gemma2_conds[i]
|
||||
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt)
|
||||
neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks)
|
||||
|
||||
# Unpack Gemma2 outputs
|
||||
gemma2_hidden_states, input_ids, gemma2_attn_mask = gemma2_conds
|
||||
neg_gemma2_hidden_states, neg_input_ids, neg_gemma2_attn_mask = neg_gemma2_conds
|
||||
|
||||
# sample image
|
||||
weight_dtype = vae.dtype # TOFO give dtype as argument
|
||||
latent_height = height // 8
|
||||
latent_width = width // 8
|
||||
latent_channels = 16
|
||||
noise = torch.randn(
|
||||
1,
|
||||
16,
|
||||
latent_channels,
|
||||
latent_height,
|
||||
latent_width,
|
||||
device=accelerator.device,
|
||||
dtype=weight_dtype,
|
||||
generator=generator,
|
||||
)
|
||||
# Prompts are paired positive/negative
|
||||
noise = noise.repeat(gemma2_attn_mask.shape[0], 1, 1, 1)
|
||||
|
||||
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True)
|
||||
# img_ids = lumina_util.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
|
||||
gemma2_attn_mask = gemma2_attn_mask.to(accelerator.device)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=6.0, use_karras_sigmas=True)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=sample_steps)
|
||||
|
||||
# if controlnet_image is not None:
|
||||
# controlnet_image = Image.open(controlnet_image).convert("RGB")
|
||||
@@ -280,16 +295,25 @@ def sample_image_inference(
|
||||
# controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
|
||||
|
||||
with accelerator.autocast():
|
||||
x = denoise(nextdit, noise, gemma2_hidden_states, gemma2_attn_mask, timesteps=timesteps, guidance=guidance_scale)
|
||||
x = denoise(
|
||||
scheduler,
|
||||
nextdit,
|
||||
noise,
|
||||
gemma2_hidden_states,
|
||||
gemma2_attn_mask.to(accelerator.device),
|
||||
neg_gemma2_hidden_states,
|
||||
neg_gemma2_attn_mask.to(accelerator.device),
|
||||
timesteps=timesteps,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
)
|
||||
|
||||
# x = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width)
|
||||
|
||||
# latent to image
|
||||
# Latent to image
|
||||
clean_memory_on_device(accelerator.device)
|
||||
org_vae_device = vae.device # will be on cpu
|
||||
vae.to(accelerator.device) # distributed_state.device is same as accelerator.device
|
||||
with accelerator.autocast():
|
||||
x = vae.decode(x)
|
||||
x = vae.decode((x / vae.scale_factor) + vae.shift_factor)
|
||||
vae.to(org_vae_device)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -317,30 +341,25 @@ def sample_image_inference(
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: Tensor):
|
||||
"""
|
||||
Get time shift
|
||||
|
||||
Args:
|
||||
mu (float): mu value.
|
||||
sigma (float): sigma value.
|
||||
t (Tensor): timestep.
|
||||
|
||||
Return:
|
||||
float: time shift
|
||||
"""
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
||||
# the following implementation was original for t=0: clean / t=1: noise
|
||||
# Since we adopt the reverse, the 1-t operations are needed
|
||||
t = 1 - t
|
||||
t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
t = 1 - t
|
||||
return t
|
||||
|
||||
|
||||
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
||||
def get_lin_function(x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15) -> Callable[[float], float]:
|
||||
"""
|
||||
Get linear function
|
||||
|
||||
Args:
|
||||
x1 (float, optional): x1 value. Defaults to 256.
|
||||
y1 (float, optional): y1 value. Defaults to 0.5.
|
||||
x2 (float, optional): x2 value. Defaults to 4096.
|
||||
y2 (float, optional): y2 value. Defaults to 1.15.
|
||||
image_seq_len,
|
||||
x1 base_seq_len: int = 256,
|
||||
y2 max_seq_len: int = 4096,
|
||||
y1 base_shift: float = 0.5,
|
||||
y2 max_shift: float = 1.15,
|
||||
|
||||
Return:
|
||||
Callable[[float], float]: linear function
|
||||
@@ -370,51 +389,164 @@ def get_schedule(
|
||||
Return:
|
||||
List[float]: timesteps schedule
|
||||
"""
|
||||
# extra step for zero
|
||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||
timesteps = torch.linspace(1, 1 / num_steps, num_steps)
|
||||
|
||||
# shifting the schedule to favor high timesteps for higher signal images
|
||||
if shift:
|
||||
# eastimate mu based on linear estimation between two points
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)(image_seq_len)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
return timesteps.tolist()
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, int]:
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
def denoise(
|
||||
model: lumina_models.NextDiT, img: Tensor, txt: Tensor, txt_mask: Tensor, timesteps: List[float], guidance: float = 4.0
|
||||
scheduler,
|
||||
model: lumina_models.NextDiT,
|
||||
img: Tensor,
|
||||
txt: Tensor,
|
||||
txt_mask: Tensor,
|
||||
neg_txt: Tensor,
|
||||
neg_txt_mask: Tensor,
|
||||
timesteps: Union[List[float], torch.Tensor],
|
||||
num_inference_steps: int = 38,
|
||||
guidance_scale: float = 4.0,
|
||||
cfg_trunc_ratio: float = 1.0,
|
||||
cfg_normalization: bool = True,
|
||||
):
|
||||
"""
|
||||
Denoise an image using the NextDiT model.
|
||||
|
||||
Args:
|
||||
scheduler ():
|
||||
Noise scheduler
|
||||
model (lumina_models.NextDiT): The NextDiT model instance.
|
||||
img (Tensor): The input image tensor.
|
||||
txt (Tensor): The input text tensor.
|
||||
txt_mask (Tensor): The input text mask tensor.
|
||||
timesteps (List[float]): A list of timesteps for the denoising process.
|
||||
guidance (float, optional): The guidance scale for the denoising process. Defaults to 4.0.
|
||||
img (Tensor):
|
||||
The input image latent tensor.
|
||||
txt (Tensor):
|
||||
The input text tensor.
|
||||
txt_mask (Tensor):
|
||||
The input text mask tensor.
|
||||
neg_txt (Tensor):
|
||||
The negative input txt tensor
|
||||
neg_txt_mask (Tensor):
|
||||
The negative input text mask tensor.
|
||||
timesteps (List[Union[float, torch.FloatTensor]]):
|
||||
A list of timesteps for the denoising process.
|
||||
guidance_scale (float, optional):
|
||||
The guidance scale for the denoising process. Defaults to 4.0.
|
||||
cfg_trunc_ratio (float, optional):
|
||||
The ratio of the timestep interval to apply normalization-based guidance scale.
|
||||
cfg_normalization (bool, optional):
|
||||
Whether to apply normalization-based guidance scale.
|
||||
|
||||
Returns:
|
||||
img (Tensor): Denoised tensor
|
||||
img (Tensor): Denoised latent tensor
|
||||
"""
|
||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
# model.prepare_block_swap_before_forward()
|
||||
# block_samples = None
|
||||
# block_single_samples = None
|
||||
pred = model.forward_with_cfg(
|
||||
x=img, # image latents (B, C, H, W)
|
||||
t=t_vec / 1000, # timesteps需要除以1000来匹配模型预期
|
||||
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
# compute whether apply classifier-free truncation on this timestep
|
||||
do_classifier_free_truncation = (i + 1) / num_inference_steps > cfg_trunc_ratio
|
||||
|
||||
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
|
||||
current_timestep = 1 - t / scheduler.config.num_train_timesteps
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
current_timestep = current_timestep.expand(img.shape[0]).to(model.device)
|
||||
|
||||
noise_pred_cond = model(
|
||||
img,
|
||||
current_timestep,
|
||||
cap_feats=txt, # Gemma2的hidden states作为caption features
|
||||
cap_mask=txt_mask.to(dtype=torch.int32), # Gemma2的attention mask
|
||||
cfg_scale=guidance,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
if not do_classifier_free_truncation:
|
||||
noise_pred_uncond = model(
|
||||
img,
|
||||
current_timestep,
|
||||
cap_feats=neg_txt, # Gemma2的hidden states作为caption features
|
||||
cap_mask=neg_txt_mask.to(dtype=torch.int32), # Gemma2的attention mask
|
||||
)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
# apply normalization after classifier-free guidance
|
||||
if cfg_normalization:
|
||||
cond_norm = torch.norm(noise_pred_cond, dim=-1, keepdim=True)
|
||||
noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
||||
noise_pred = noise_pred * (cond_norm / noise_norm)
|
||||
else:
|
||||
noise_pred = noise_pred_cond
|
||||
|
||||
img_dtype = img.dtype
|
||||
|
||||
if img.dtype != img_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
img = img.to(img_dtype)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
noise_pred = -noise_pred
|
||||
img = scheduler.step(noise_pred, t, img, return_dict=False)[0]
|
||||
|
||||
# model.prepare_block_swap_before_forward()
|
||||
return img
|
||||
|
||||
|
||||
@@ -754,3 +886,14 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
|
||||
default=3.0,
|
||||
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_flash_attn",
|
||||
action="store_true",
|
||||
help="Use Flash Attention for the model. / モデルにFlash Attentionを使用する。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--system_prompt",
|
||||
type=str,
|
||||
default="You are an assistant designed to generate high-quality images based on user prompts. <Prompt Start> ",
|
||||
help="System prompt to add to the prompt. / プロンプトに追加するシステムプロンプト。",
|
||||
)
|
||||
|
||||
@@ -20,11 +20,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
MODEL_VERSION_LUMINA_V2 = "lumina2"
|
||||
|
||||
|
||||
def load_lumina_model(
|
||||
ckpt_path: str,
|
||||
dtype: torch.dtype,
|
||||
dtype: Optional[torch.dtype],
|
||||
device: torch.device,
|
||||
disable_mmap: bool = False,
|
||||
use_flash_attn: bool = False,
|
||||
):
|
||||
"""
|
||||
Load the Lumina model from the checkpoint path.
|
||||
@@ -34,22 +36,22 @@ def load_lumina_model(
|
||||
dtype (torch.dtype): The data type for the model.
|
||||
device (torch.device): The device to load the model on.
|
||||
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
|
||||
use_flash_attn (bool, optional): Whether to use flash attention. Defaults to False.
|
||||
|
||||
Returns:
|
||||
model (lumina_models.NextDiT): The loaded model.
|
||||
"""
|
||||
logger.info("Building Lumina")
|
||||
with torch.device("meta"):
|
||||
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner().to(dtype)
|
||||
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn).to(dtype)
|
||||
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
state_dict = load_safetensors(
|
||||
ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype
|
||||
)
|
||||
state_dict = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
|
||||
info = model.load_state_dict(state_dict, strict=False, assign=True)
|
||||
logger.info(f"Loaded Lumina: {info}")
|
||||
return model
|
||||
|
||||
|
||||
def load_ae(
|
||||
ckpt_path: str,
|
||||
dtype: torch.dtype,
|
||||
@@ -74,9 +76,7 @@ def load_ae(
|
||||
ae = flux_models.AutoEncoder(flux_models.configs["schnell"].ae_params).to(dtype)
|
||||
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_safetensors(
|
||||
ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype
|
||||
)
|
||||
sd = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
|
||||
info = ae.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded AE: {info}")
|
||||
return ae
|
||||
@@ -104,37 +104,35 @@ def load_gemma2(
|
||||
"""
|
||||
logger.info("Building Gemma2")
|
||||
GEMMA2_CONFIG = {
|
||||
"_name_or_path": "google/gemma-2-2b",
|
||||
"architectures": [
|
||||
"Gemma2Model"
|
||||
],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"attn_logit_softcapping": 50.0,
|
||||
"bos_token_id": 2,
|
||||
"cache_implementation": "hybrid",
|
||||
"eos_token_id": 1,
|
||||
"final_logit_softcapping": 30.0,
|
||||
"head_dim": 256,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"hidden_size": 2304,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 9216,
|
||||
"max_position_embeddings": 8192,
|
||||
"model_type": "gemma2",
|
||||
"num_attention_heads": 8,
|
||||
"num_hidden_layers": 26,
|
||||
"num_key_value_heads": 4,
|
||||
"pad_token_id": 0,
|
||||
"query_pre_attn_scalar": 256,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_theta": 10000.0,
|
||||
"sliding_window": 4096,
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.44.2",
|
||||
"use_cache": True,
|
||||
"vocab_size": 256000
|
||||
"_name_or_path": "google/gemma-2-2b",
|
||||
"architectures": ["Gemma2Model"],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"attn_logit_softcapping": 50.0,
|
||||
"bos_token_id": 2,
|
||||
"cache_implementation": "hybrid",
|
||||
"eos_token_id": 1,
|
||||
"final_logit_softcapping": 30.0,
|
||||
"head_dim": 256,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"hidden_size": 2304,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 9216,
|
||||
"max_position_embeddings": 8192,
|
||||
"model_type": "gemma2",
|
||||
"num_attention_heads": 8,
|
||||
"num_hidden_layers": 26,
|
||||
"num_key_value_heads": 4,
|
||||
"pad_token_id": 0,
|
||||
"query_pre_attn_scalar": 256,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_theta": 10000.0,
|
||||
"sliding_window": 4096,
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.44.2",
|
||||
"use_cache": True,
|
||||
"vocab_size": 256000,
|
||||
}
|
||||
|
||||
config = Gemma2Config(**GEMMA2_CONFIG)
|
||||
@@ -145,9 +143,7 @@ def load_gemma2(
|
||||
sd = state_dict
|
||||
else:
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_safetensors(
|
||||
ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype
|
||||
)
|
||||
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
|
||||
for key in list(sd.keys()):
|
||||
new_key = key.replace("model.", "")
|
||||
@@ -159,6 +155,7 @@ def load_gemma2(
|
||||
logger.info(f"Loaded Gemma2: {info}")
|
||||
return gemma2
|
||||
|
||||
|
||||
def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
|
||||
"""
|
||||
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
|
||||
@@ -174,6 +171,7 @@ def pack_latents(x: torch.Tensor) -> torch.Tensor:
|
||||
x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
return x
|
||||
|
||||
|
||||
DIFFUSERS_TO_ALPHA_VLLM_MAP = {
|
||||
# Embedding layers
|
||||
"cap_embedder.0.weight": ["time_caption_embed.caption_embedder.0.weight"],
|
||||
@@ -224,9 +222,7 @@ def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict
|
||||
for block_idx in range(num_double_blocks):
|
||||
if str(block_idx) in key:
|
||||
converted = pattern.replace("()", str(block_idx))
|
||||
new_key = key.replace(
|
||||
converted, replacement.replace("()", str(block_idx))
|
||||
)
|
||||
new_key = key.replace(converted, replacement.replace("()", str(block_idx)))
|
||||
break
|
||||
|
||||
if new_key == key:
|
||||
|
||||
@@ -610,6 +610,21 @@ from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
@@ -649,22 +664,49 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
shift: float = 1.0,
|
||||
use_dynamic_shifting=False,
|
||||
base_shift: Optional[float] = 0.5,
|
||||
max_shift: Optional[float] = 1.15,
|
||||
base_image_seq_len: Optional[int] = 256,
|
||||
max_image_seq_len: Optional[int] = 4096,
|
||||
invert_sigmas: bool = False,
|
||||
shift_terminal: Optional[float] = None,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
):
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
||||
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
||||
|
||||
sigmas = timesteps / num_train_timesteps
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
if not use_dynamic_shifting:
|
||||
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
|
||||
self.timesteps = sigmas * num_train_timesteps
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
self._shift = shift
|
||||
|
||||
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigma_min = self.sigmas[-1].item()
|
||||
self.sigma_max = self.sigmas[0].item()
|
||||
|
||||
@property
|
||||
def shift(self):
|
||||
"""
|
||||
The value used for shifting.
|
||||
"""
|
||||
return self._shift
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
@@ -690,6 +732,9 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_shift(self, shift: float):
|
||||
self._shift = shift
|
||||
|
||||
def scale_noise(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
@@ -709,10 +754,31 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
A scaled input sample.
|
||||
"""
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
|
||||
|
||||
if sample.device.type == "mps" and torch.is_floating_point(timestep):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
|
||||
timestep = timestep.to(sample.device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.timesteps.to(sample.device)
|
||||
timestep = timestep.to(sample.device)
|
||||
|
||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
|
||||
elif self.step_index is not None:
|
||||
# add_noise is called after first denoising step (for inpainting)
|
||||
step_indices = [self.step_index] * timestep.shape[0]
|
||||
else:
|
||||
# add noise is called before first denoising step to create initial latent(img2img)
|
||||
step_indices = [self.begin_index] * timestep.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(sample.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sample = sigma * noise + (1.0 - sigma) * sample
|
||||
|
||||
return sample
|
||||
@@ -720,7 +786,37 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _sigma_to_t(self, sigma):
|
||||
return sigma * self.config.num_train_timesteps
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
|
||||
value.
|
||||
|
||||
Reference:
|
||||
https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
|
||||
|
||||
Args:
|
||||
t (`torch.Tensor`):
|
||||
A tensor of timesteps to be stretched and shifted.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
|
||||
"""
|
||||
one_minus_z = 1 - t
|
||||
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
|
||||
stretched_t = 1 - (one_minus_z / scale_factor)
|
||||
return stretched_t
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
mu: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -730,18 +826,49 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
if self.config.use_dynamic_shifting and mu is None:
|
||||
raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
|
||||
|
||||
if sigmas is None:
|
||||
timesteps = np.linspace(
|
||||
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
||||
)
|
||||
|
||||
sigmas = timesteps / self.config.num_train_timesteps
|
||||
else:
|
||||
sigmas = np.array(sigmas).astype(np.float32)
|
||||
num_inference_steps = len(sigmas)
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps)
|
||||
if self.config.use_dynamic_shifting:
|
||||
sigmas = self.time_shift(mu, 1.0, sigmas)
|
||||
else:
|
||||
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
|
||||
|
||||
if self.config.shift_terminal:
|
||||
sigmas = self.stretch_shift_to_terminal(sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
|
||||
sigmas = timesteps / self.config.num_train_timesteps
|
||||
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
||||
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
||||
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
|
||||
if self.config.invert_sigmas:
|
||||
sigmas = 1.0 - sigmas
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
|
||||
else:
|
||||
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
self.sigmas = sigmas
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
@@ -807,7 +934,11 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
|
||||
if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
|
||||
if (
|
||||
isinstance(timestep, int)
|
||||
or isinstance(timestep, torch.IntTensor)
|
||||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
@@ -823,30 +954,10 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample = sample.to(torch.float32)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
|
||||
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
||||
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||
|
||||
noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator)
|
||||
|
||||
eps = noise * s_noise
|
||||
sigma_hat = sigma * (gamma + 1)
|
||||
|
||||
if gamma > 0:
|
||||
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
|
||||
# backwards compatibility
|
||||
|
||||
# if self.config.prediction_type == "vector_field":
|
||||
|
||||
denoised = sample - model_output * sigma
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - denoised) / sigma_hat
|
||||
|
||||
dt = self.sigmas[self.step_index + 1] - sigma_hat
|
||||
|
||||
prev_sample = sample + derivative * dt
|
||||
# Cast sample back to model compatible dtype
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
|
||||
@@ -858,6 +969,86 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
|
||||
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
||||
"""Constructs an exponential noise schedule."""
|
||||
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
||||
def _convert_to_beta(
|
||||
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
||||
) -> torch.Tensor:
|
||||
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
||||
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
sigmas = np.array(
|
||||
[
|
||||
sigma_min + (ppf * (sigma_max - sigma_min))
|
||||
for ppf in [
|
||||
scipy.stats.beta.ppf(timestep, alpha, beta)
|
||||
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
||||
]
|
||||
]
|
||||
)
|
||||
return sigmas
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, List, Optional, Tuple, Union, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -430,9 +430,21 @@ class LatentsCachingStrategy:
|
||||
bucket_reso: Tuple[int, int],
|
||||
npz_path: str,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
apply_alpha_mask: bool,
|
||||
multi_resolution: bool = False,
|
||||
):
|
||||
) -> bool:
|
||||
"""
|
||||
Args:
|
||||
latents_stride: stride of latents
|
||||
bucket_reso: resolution of the bucket
|
||||
npz_path: path to the npz file
|
||||
flip_aug: whether to flip images
|
||||
apply_alpha_mask: whether to apply alpha mask
|
||||
multi_resolution: whether to use multi-resolution latents
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
@@ -451,7 +463,7 @@ class LatentsCachingStrategy:
|
||||
return False
|
||||
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
|
||||
return False
|
||||
if alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
|
||||
if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
@@ -462,22 +474,35 @@ class LatentsCachingStrategy:
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def _default_cache_batch_latents(
|
||||
self,
|
||||
encode_by_vae,
|
||||
vae_device,
|
||||
vae_dtype,
|
||||
encode_by_vae: Callable,
|
||||
vae_device: torch.device,
|
||||
vae_dtype: torch.dtype,
|
||||
image_infos: List,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
apply_alpha_mask: bool,
|
||||
random_crop: bool,
|
||||
multi_resolution: bool = False,
|
||||
):
|
||||
"""
|
||||
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
|
||||
|
||||
Args:
|
||||
encode_by_vae: function to encode images by VAE
|
||||
vae_device: device to use for VAE
|
||||
vae_dtype: dtype to use for VAE
|
||||
image_infos: list of ImageInfo
|
||||
flip_aug: whether to flip images
|
||||
apply_alpha_mask: whether to apply alpha mask
|
||||
random_crop: whether to random crop images
|
||||
multi_resolution: whether to use multi-resolution latents
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
from library import train_util # import here to avoid circular import
|
||||
|
||||
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
|
||||
image_infos, alpha_mask, random_crop
|
||||
image_infos, apply_alpha_mask, random_crop
|
||||
)
|
||||
img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype)
|
||||
|
||||
@@ -519,12 +544,40 @@ class LatentsCachingStrategy:
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""
|
||||
for SD/SDXL
|
||||
|
||||
Args:
|
||||
npz_path (str): Path to the npz file.
|
||||
bucket_reso (Tuple[int, int]): The resolution of the bucket.
|
||||
|
||||
Returns:
|
||||
Tuple[
|
||||
Optional[np.ndarray],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray]
|
||||
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
|
||||
"""
|
||||
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
|
||||
|
||||
def _default_load_latents_from_disk(
|
||||
self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""
|
||||
Args:
|
||||
latents_stride (Optional[int]): Stride for latents. If None, load all latents.
|
||||
npz_path (str): Path to the npz file.
|
||||
bucket_reso (Tuple[int, int]): The resolution of the bucket.
|
||||
|
||||
Returns:
|
||||
Tuple[
|
||||
Optional[np.ndarray],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray]
|
||||
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
|
||||
"""
|
||||
if latents_stride is None:
|
||||
key_reso_suffix = ""
|
||||
else:
|
||||
@@ -552,6 +605,19 @@ class LatentsCachingStrategy:
|
||||
alpha_mask=None,
|
||||
key_reso_suffix="",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
npz_path (str): Path to the npz file.
|
||||
latents_tensor (torch.Tensor): Latent tensor
|
||||
original_size (List[int]): Original size of the image
|
||||
crop_ltrb (List[int]): Crop left top right bottom
|
||||
flipped_latents_tensor (Optional[torch.Tensor]): Flipped latent tensor
|
||||
alpha_mask (Optional[torch.Tensor]): Alpha mask
|
||||
key_reso_suffix (str): Key resolution suffix
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
kwargs = {}
|
||||
|
||||
if os.path.exists(npz_path):
|
||||
|
||||
@@ -3,13 +3,13 @@ import os
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModel, GemmaTokenizerFast
|
||||
from transformers import AutoTokenizer, AutoModel, Gemma2Model, GemmaTokenizerFast
|
||||
from library import train_util
|
||||
from library.strategy_base import (
|
||||
LatentsCachingStrategy,
|
||||
TokenizeStrategy,
|
||||
TextEncodingStrategy,
|
||||
TextEncoderOutputsCachingStrategy
|
||||
TextEncoderOutputsCachingStrategy,
|
||||
)
|
||||
import numpy as np
|
||||
from library.utils import setup_logging
|
||||
@@ -37,21 +37,38 @@ class LuminaTokenizeStrategy(TokenizeStrategy):
|
||||
else:
|
||||
self.max_length = max_length
|
||||
|
||||
def tokenize(self, text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def tokenize(
|
||||
self, text: Union[str, List[str]]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
text (Union[str, List[str]]): Text to tokenize
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
token input ids, attention_masks
|
||||
"""
|
||||
text = [text] if isinstance(text, str) else text
|
||||
encodings = self.tokenizer(
|
||||
text,
|
||||
max_length=self.max_length,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
padding="max_length",
|
||||
pad_to_multiple_of=8,
|
||||
truncation=True,
|
||||
)
|
||||
return [encodings.input_ids, encodings.attention_mask]
|
||||
return (encodings.input_ids, encodings.attention_mask)
|
||||
|
||||
def tokenize_with_weights(
|
||||
self, text: str | List[str]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
text (Union[str, List[str]]): Text to tokenize
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||
token input ids, attention_masks, weights
|
||||
"""
|
||||
# Gemma doesn't support weighted prompts, return uniform weights
|
||||
tokens, attention_masks = self.tokenize(text)
|
||||
weights = [torch.ones_like(t) for t in tokens]
|
||||
@@ -66,9 +83,20 @@ class LuminaTextEncodingStrategy(TextEncodingStrategy):
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens: List[torch.Tensor],
|
||||
tokens: Tuple[torch.Tensor, torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
|
||||
models (List[Any]): Text encoders
|
||||
tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
hidden_states, input_ids, attention_masks
|
||||
"""
|
||||
text_encoder = models[0]
|
||||
assert isinstance(text_encoder, Gemma2Model)
|
||||
input_ids, attention_masks = tokens
|
||||
|
||||
outputs = text_encoder(
|
||||
@@ -84,9 +112,20 @@ class LuminaTextEncodingStrategy(TextEncodingStrategy):
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens: List[torch.Tensor],
|
||||
weights_list: List[torch.Tensor],
|
||||
tokens: Tuple[torch.Tensor, torch.Tensor],
|
||||
weights: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
|
||||
models (List[Any]): Text encoders
|
||||
tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks
|
||||
weights_list (List[torch.Tensor]): Currently unused
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
hidden_states, input_ids, attention_masks
|
||||
"""
|
||||
# For simplicity, use uniform weighting
|
||||
return self.encode_tokens(tokenize_strategy, models, tokens)
|
||||
|
||||
@@ -114,7 +153,14 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
+ LuminaTextEncoderOutputsCachingStrategy.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str):
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
|
||||
"""
|
||||
Args:
|
||||
npz_path (str): Path to the npz file.
|
||||
|
||||
Returns:
|
||||
bool: True if the npz file is expected to be cached.
|
||||
"""
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
@@ -141,7 +187,7 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
Load outputs from a npz file
|
||||
|
||||
Returns:
|
||||
List[np.ndarray]: hidden_state, input_ids, attention_mask
|
||||
List[np.ndarray]: hidden_state, input_ids, attention_mask
|
||||
"""
|
||||
data = np.load(npz_path)
|
||||
hidden_state = data["hidden_state"]
|
||||
@@ -151,53 +197,75 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
|
||||
|
||||
def cache_batch_outputs(
|
||||
self,
|
||||
tokenize_strategy: LuminaTokenizeStrategy,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
text_encoding_strategy: LuminaTextEncodingStrategy,
|
||||
infos: List,
|
||||
):
|
||||
lumina_text_encoding_strategy: LuminaTextEncodingStrategy = (
|
||||
text_encoding_strategy
|
||||
)
|
||||
captions = [info.caption for info in infos]
|
||||
text_encoding_strategy: TextEncodingStrategy,
|
||||
batch: List[train_util.ImageInfo],
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
|
||||
models (List[Any]): Text encoders
|
||||
text_encoding_strategy (LuminaTextEncodingStrategy):
|
||||
infos (List): List of image_info
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy)
|
||||
assert isinstance(tokenize_strategy, LuminaTokenizeStrategy)
|
||||
|
||||
captions = [info.system_prompt or "" + info.caption for info in batch]
|
||||
|
||||
if self.is_weighted:
|
||||
tokens, weights_list = tokenize_strategy.tokenize_with_weights(
|
||||
captions
|
||||
tokens, attention_masks, weights_list = (
|
||||
tokenize_strategy.tokenize_with_weights(captions)
|
||||
)
|
||||
with torch.no_grad():
|
||||
hidden_state, input_ids, attention_masks = lumina_text_encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy, models, tokens, weights_list
|
||||
hidden_state, input_ids, attention_masks = (
|
||||
text_encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy,
|
||||
models,
|
||||
(tokens, attention_masks),
|
||||
weights_list,
|
||||
)
|
||||
)
|
||||
else:
|
||||
tokens = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
hidden_state, input_ids, attention_masks = lumina_text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, models, tokens
|
||||
hidden_state, input_ids, attention_masks = (
|
||||
text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, models, tokens
|
||||
)
|
||||
)
|
||||
|
||||
if hidden_state.dtype != torch.float32:
|
||||
hidden_state = hidden_state.float()
|
||||
|
||||
hidden_state = hidden_state.cpu().numpy()
|
||||
attention_mask = attention_masks.cpu().numpy()
|
||||
input_ids = tokens.cpu().numpy()
|
||||
attention_mask = attention_masks.cpu().numpy() # (B, S)
|
||||
input_ids = input_ids.cpu().numpy() # (B, S)
|
||||
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
for i, info in enumerate(batch):
|
||||
hidden_state_i = hidden_state[i]
|
||||
attention_mask_i = attention_mask[i]
|
||||
input_ids_i = input_ids[i]
|
||||
|
||||
assert info.text_encoder_outputs_npz is not None, "Text encoder cache outputs to disk not found for image {info.image_path}"
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
hidden_state=hidden_state_i,
|
||||
attention_mask=attention_mask_i,
|
||||
input_ids=input_ids_i
|
||||
input_ids=input_ids_i,
|
||||
)
|
||||
else:
|
||||
info.text_encoder_outputs = [hidden_state_i, attention_mask_i, input_ids_i]
|
||||
info.text_encoder_outputs = [
|
||||
hidden_state_i,
|
||||
attention_mask_i,
|
||||
input_ids_i,
|
||||
]
|
||||
|
||||
|
||||
class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
@@ -227,7 +295,14 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
npz_path: str,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
):
|
||||
) -> bool:
|
||||
"""
|
||||
Args:
|
||||
bucket_reso (Tuple[int, int]): The resolution of the bucket.
|
||||
npz_path (str): Path to the npz file.
|
||||
flip_aug (bool): Whether to flip the image.
|
||||
alpha_mask (bool): Whether to apply
|
||||
"""
|
||||
return self._default_is_disk_cached_latents_expected(
|
||||
8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True
|
||||
)
|
||||
@@ -241,6 +316,20 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray],
|
||||
]:
|
||||
"""
|
||||
Args:
|
||||
npz_path (str): Path to the npz file.
|
||||
bucket_reso (Tuple[int, int]): The resolution of the bucket.
|
||||
|
||||
Returns:
|
||||
Tuple[
|
||||
Optional[np.ndarray],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[np.ndarray],
|
||||
Optional[np.ndarray],
|
||||
]: Tuple of latent tensors, attention_mask, input_ids, latents, latents_unet
|
||||
"""
|
||||
return self._default_load_latents_from_disk(
|
||||
8, npz_path, bucket_reso
|
||||
) # support multi-resolution
|
||||
|
||||
@@ -195,7 +195,7 @@ class ImageInfo:
|
||||
self.latents_flipped: Optional[torch.Tensor] = None
|
||||
self.latents_npz: Optional[str] = None # set in cache_latents
|
||||
self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size
|
||||
self.latents_crop_ltrb: Optional[Tuple[int, int]] = (
|
||||
self.latents_crop_ltrb: Optional[Tuple[int, int, int, int]] = (
|
||||
None # crop left top right bottom in original pixel size, not latents size
|
||||
)
|
||||
self.cond_img_path: Optional[str] = None
|
||||
@@ -211,6 +211,8 @@ class ImageInfo:
|
||||
|
||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||
|
||||
self.system_prompt: Optional[str] = None
|
||||
|
||||
|
||||
class BucketManager:
|
||||
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
|
||||
@@ -434,6 +436,7 @@ class BaseSubset:
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
system_prompt: Optional[str] = None
|
||||
) -> None:
|
||||
self.image_dir = image_dir
|
||||
self.alpha_mask = alpha_mask if alpha_mask is not None else False
|
||||
@@ -464,6 +467,8 @@ class BaseSubset:
|
||||
self.validation_seed = validation_seed
|
||||
self.validation_split = validation_split
|
||||
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
|
||||
class DreamBoothSubset(BaseSubset):
|
||||
def __init__(
|
||||
@@ -495,6 +500,7 @@ class DreamBoothSubset(BaseSubset):
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
system_prompt: Optional[str] = None
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
@@ -522,6 +528,7 @@ class DreamBoothSubset(BaseSubset):
|
||||
custom_attributes=custom_attributes,
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
system_prompt=system_prompt
|
||||
)
|
||||
|
||||
self.is_reg = is_reg
|
||||
@@ -564,6 +571,7 @@ class FineTuningSubset(BaseSubset):
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
system_prompt: Optional[str] = None
|
||||
) -> None:
|
||||
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
||||
|
||||
@@ -591,6 +599,7 @@ class FineTuningSubset(BaseSubset):
|
||||
custom_attributes=custom_attributes,
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
system_prompt=system_prompt
|
||||
)
|
||||
|
||||
self.metadata_file = metadata_file
|
||||
@@ -629,6 +638,7 @@ class ControlNetSubset(BaseSubset):
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
system_prompt: Optional[str] = None
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
@@ -656,6 +666,7 @@ class ControlNetSubset(BaseSubset):
|
||||
custom_attributes=custom_attributes,
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
system_prompt=system_prompt
|
||||
)
|
||||
|
||||
self.conditioning_data_dir = conditioning_data_dir
|
||||
@@ -1686,8 +1697,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
text_encoder_outputs_list.append(text_encoder_outputs)
|
||||
|
||||
if tokenization_required:
|
||||
system_prompt = subset.system_prompt or ""
|
||||
caption = self.process_caption(subset, image_info.caption)
|
||||
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension
|
||||
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(system_prompt + caption)] # remove batch dimension
|
||||
# if self.XTI_layers:
|
||||
# caption_layer = []
|
||||
# for layer in self.XTI_layers:
|
||||
@@ -2059,6 +2071,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
num_train_images = 0
|
||||
num_reg_images = 0
|
||||
reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = []
|
||||
|
||||
for subset in subsets:
|
||||
num_repeats = subset.num_repeats if self.is_training_dataset else 1
|
||||
if num_repeats < 1:
|
||||
@@ -2086,7 +2099,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
num_train_images += num_repeats * len(img_paths)
|
||||
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
|
||||
info = ImageInfo(img_path, num_repeats, subset.system_prompt or "" + caption, subset.is_reg, img_path)
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
if subset.is_reg:
|
||||
@@ -2967,7 +2980,7 @@ def trim_and_resize_if_required(
|
||||
# for new_cache_latents
|
||||
def load_images_and_masks_for_caching(
|
||||
image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool
|
||||
) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]:
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]:
|
||||
r"""
|
||||
requires image_infos to have: [absolute_path or image], bucket_reso, resized_size
|
||||
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
import argparse
|
||||
import copy
|
||||
import math
|
||||
import random
|
||||
from typing import Any, Optional, Union, Tuple
|
||||
from typing import Any, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from accelerate import Accelerator
|
||||
|
||||
from library.device_utils import clean_memory_on_device, init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch import Tensor
|
||||
from accelerate import Accelerator
|
||||
|
||||
|
||||
import train_network
|
||||
from library import (
|
||||
lumina_models,
|
||||
@@ -40,10 +40,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
def assert_extra_args(self, args, train_dataset_group, val_dataset_group):
|
||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
|
||||
if (
|
||||
args.cache_text_encoder_outputs_to_disk
|
||||
and not args.cache_text_encoder_outputs
|
||||
):
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
logger.warning("Enabling cache_text_encoder_outputs due to disk caching")
|
||||
args.cache_text_encoder_outputs = True
|
||||
|
||||
@@ -59,17 +56,14 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
model = lumina_util.load_lumina_model(
|
||||
args.pretrained_model_name_or_path,
|
||||
loading_dtype,
|
||||
"cpu",
|
||||
torch.device("cpu"),
|
||||
disable_mmap=args.disable_mmap_load_safetensors,
|
||||
use_flash_attn=args.use_flash_attn,
|
||||
)
|
||||
|
||||
if args.fp8_base:
|
||||
# check dtype of model
|
||||
if (
|
||||
model.dtype == torch.float8_e4m3fnuz
|
||||
or model.dtype == torch.float8_e5m2
|
||||
or model.dtype == torch.float8_e5m2fnuz
|
||||
):
|
||||
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
|
||||
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
|
||||
elif model.dtype == torch.float8_e4m3fn:
|
||||
logger.info("Loaded fp8 Lumina 2 model")
|
||||
@@ -92,17 +86,13 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model
|
||||
|
||||
def get_tokenize_strategy(self, args):
|
||||
return strategy_lumina.LuminaTokenizeStrategy(
|
||||
args.gemma2_max_token_length, args.tokenizer_cache_dir
|
||||
)
|
||||
return strategy_lumina.LuminaTokenizeStrategy(args.gemma2_max_token_length, args.tokenizer_cache_dir)
|
||||
|
||||
def get_tokenizers(self, tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy):
|
||||
return [tokenize_strategy.tokenizer]
|
||||
|
||||
def get_latents_caching_strategy(self, args):
|
||||
return strategy_lumina.LuminaLatentsCachingStrategy(
|
||||
args.cache_latents_to_disk, args.vae_batch_size, False
|
||||
)
|
||||
return strategy_lumina.LuminaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
|
||||
|
||||
def get_text_encoding_strategy(self, args):
|
||||
return strategy_lumina.LuminaTextEncodingStrategy()
|
||||
@@ -144,15 +134,11 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||
logger.info("move text encoders to gpu")
|
||||
text_encoders[0].to(
|
||||
accelerator.device, dtype=weight_dtype
|
||||
) # always not fp8
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
||||
|
||||
if text_encoders[0].dtype == torch.float8_e4m3fn:
|
||||
# if we load fp8 weights, the model is already fp8, so we use it as is
|
||||
self.prepare_text_encoder_fp8(
|
||||
1, text_encoders[1], text_encoders[1].dtype, weight_dtype
|
||||
)
|
||||
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
|
||||
else:
|
||||
# otherwise, we need to convert it to target dtype
|
||||
text_encoders[0].to(weight_dtype)
|
||||
@@ -162,35 +148,36 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
# cache sample prompts
|
||||
if args.sample_prompts is not None:
|
||||
logger.info(
|
||||
f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}"
|
||||
)
|
||||
logger.info(f"cache Text Encoder outputs for sample prompts: {args.sample_prompts}")
|
||||
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
|
||||
assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
|
||||
|
||||
system_prompt = args.system_prompt or ""
|
||||
sample_prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = (
|
||||
{}
|
||||
) # key: prompt, value: text encoder outputs
|
||||
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
for prompt_dict in sample_prompts:
|
||||
prompts = [prompt_dict.get("prompt", ""),
|
||||
prompt_dict.get("negative_prompt", "")]
|
||||
logger.info(
|
||||
f"cache Text Encoder outputs for prompt: {prompts[0]}"
|
||||
)
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompts)
|
||||
sample_prompts_te_outputs[prompts[0]] = (
|
||||
text_encoding_strategy.encode_tokens(
|
||||
prompts = [
|
||||
prompt_dict.get("prompt", ""),
|
||||
prompt_dict.get("negative_prompt", ""),
|
||||
]
|
||||
for prompt in prompts:
|
||||
prompt = system_prompt + prompt
|
||||
if prompt in sample_prompts_te_outputs:
|
||||
continue
|
||||
|
||||
logger.info(f"cache Text Encoder outputs for prompt: {prompt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
sample_prompts_te_outputs[prompt] = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
text_encoders,
|
||||
tokens_and_masks,
|
||||
)
|
||||
)
|
||||
|
||||
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
@@ -235,12 +222,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
# Remaining methods maintain similar structure to flux implementation
|
||||
# with Lumina-specific model calls and strategies
|
||||
|
||||
def get_noise_scheduler(
|
||||
self, args: argparse.Namespace, device: torch.device
|
||||
) -> Any:
|
||||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(
|
||||
num_train_timesteps=1000, shift=args.discrete_flow_shift
|
||||
)
|
||||
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
return noise_scheduler
|
||||
|
||||
@@ -258,26 +241,45 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
text_encoder_conds: Tuple[Tensor, Tensor, Tensor], # (hidden_states, input_ids, attention_masks)
|
||||
text_encoder_conds: Tuple[Tensor, Tensor, Tensor], # (hidden_states, input_ids, attention_masks)
|
||||
dit: lumina_models.NextDiT,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True,
|
||||
):
|
||||
assert isinstance(noise_scheduler, sd3_train_utils.FlowMatchEulerDiscreteScheduler)
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = (
|
||||
flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
||||
)
|
||||
# Sample a random timestep for each image
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
u = lumina_train_util.compute_density_for_timestep_sampling(
|
||||
weighting_scheme=args.weighting_scheme,
|
||||
batch_size=bsz,
|
||||
logit_mean=args.logit_mean,
|
||||
logit_std=args.logit_std,
|
||||
mode_scale=args.mode_scale,
|
||||
)
|
||||
indices = (u * noise_scheduler.config.num_train_timesteps).long()
|
||||
timesteps = noise_scheduler.timesteps[indices].to(device=latents.device)
|
||||
|
||||
# May not need to pack/unpack?
|
||||
# pack latents and get img_ids - 这部分可以保留因为NextDiT也需要packed格式的输入
|
||||
# packed_noisy_model_input = lumina_util.pack_latents(noisy_model_input)
|
||||
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
|
||||
sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)
|
||||
schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
|
||||
timesteps = timesteps.to(accelerator.device)
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < n_dim:
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
return sigma
|
||||
|
||||
# Add noise according to flow matching.
|
||||
# zt = (1 - texp) * x + texp * z1
|
||||
# Lumina2 reverses the lerp i.e., sigma of 1.0 should mean `latents`
|
||||
sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
|
||||
noisy_model_input = (1.0 - sigmas) * noise + sigmas * latents
|
||||
|
||||
# ensure the hidden state will require grad
|
||||
if args.gradient_checkpointing:
|
||||
@@ -289,48 +291,35 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
# Unpack Gemma2 outputs
|
||||
gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds
|
||||
|
||||
def call_dit(img, gemma2_hidden_states, timesteps, gemma2_attn_mask):
|
||||
def call_dit(img, gemma2_hidden_states, gemma2_attn_mask, timesteps):
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# NextDiT forward expects (x, t, cap_feats, cap_mask)
|
||||
model_pred = dit(
|
||||
x=img, # image latents (B, C, H, W)
|
||||
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
||||
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
||||
cap_mask=gemma2_attn_mask.to(
|
||||
dtype=torch.int32
|
||||
), # Gemma2的attention mask
|
||||
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
|
||||
)
|
||||
return model_pred
|
||||
|
||||
model_pred = call_dit(
|
||||
img=noisy_model_input,
|
||||
gemma2_hidden_states=gemma2_hidden_states,
|
||||
timesteps=timesteps,
|
||||
gemma2_attn_mask=gemma2_attn_mask,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
|
||||
# May not need to pack/unpack?
|
||||
# unpack latents
|
||||
# model_pred = lumina_util.unpack_latents(
|
||||
# model_pred, packed_latent_height, packed_latent_width
|
||||
# )
|
||||
|
||||
# apply model prediction type
|
||||
model_pred, weighting = flux_train_utils.apply_model_prediction_type(
|
||||
args, model_pred, noisy_model_input, sigmas
|
||||
)
|
||||
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
|
||||
# flow matching loss: this is different from SD3
|
||||
target = noise - latents
|
||||
# flow matching loss
|
||||
target = latents - noise
|
||||
|
||||
# differential output preservation
|
||||
if "custom_attributes" in batch:
|
||||
diff_output_pr_indices = []
|
||||
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
||||
if (
|
||||
"diff_output_preservation" in custom_attributes
|
||||
and custom_attributes["diff_output_preservation"]
|
||||
):
|
||||
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
||||
diff_output_pr_indices.append(i)
|
||||
|
||||
if len(diff_output_pr_indices) > 0:
|
||||
@@ -338,9 +327,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
with torch.no_grad():
|
||||
model_pred_prior = call_dit(
|
||||
img=noisy_model_input[diff_output_pr_indices],
|
||||
gemma2_hidden_states=gemma2_hidden_states[
|
||||
diff_output_pr_indices
|
||||
],
|
||||
gemma2_hidden_states=gemma2_hidden_states[diff_output_pr_indices],
|
||||
timesteps=timesteps[diff_output_pr_indices],
|
||||
gemma2_attn_mask=(gemma2_attn_mask[diff_output_pr_indices]),
|
||||
)
|
||||
@@ -363,9 +350,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
return loss
|
||||
|
||||
def get_sai_model_spec(self, args):
|
||||
return train_util.get_sai_model_spec(
|
||||
None, args, False, True, False, lumina="lumina2"
|
||||
)
|
||||
return train_util.get_sai_model_spec(None, args, False, True, False, lumina="lumina2")
|
||||
|
||||
def update_metadata(self, metadata, args):
|
||||
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
||||
@@ -384,12 +369,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
||||
text_encoder.embed_tokens.requires_grad_(True)
|
||||
|
||||
def prepare_text_encoder_fp8(
|
||||
self, index, text_encoder, te_weight_dtype, weight_dtype
|
||||
):
|
||||
logger.info(
|
||||
f"prepare Gemma2 for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}"
|
||||
)
|
||||
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
||||
logger.info(f"prepare Gemma2 for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
text_encoder.embed_tokens.to(dtype=weight_dtype)
|
||||
|
||||
@@ -402,12 +383,8 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
# if we doesn't swap blocks, we can move the model to device
|
||||
nextdit = unet
|
||||
assert isinstance(nextdit, lumina_models.NextDiT)
|
||||
nextdit = accelerator.prepare(
|
||||
nextdit, device_placement=[not self.is_swapping_blocks]
|
||||
)
|
||||
accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks(
|
||||
accelerator.device
|
||||
) # reduce peak memory usage
|
||||
nextdit = accelerator.prepare(nextdit, device_placement=[not self.is_swapping_blocks])
|
||||
accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
||||
accelerator.unwrap_model(nextdit).prepare_block_swap_before_forward()
|
||||
|
||||
return nextdit
|
||||
|
||||
@@ -129,7 +129,7 @@ class NetworkTrainer:
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(64)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
def load_target_model(self, args, weight_dtype, accelerator) -> tuple:
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
@@ -354,12 +354,13 @@ class NetworkTrainer:
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
||||
|
||||
|
||||
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
|
||||
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
|
||||
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
|
||||
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch['captions'])
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
|
||||
Reference in New Issue
Block a user