Compare commits

...

10 Commits

Author SHA1 Message Date
Urle Sistiana
d6210845d6 Merge 3bbfa9b258 into f7f971f50d 2026-01-18 15:10:39 +09:00
Kohya S.
f7f971f50d Merge pull request #2251 from kohya-ss/fix-pytest-for-lumina
fix(tests): add ip_noise_gamma args for MockArgs in pytest
2026-01-18 15:09:47 +09:00
Kohya S
c4be615f69 fix(tests): add ip_noise_gamma args for MockArgs in pytest 2026-01-18 15:05:57 +09:00
Kohya S.
e06e063970 Merge pull request #2225 from urlesistiana/sd3_lumina2_ts_fix
fix: lumina 2 timesteps handling
2026-01-18 14:39:04 +09:00
urlesistiana
f7fc7ddda2 fix #2201: lumina 2 timesteps handling 2025-10-13 16:08:28 +08:00
urlesistiana
3bbfa9b258 fixed FeedForward 2025-10-01 18:00:16 +08:00
urlesistiana
3420a6f7d1 check activation_memory_budget value range and accept value 0 2025-10-01 17:55:24 +08:00
urlesistiana
45cab086cc make torch.compile happy
don't compile funcs with complex ops

simplify FeedForward to avoid "cache line invalidated" error
2025-10-01 17:40:12 +08:00
urlesistiana
f25cb8abd1 add --activation_memory_budget to training arguments
remove SDSCRIPTS_TORCH_COMPILE_ACTIVATION_MEMORY_BUDGET env
2025-09-30 22:03:30 +08:00
urlesistiana
c15e6b4f3b feat: add selective torch compile and activation memory budget to Lumina 2 2025-09-30 15:25:13 +08:00
6 changed files with 106 additions and 107 deletions

View File

@@ -20,6 +20,7 @@
# --------------------------------------------------------
import math
import os
from typing import List, Optional, Tuple
from dataclasses import dataclass
@@ -31,6 +32,10 @@ import torch.nn.functional as F
from library import custom_offloading_utils
disable_selective_torch_compile = (
os.getenv("SDSCRIPTS_SELECTIVE_TORCH_COMPILE", "0") == "0"
)
try:
from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
@@ -553,7 +558,7 @@ class JointAttention(nn.Module):
f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}"
)
@torch.compiler.disable(reason="complex ops inside")
def apply_rope(
x_in: torch.Tensor,
freqs_cis: torch.Tensor,
@@ -629,13 +634,10 @@ class FeedForward(nn.Module):
bias=False,
)
nn.init.xavier_uniform_(self.w3.weight)
# @torch.compile
def _forward_silu_gating(self, x1, x3):
return F.silu(x1) * x3
@torch.compile(disable=disable_selective_torch_compile)
def forward(self, x):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
return self.w2(F.silu(self.w1(x))*self.w3(x))
class JointTransformerBlock(GradientCheckpointMixin):
@@ -701,6 +703,7 @@ class JointTransformerBlock(GradientCheckpointMixin):
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)
@torch.compile(disable=disable_selective_torch_compile)
def _forward(
self,
x: torch.Tensor,
@@ -792,6 +795,7 @@ class FinalLayer(GradientCheckpointMixin):
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)
@torch.compile(disable=disable_selective_torch_compile)
def forward(self, x, c):
scale = self.adaLN_modulation(c)
x = modulate(self.norm_final(x), scale)
@@ -812,6 +816,7 @@ class RopeEmbedder:
self.axes_lens = axes_lens
self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
@torch.compiler.disable(reason="complex ops inside")
def __call__(self, ids: torch.Tensor):
device = ids.device
self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
@@ -1224,6 +1229,7 @@ class NextDiT(nn.Module):
return output
@staticmethod
@torch.compiler.disable(reason="complex ops inside")
def precompute_freqs_cis(
dim: List[int],
end: List[int],

View File

@@ -475,11 +475,7 @@ def sample_image_inference(
def time_shift(mu: float, sigma: float, t: torch.Tensor):
# the following implementation was original for t=0: clean / t=1: noise
# Since we adopt the reverse, the 1-t operations are needed
t = 1 - t
t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
t = 1 - t
return t
@@ -802,61 +798,42 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None) -> Tensor
weighting = torch.ones_like(sigmas)
return weighting
# mainly copied from flux_train_utils.get_noisy_model_input_and_timesteps
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Get noisy model input and timesteps.
Args:
args (argparse.Namespace): Arguments.
noise_scheduler (noise_scheduler): Noise scheduler.
latents (Tensor): Latents.
noise (Tensor): Latent noise.
device (torch.device): Device.
dtype (torch.dtype): Data type
Return:
Tuple[Tensor, Tensor, Tensor]:
noisy model input
timesteps
sigmas
"""
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, _, h, w = latents.shape
sigmas = None
assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random t-based noise sampling
# Simple random sigma-based noise sampling
if args.timestep_sampling == "sigmoid":
# https://github.com/XLabs-AI/x-flux/tree/main
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
else:
t = torch.rand((bsz,), device=device)
sigmas = torch.rand((bsz,), device=device)
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents
timesteps = sigmas * num_timesteps
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
logits_norm = (
logits_norm * args.sigmoid_scale
) # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * noise + t * latents
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_timesteps
elif args.timestep_sampling == "nextdit_shift":
t = torch.rand((bsz,), device=device)
sigmas = torch.rand((bsz,), device=device)
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
t = time_shift(mu, 1.0, t)
sigmas = time_shift(mu, 1.0, sigmas)
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents
timesteps = sigmas * num_timesteps
elif args.timestep_sampling == "flux_shift":
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
sigmas = time_shift(mu, 1.0, sigmas)
timesteps = sigmas * num_timesteps
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
@@ -867,14 +844,24 @@ def get_noisy_model_input_and_timesteps(
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
indices = (u * num_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
# Add noise according to flow matching.
sigmas = get_sigmas(
noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype
)
noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise
# Broadcast sigmas to latent shape
sigmas = sigmas.view(-1, 1, 1, 1)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma:
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
if args.ip_noise_gamma_random_strength:
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
else:
ip_noise_gamma = args.ip_noise_gamma
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
else:
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
@@ -1049,10 +1036,10 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--timestep_sampling",
choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"],
choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift", "flux_shift"],
default="shift",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'."
" / タイムステップをサンプリングする方法sigma、random uniform、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid, Flux.1 and NextDIT.1 shifting. Default is 'shift'."
" / タイムステップをサンプリングする方法sigma、random uniform、random normalのsigmoid、sigmoidのシフト、Flux.1、NextDIT.1のシフト。デフォルトは'shift'です。",
)
parser.add_argument(
"--sigmoid_scale",

View File

@@ -3974,6 +3974,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
],
help="dynamo backend type (default is inductor) / dynamoのbackendの種類デフォルトは inductor",
)
parser.add_argument(
"--activation_memory_budget",
type=float,
default=None,
help="activation memory budget setting for torch.compile (range: 0~1). Smaller value saves more memory at cost of speed. If set, use --torch_compile without --gradient_checkpointing is recommended. Requires PyTorch 2.4. / torch.compileのactivation memory budget設定01の値。この値を小さくするとメモリ使用量を節約できますが、処理速度は低下します。この設定を行う場合は、--gradient_checkpointing オプションを指定せずに --torch_compile を使用することをお勧めします。PyTorch 2.4以降が必要です。"
)
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
parser.add_argument(
"--sdpa",
@@ -5506,6 +5512,19 @@ def prepare_accelerator(args: argparse.Namespace):
if args.torch_compile:
dynamo_backend = args.dynamo_backend
if args.activation_memory_budget is not None: # Note: 0 is a valid value.
if 0 <= args.activation_memory_budget <= 1:
logger.info(
f"set torch compile activation memory budget to {args.activation_memory_budget}"
)
torch._functorch.config.activation_memory_budget = ( # type: ignore
args.activation_memory_budget
)
else:
raise ValueError(
"activation_memory_budget must be between 0 and 1 (inclusive)"
)
kwargs_handlers = [
(
InitProcessGroupKwargs(

View File

@@ -743,7 +743,7 @@ def train(args):
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = nextdit(
x=noisy_model_input, # image latents (B, C, H, W)
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
t=1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
cap_mask=gemma2_attn_mask.to(
dtype=torch.int32

View File

@@ -268,7 +268,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
# NextDiT forward expects (x, t, cap_feats, cap_mask)
model_pred = dit(
x=img, # image latents (B, C, H, W)
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
t=1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
)

View File

@@ -19,11 +19,7 @@ from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
def test_batchify():
# Test case with no batch size specified
prompts = [
{"prompt": "test1"},
{"prompt": "test2"},
{"prompt": "test3"}
]
prompts = [{"prompt": "test1"}, {"prompt": "test2"}, {"prompt": "test3"}]
batchified = list(batchify(prompts))
assert len(batchified) == 1
assert len(batchified[0]) == 3
@@ -38,7 +34,7 @@ def test_batchify():
prompts_with_params = [
{"prompt": "test1", "width": 512, "height": 512},
{"prompt": "test2", "width": 512, "height": 512},
{"prompt": "test3", "width": 1024, "height": 1024}
{"prompt": "test3", "width": 1024, "height": 1024},
]
batchified_params = list(batchify(prompts_with_params))
assert len(batchified_params) == 2
@@ -61,7 +57,7 @@ def test_time_shift():
# Test with edge cases
t_edges = torch.tensor([0.0, 1.0])
result_edges = time_shift(1.0, 1.0, t_edges)
# Check that results are bounded within [0, 1]
assert torch.all(result_edges >= 0)
assert torch.all(result_edges <= 1)
@@ -93,10 +89,7 @@ def test_get_schedule():
# Test with shift disabled
unshifted_schedule = get_schedule(num_steps=10, image_seq_len=256, shift=False)
assert torch.allclose(
torch.tensor(unshifted_schedule),
torch.linspace(1, 1/10, 10)
)
assert torch.allclose(torch.tensor(unshifted_schedule), torch.linspace(1, 1 / 10, 10))
def test_compute_density_for_timestep_sampling():
@@ -106,16 +99,12 @@ def test_compute_density_for_timestep_sampling():
assert torch.all((uniform_samples >= 0) & (uniform_samples <= 1))
# Test logit normal sampling
logit_normal_samples = compute_density_for_timestep_sampling(
"logit_normal", batch_size=100, logit_mean=0.0, logit_std=1.0
)
logit_normal_samples = compute_density_for_timestep_sampling("logit_normal", batch_size=100, logit_mean=0.0, logit_std=1.0)
assert len(logit_normal_samples) == 100
assert torch.all((logit_normal_samples >= 0) & (logit_normal_samples <= 1))
# Test mode sampling
mode_samples = compute_density_for_timestep_sampling(
"mode", batch_size=100, mode_scale=0.5
)
mode_samples = compute_density_for_timestep_sampling("mode", batch_size=100, mode_scale=0.5)
assert len(mode_samples) == 100
assert torch.all((mode_samples >= 0) & (mode_samples <= 1))
@@ -123,20 +112,20 @@ def test_compute_density_for_timestep_sampling():
def test_get_sigmas():
# Create a mock noise scheduler
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
device = torch.device('cpu')
device = torch.device("cpu")
# Test with default parameters
timesteps = torch.tensor([100, 500, 900])
sigmas = get_sigmas(scheduler, timesteps, device)
# Check shape and basic properties
assert sigmas.shape[0] == 3
assert torch.all(sigmas >= 0)
# Test with different n_dim
sigmas_4d = get_sigmas(scheduler, timesteps, device, n_dim=4)
assert sigmas_4d.ndim == 4
# Test with different dtype
sigmas_float16 = get_sigmas(scheduler, timesteps, device, dtype=torch.float16)
assert sigmas_float16.dtype == torch.float16
@@ -145,17 +134,17 @@ def test_get_sigmas():
def test_compute_loss_weighting_for_sd3():
# Prepare some mock sigmas
sigmas = torch.tensor([0.1, 0.5, 1.0])
# Test sigma_sqrt weighting
sqrt_weighting = compute_loss_weighting_for_sd3("sigma_sqrt", sigmas)
assert torch.allclose(sqrt_weighting, 1 / (sigmas**2), rtol=1e-5)
# Test cosmap weighting
cosmap_weighting = compute_loss_weighting_for_sd3("cosmap", sigmas)
bot = 1 - 2 * sigmas + 2 * sigmas**2
expected_cosmap = 2 / (math.pi * bot)
assert torch.allclose(cosmap_weighting, expected_cosmap, rtol=1e-5)
# Test default weighting
default_weighting = compute_loss_weighting_for_sd3("unknown", sigmas)
assert torch.all(default_weighting == 1)
@@ -166,22 +155,22 @@ def test_apply_model_prediction_type():
class MockArgs:
model_prediction_type = "raw"
weighting_scheme = "sigma_sqrt"
args = MockArgs()
model_pred = torch.tensor([1.0, 2.0, 3.0])
noisy_model_input = torch.tensor([0.5, 1.0, 1.5])
sigmas = torch.tensor([0.1, 0.5, 1.0])
# Test raw prediction type
raw_pred, raw_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
assert torch.all(raw_pred == model_pred)
assert raw_weighting is None
# Test additive prediction type
args.model_prediction_type = "additive"
additive_pred, _ = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
assert torch.all(additive_pred == model_pred + noisy_model_input)
# Test sigma scaled prediction type
args.model_prediction_type = "sigma_scaled"
sigma_scaled_pred, sigma_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
@@ -192,12 +181,12 @@ def test_apply_model_prediction_type():
def test_retrieve_timesteps():
# Create a mock scheduler
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
# Test with num_inference_steps
timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=50)
assert len(timesteps) == 50
assert n_steps == 50
# Test error handling with simultaneous timesteps and sigmas
with pytest.raises(ValueError):
retrieve_timesteps(scheduler, timesteps=[1, 2, 3], sigmas=[0.1, 0.2, 0.3])
@@ -210,32 +199,30 @@ def test_get_noisy_model_input_and_timesteps():
weighting_scheme = "sigma_sqrt"
sigmoid_scale = 1.0
discrete_flow_shift = 6.0
ip_noise_gamma = True
ip_noise_gamma_random_strength = 0.01
args = MockArgs()
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
device = torch.device('cpu')
device = torch.device("cpu")
# Prepare mock latents and noise
latents = torch.randn(4, 16, 64, 64)
noise = torch.randn_like(latents)
# Test uniform sampling
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
args, scheduler, latents, noise, device, torch.float32
)
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, scheduler, latents, noise, device, torch.float32)
# Validate output shapes and types
assert noisy_input.shape == latents.shape
assert timesteps.shape[0] == latents.shape[0]
assert noisy_input.dtype == torch.float32
assert timesteps.dtype == torch.float32
# Test different sampling methods
sampling_methods = ["sigmoid", "shift", "nextdit_shift"]
for method in sampling_methods:
args.timestep_sampling = method
noisy_input, timesteps, _ = get_noisy_model_input_and_timesteps(
args, scheduler, latents, noise, device, torch.float32
)
noisy_input, timesteps, _ = get_noisy_model_input_and_timesteps(args, scheduler, latents, noise, device, torch.float32)
assert noisy_input.shape == latents.shape
assert timesteps.shape[0] == latents.shape[0]