mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
@@ -96,7 +96,7 @@ def prepare_deepspeed_plugin(args: argparse.Namespace):
|
|||||||
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
|
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
|
||||||
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
|
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
|
||||||
)
|
)
|
||||||
|
|
||||||
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
||||||
if args.mixed_precision.lower() == "fp16":
|
if args.mixed_precision.lower() == "fp16":
|
||||||
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
|
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
|
||||||
@@ -125,18 +125,18 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
|||||||
class DeepSpeedWrapper(torch.nn.Module):
|
class DeepSpeedWrapper(torch.nn.Module):
|
||||||
def __init__(self, **kw_models) -> None:
|
def __init__(self, **kw_models) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.models = torch.nn.ModuleDict()
|
self.models = torch.nn.ModuleDict()
|
||||||
|
|
||||||
wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no"
|
wrap_model_forward_with_torch_autocast = args.mixed_precision != "no"
|
||||||
|
|
||||||
for key, model in kw_models.items():
|
for key, model in kw_models.items():
|
||||||
if isinstance(model, list):
|
if isinstance(model, list):
|
||||||
model = torch.nn.ModuleList(model)
|
model = torch.nn.ModuleList(model)
|
||||||
|
|
||||||
if wrap_model_forward_with_torch_autocast:
|
if wrap_model_forward_with_torch_autocast:
|
||||||
model = self.__wrap_model_with_torch_autocast(model)
|
model = self.__wrap_model_with_torch_autocast(model)
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
model, torch.nn.Module
|
model, torch.nn.Module
|
||||||
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
||||||
@@ -151,7 +151,7 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def __wrap_model_forward_with_torch_autocast(self, model):
|
def __wrap_model_forward_with_torch_autocast(self, model):
|
||||||
|
|
||||||
assert hasattr(model, "forward"), f"model must have a forward method."
|
assert hasattr(model, "forward"), f"model must have a forward method."
|
||||||
|
|
||||||
forward_fn = model.forward
|
forward_fn = model.forward
|
||||||
@@ -161,20 +161,19 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
|||||||
device_type = model.device.type
|
device_type = model.device.type
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"[DeepSpeed] model.device is not available. Using get_preferred_device() "
|
"[DeepSpeed] model.device is not available. Using get_preferred_device() "
|
||||||
"to determine the device_type for torch.autocast()."
|
"to determine the device_type for torch.autocast()."
|
||||||
)
|
)
|
||||||
device_type = get_preferred_device().type
|
device_type = get_preferred_device().type
|
||||||
|
|
||||||
with torch.autocast(device_type = device_type):
|
with torch.autocast(device_type=device_type):
|
||||||
return forward_fn(*args, **kwargs)
|
return forward_fn(*args, **kwargs)
|
||||||
|
|
||||||
model.forward = forward
|
model.forward = forward
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_models(self):
|
def get_models(self):
|
||||||
return self.models
|
return self.models
|
||||||
|
|
||||||
|
|
||||||
ds_model = DeepSpeedWrapper(**models)
|
ds_model = DeepSpeedWrapper(**models)
|
||||||
return ds_model
|
return ds_model
|
||||||
|
|||||||
@@ -34,18 +34,18 @@ from library import custom_offloading_utils
|
|||||||
try:
|
try:
|
||||||
from flash_attn import flash_attn_varlen_func
|
from flash_attn import flash_attn_varlen_func
|
||||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||||
except:
|
except ImportError:
|
||||||
# flash_attn may not be available but it is not required
|
# flash_attn may not be available but it is not required
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from sageattention import sageattn
|
from sageattention import sageattn
|
||||||
except:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from apex.normalization import FusedRMSNorm as RMSNorm
|
from apex.normalization import FusedRMSNorm as RMSNorm
|
||||||
except:
|
except ImportError:
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
|
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
|
||||||
@@ -98,7 +98,7 @@ except:
|
|||||||
x_dtype = x.dtype
|
x_dtype = x.dtype
|
||||||
# To handle float8 we need to convert the tensor to float
|
# To handle float8 we need to convert the tensor to float
|
||||||
x = x.float()
|
x = x.float()
|
||||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
|
||||||
return ((x * rrms) * self.weight.float()).to(dtype=x_dtype)
|
return ((x * rrms) * self.weight.float()).to(dtype=x_dtype)
|
||||||
|
|
||||||
|
|
||||||
@@ -370,7 +370,7 @@ class JointAttention(nn.Module):
|
|||||||
if self.use_sage_attn:
|
if self.use_sage_attn:
|
||||||
# Handle GQA (Grouped Query Attention) if needed
|
# Handle GQA (Grouped Query Attention) if needed
|
||||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||||
if n_rep >= 1:
|
if n_rep > 1:
|
||||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||||
|
|
||||||
@@ -379,7 +379,7 @@ class JointAttention(nn.Module):
|
|||||||
output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale)
|
output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale)
|
||||||
else:
|
else:
|
||||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||||
if n_rep >= 1:
|
if n_rep > 1:
|
||||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||||
|
|
||||||
@@ -456,51 +456,47 @@ class JointAttention(nn.Module):
|
|||||||
bsz = q.shape[0]
|
bsz = q.shape[0]
|
||||||
seqlen = q.shape[1]
|
seqlen = q.shape[1]
|
||||||
|
|
||||||
# Transpose tensors to match SageAttention's expected format (HND layout)
|
# Transpose to SageAttention's expected HND layout: [batch, heads, seq_len, head_dim]
|
||||||
q_transposed = q.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
|
q_transposed = q.permute(0, 2, 1, 3)
|
||||||
k_transposed = k.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
|
k_transposed = k.permute(0, 2, 1, 3)
|
||||||
v_transposed = v.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
|
v_transposed = v.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
# Handle masking for SageAttention
|
# Fast path: if all tokens are valid, run batched SageAttention directly
|
||||||
# We need to filter out masked positions - this approach handles variable sequence lengths
|
if x_mask.all():
|
||||||
outputs = []
|
output = sageattn(
|
||||||
for b in range(bsz):
|
q_transposed, k_transposed, v_transposed,
|
||||||
# Find valid token positions from the mask
|
tensor_layout="HND", is_causal=False, sm_scale=softmax_scale,
|
||||||
valid_indices = torch.nonzero(x_mask[b], as_tuple=False).squeeze(-1)
|
)
|
||||||
if valid_indices.numel() == 0:
|
# output: [batch, heads, seq_len, head_dim] -> [batch, seq_len, heads, head_dim]
|
||||||
# If all tokens are masked, create a zero output
|
output = output.permute(0, 2, 1, 3)
|
||||||
batch_output = torch.zeros(
|
else:
|
||||||
seqlen, self.n_local_heads, self.head_dim,
|
# Slow path: per-batch loop to handle variable-length masking
|
||||||
device=q.device, dtype=q.dtype
|
# SageAttention does not support attention masks natively
|
||||||
)
|
outputs = []
|
||||||
else:
|
for b in range(bsz):
|
||||||
# Extract only valid tokens for this batch
|
valid_indices = x_mask[b].nonzero(as_tuple=True)[0]
|
||||||
batch_q = q_transposed[b, :, valid_indices, :]
|
if valid_indices.numel() == 0:
|
||||||
batch_k = k_transposed[b, :, valid_indices, :]
|
outputs.append(torch.zeros(
|
||||||
batch_v = v_transposed[b, :, valid_indices, :]
|
seqlen, self.n_local_heads, self.head_dim,
|
||||||
|
device=q.device, dtype=q.dtype,
|
||||||
# Run SageAttention on valid tokens only
|
))
|
||||||
|
continue
|
||||||
|
|
||||||
batch_output_valid = sageattn(
|
batch_output_valid = sageattn(
|
||||||
batch_q.unsqueeze(0), # Add batch dimension back
|
q_transposed[b:b+1, :, valid_indices, :],
|
||||||
batch_k.unsqueeze(0),
|
k_transposed[b:b+1, :, valid_indices, :],
|
||||||
batch_v.unsqueeze(0),
|
v_transposed[b:b+1, :, valid_indices, :],
|
||||||
tensor_layout="HND",
|
tensor_layout="HND", is_causal=False, sm_scale=softmax_scale,
|
||||||
is_causal=False,
|
|
||||||
sm_scale=softmax_scale
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create output tensor with zeros for masked positions
|
|
||||||
batch_output = torch.zeros(
|
batch_output = torch.zeros(
|
||||||
seqlen, self.n_local_heads, self.head_dim,
|
seqlen, self.n_local_heads, self.head_dim,
|
||||||
device=q.device, dtype=q.dtype
|
device=q.device, dtype=q.dtype,
|
||||||
)
|
)
|
||||||
# Place valid outputs back in the right positions
|
|
||||||
batch_output[valid_indices] = batch_output_valid.squeeze(0).permute(1, 0, 2)
|
batch_output[valid_indices] = batch_output_valid.squeeze(0).permute(1, 0, 2)
|
||||||
|
outputs.append(batch_output)
|
||||||
outputs.append(batch_output)
|
|
||||||
|
output = torch.stack(outputs, dim=0)
|
||||||
# Stack batch outputs and reshape to expected format
|
|
||||||
output = torch.stack(outputs, dim=0) # [batch, seq_len, heads, head_dim]
|
|
||||||
except NameError as e:
|
except NameError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Could not load Sage Attention. Please install https://github.com/thu-ml/SageAttention. / Sage Attention を読み込めませんでした。https://github.com/thu-ml/SageAttention をインストールしてください。 / {e}"
|
f"Could not load Sage Attention. Please install https://github.com/thu-ml/SageAttention. / Sage Attention を読み込めませんでした。https://github.com/thu-ml/SageAttention をインストールしてください。 / {e}"
|
||||||
@@ -1113,10 +1109,9 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)
|
x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)
|
||||||
|
|
||||||
x_mask = torch.zeros(bsz, image_seq_len, dtype=torch.bool, device=device)
|
# x.shape[1] == image_seq_len after patchify, so this was assigning to itself.
|
||||||
for i in range(bsz):
|
# The mask can be set without a loop since all samples have the same image_seq_len.
|
||||||
x[i, :image_seq_len] = x[i]
|
x_mask = torch.ones(bsz, image_seq_len, dtype=torch.bool, device=device)
|
||||||
x_mask[i, :image_seq_len] = True
|
|
||||||
|
|
||||||
x = self.x_embedder(x)
|
x = self.x_embedder(x)
|
||||||
|
|
||||||
@@ -1389,4 +1384,4 @@ def NextDiT_7B_GQA_patch2_Adaln_Refiner(**kwargs):
|
|||||||
axes_dims=[40, 40, 40],
|
axes_dims=[40, 40, 40],
|
||||||
axes_lens=[300, 512, 512],
|
axes_lens=[300, 512, 512],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@@ -334,32 +334,35 @@ def sample_image_inference(
|
|||||||
|
|
||||||
# No need to add system prompt here, as it has been handled in the tokenize_strategy
|
# No need to add system prompt here, as it has been handled in the tokenize_strategy
|
||||||
|
|
||||||
# Get sample prompts from cache
|
# Get sample prompts from cache, fallback to live encoding
|
||||||
|
gemma2_conds = None
|
||||||
|
neg_gemma2_conds = None
|
||||||
|
|
||||||
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
|
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
|
||||||
gemma2_conds = sample_prompts_gemma2_outputs[prompt]
|
gemma2_conds = sample_prompts_gemma2_outputs[prompt]
|
||||||
logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}")
|
logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}")
|
||||||
|
|
||||||
if (
|
if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs:
|
||||||
sample_prompts_gemma2_outputs
|
|
||||||
and negative_prompt in sample_prompts_gemma2_outputs
|
|
||||||
):
|
|
||||||
neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt]
|
neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt]
|
||||||
logger.info(
|
logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}")
|
||||||
f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load sample prompts from Gemma 2
|
# Only encode if not found in cache
|
||||||
if gemma2_model is not None:
|
if gemma2_conds is None and gemma2_model is not None:
|
||||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||||
gemma2_conds = encoding_strategy.encode_tokens(
|
gemma2_conds = encoding_strategy.encode_tokens(
|
||||||
tokenize_strategy, gemma2_model, tokens_and_masks
|
tokenize_strategy, gemma2_model, tokens_and_masks
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if neg_gemma2_conds is None and gemma2_model is not None:
|
||||||
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True)
|
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True)
|
||||||
neg_gemma2_conds = encoding_strategy.encode_tokens(
|
neg_gemma2_conds = encoding_strategy.encode_tokens(
|
||||||
tokenize_strategy, gemma2_model, tokens_and_masks
|
tokenize_strategy, gemma2_model, tokens_and_masks
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if gemma2_conds is None or neg_gemma2_conds is None:
|
||||||
|
logger.error(f"Cannot generate sample: no cached outputs and no text encoder available for prompt: {prompt}")
|
||||||
|
continue
|
||||||
|
|
||||||
# Unpack Gemma2 outputs
|
# Unpack Gemma2 outputs
|
||||||
gemma2_hidden_states, _, gemma2_attn_mask = gemma2_conds
|
gemma2_hidden_states, _, gemma2_attn_mask = gemma2_conds
|
||||||
neg_gemma2_hidden_states, _, neg_gemma2_attn_mask = neg_gemma2_conds
|
neg_gemma2_hidden_states, _, neg_gemma2_attn_mask = neg_gemma2_conds
|
||||||
@@ -475,6 +478,7 @@ def sample_image_inference(
|
|||||||
|
|
||||||
|
|
||||||
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
||||||
|
"""Apply time shifting to timesteps."""
|
||||||
t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||||
return t
|
return t
|
||||||
|
|
||||||
@@ -483,7 +487,7 @@ def get_lin_function(
|
|||||||
x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15
|
x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15
|
||||||
) -> Callable[[float], float]:
|
) -> Callable[[float], float]:
|
||||||
"""
|
"""
|
||||||
Get linear function
|
Get linear function for resolution-dependent shifting.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image_seq_len,
|
image_seq_len,
|
||||||
@@ -528,6 +532,7 @@ def get_schedule(
|
|||||||
mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)(
|
mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)(
|
||||||
image_seq_len
|
image_seq_len
|
||||||
)
|
)
|
||||||
|
timesteps = torch.clamp(timesteps, min=1e-7).to(timesteps.device)
|
||||||
timesteps = time_shift(mu, 1.0, timesteps)
|
timesteps = time_shift(mu, 1.0, timesteps)
|
||||||
|
|
||||||
return timesteps.tolist()
|
return timesteps.tolist()
|
||||||
@@ -689,15 +694,15 @@ def denoise(
|
|||||||
|
|
||||||
img_dtype = img.dtype
|
img_dtype = img.dtype
|
||||||
|
|
||||||
if img.dtype != img_dtype:
|
|
||||||
if torch.backends.mps.is_available():
|
|
||||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
|
||||||
img = img.to(img_dtype)
|
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
noise_pred = -noise_pred
|
noise_pred = -noise_pred
|
||||||
img = scheduler.step(noise_pred, t, img, return_dict=False)[0]
|
img = scheduler.step(noise_pred, t, img, return_dict=False)[0]
|
||||||
|
|
||||||
|
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||||
|
if img.dtype != img_dtype:
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
img = img.to(img_dtype)
|
||||||
|
|
||||||
model.prepare_block_swap_before_forward()
|
model.prepare_block_swap_before_forward()
|
||||||
return img
|
return img
|
||||||
|
|
||||||
@@ -823,6 +828,7 @@ def get_noisy_model_input_and_timesteps(
|
|||||||
timesteps = sigmas * num_timesteps
|
timesteps = sigmas * num_timesteps
|
||||||
elif args.timestep_sampling == "nextdit_shift":
|
elif args.timestep_sampling == "nextdit_shift":
|
||||||
sigmas = torch.rand((bsz,), device=device)
|
sigmas = torch.rand((bsz,), device=device)
|
||||||
|
sigmas = torch.clamp(sigmas, min=1e-7).to(device)
|
||||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
|
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
|
||||||
sigmas = time_shift(mu, 1.0, sigmas)
|
sigmas = time_shift(mu, 1.0, sigmas)
|
||||||
|
|
||||||
@@ -831,6 +837,7 @@ def get_noisy_model_input_and_timesteps(
|
|||||||
sigmas = torch.randn(bsz, device=device)
|
sigmas = torch.randn(bsz, device=device)
|
||||||
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||||
sigmas = sigmas.sigmoid()
|
sigmas = sigmas.sigmoid()
|
||||||
|
sigmas = torch.clamp(sigmas, min=1e-7).to(device)
|
||||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
|
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
|
||||||
sigmas = time_shift(mu, 1.0, sigmas)
|
sigmas = time_shift(mu, 1.0, sigmas)
|
||||||
timesteps = sigmas * num_timesteps
|
timesteps = sigmas * num_timesteps
|
||||||
|
|||||||
@@ -370,19 +370,25 @@ def train(args):
|
|||||||
grouped_params = []
|
grouped_params = []
|
||||||
param_group = {}
|
param_group = {}
|
||||||
for group in params_to_optimize:
|
for group in params_to_optimize:
|
||||||
named_parameters = list(nextdit.named_parameters())
|
named_parameters = [(n, p) for n, p in nextdit.named_parameters() if p.requires_grad]
|
||||||
assert len(named_parameters) == len(
|
assert len(named_parameters) == len(
|
||||||
group["params"]
|
group["params"]
|
||||||
), "number of parameters does not match"
|
), f"number of trainable parameters ({len(named_parameters)}) does not match optimizer group ({len(group['params'])})"
|
||||||
for p, np in zip(group["params"], named_parameters):
|
for p, np in zip(group["params"], named_parameters):
|
||||||
# determine target layer and block index for each parameter
|
# determine target layer and block index for each parameter
|
||||||
block_type = "other" # double, single or other
|
# Lumina NextDiT architecture:
|
||||||
if np[0].startswith("double_blocks"):
|
# - "layers.{i}.*" : main transformer blocks (e.g. 32 blocks for 2B)
|
||||||
|
# - "context_refiner.{i}.*" : context refiner blocks (2 blocks)
|
||||||
|
# - "noise_refiner.{i}.*" : noise refiner blocks (2 blocks)
|
||||||
|
# - others: t_embedder, cap_embedder, x_embedder, norm_final, final_layer
|
||||||
|
block_type = "other"
|
||||||
|
if np[0].startswith("layers."):
|
||||||
block_index = int(np[0].split(".")[1])
|
block_index = int(np[0].split(".")[1])
|
||||||
block_type = "double"
|
block_type = "main"
|
||||||
elif np[0].startswith("single_blocks"):
|
elif np[0].startswith("context_refiner.") or np[0].startswith("noise_refiner."):
|
||||||
block_index = int(np[0].split(".")[1])
|
# All refiner blocks (context + noise) grouped together
|
||||||
block_type = "single"
|
block_index = -1
|
||||||
|
block_type = "refiner"
|
||||||
else:
|
else:
|
||||||
block_index = -1
|
block_index = -1
|
||||||
|
|
||||||
@@ -759,7 +765,7 @@ def train(args):
|
|||||||
|
|
||||||
# calculate loss
|
# calculate loss
|
||||||
huber_c = train_util.get_huber_threshold_if_needed(
|
huber_c = train_util.get_huber_threshold_if_needed(
|
||||||
args, timesteps, noise_scheduler
|
args, 1000 - timesteps, noise_scheduler
|
||||||
)
|
)
|
||||||
loss = train_util.conditional_loss(
|
loss = train_util.conditional_loss(
|
||||||
model_pred.float(), target.float(), args.loss_type, "none", huber_c
|
model_pred.float(), target.float(), args.loss_type, "none", huber_c
|
||||||
|
|||||||
@@ -43,9 +43,9 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
logger.warning("Enabling cache_text_encoder_outputs due to disk caching")
|
logger.warning("Enabling cache_text_encoder_outputs due to disk caching")
|
||||||
args.cache_text_encoder_outputs = True
|
args.cache_text_encoder_outputs = True
|
||||||
|
|
||||||
train_dataset_group.verify_bucket_reso_steps(32)
|
train_dataset_group.verify_bucket_reso_steps(16)
|
||||||
if val_dataset_group is not None:
|
if val_dataset_group is not None:
|
||||||
val_dataset_group.verify_bucket_reso_steps(32)
|
val_dataset_group.verify_bucket_reso_steps(16)
|
||||||
|
|
||||||
self.train_gemma2 = not args.network_train_unet_only
|
self.train_gemma2 = not args.network_train_unet_only
|
||||||
|
|
||||||
@@ -134,13 +134,16 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
|
|
||||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||||
logger.info("move text encoders to gpu")
|
logger.info("move text encoders to gpu")
|
||||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
# Lumina uses a single text encoder (Gemma2) at index 0.
|
||||||
|
# Check original dtype BEFORE casting to preserve fp8 detection.
|
||||||
|
gemma2_original_dtype = text_encoders[0].dtype
|
||||||
|
text_encoders[0].to(accelerator.device)
|
||||||
|
|
||||||
if text_encoders[0].dtype == torch.float8_e4m3fn:
|
if gemma2_original_dtype == torch.float8_e4m3fn:
|
||||||
# if we load fp8 weights, the model is already fp8, so we use it as is
|
# Model was loaded as fp8 — apply fp8 optimization
|
||||||
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
|
self.prepare_text_encoder_fp8(0, text_encoders[0], gemma2_original_dtype, weight_dtype)
|
||||||
else:
|
else:
|
||||||
# otherwise, we need to convert it to target dtype
|
# Otherwise, cast to target dtype
|
||||||
text_encoders[0].to(weight_dtype)
|
text_encoders[0].to(weight_dtype)
|
||||||
|
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
|
|||||||
@@ -227,19 +227,16 @@ class LoRAInfModule(LoRAModule):
|
|||||||
org_sd["weight"] = weight.to(dtype)
|
org_sd["weight"] = weight.to(dtype)
|
||||||
self.org_module.load_state_dict(org_sd)
|
self.org_module.load_state_dict(org_sd)
|
||||||
else:
|
else:
|
||||||
# split_dims
|
# split_dims: merge each split's LoRA into the correct slice of the fused QKV weight
|
||||||
total_dims = sum(self.split_dims)
|
|
||||||
for i in range(len(self.split_dims)):
|
for i in range(len(self.split_dims)):
|
||||||
# get up/down weight
|
# get up/down weight
|
||||||
down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim)
|
down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim)
|
||||||
up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank)
|
up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split_dim, rank)
|
||||||
|
|
||||||
# pad up_weight -> (total_dims, rank)
|
# merge into the correct slice of the fused weight
|
||||||
padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float)
|
start = sum(self.split_dims[:i])
|
||||||
padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight
|
end = sum(self.split_dims[:i + 1])
|
||||||
|
weight[start:end] += self.multiplier * (up_weight @ down_weight) * self.scale
|
||||||
# merge weight
|
|
||||||
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
|
||||||
|
|
||||||
# set weight to org_module
|
# set weight to org_module
|
||||||
org_sd["weight"] = weight.to(dtype)
|
org_sd["weight"] = weight.to(dtype)
|
||||||
@@ -250,6 +247,17 @@ class LoRAInfModule(LoRAModule):
|
|||||||
if multiplier is None:
|
if multiplier is None:
|
||||||
multiplier = self.multiplier
|
multiplier = self.multiplier
|
||||||
|
|
||||||
|
# Handle split_dims case where lora_down/lora_up are ModuleList
|
||||||
|
if self.split_dims is not None:
|
||||||
|
# Each sub-module produces a partial weight; concatenate along output dim
|
||||||
|
weights = []
|
||||||
|
for lora_up, lora_down in zip(self.lora_up, self.lora_down):
|
||||||
|
up_w = lora_up.weight.to(torch.float)
|
||||||
|
down_w = lora_down.weight.to(torch.float)
|
||||||
|
weights.append(up_w @ down_w)
|
||||||
|
weight = self.multiplier * torch.cat(weights, dim=0) * self.scale
|
||||||
|
return weight
|
||||||
|
|
||||||
# get up/down weight from module
|
# get up/down weight from module
|
||||||
up_weight = self.lora_up.weight.to(torch.float)
|
up_weight = self.lora_up.weight.to(torch.float)
|
||||||
down_weight = self.lora_down.weight.to(torch.float)
|
down_weight = self.lora_down.weight.to(torch.float)
|
||||||
@@ -409,7 +417,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, wei
|
|||||||
|
|
||||||
weights_sd = load_file(file)
|
weights_sd = load_file(file)
|
||||||
else:
|
else:
|
||||||
weights_sd = torch.load(file, map_location="cpu")
|
weights_sd = torch.load(file, map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
# get dim/alpha mapping, and train t5xxl
|
# get dim/alpha mapping, and train t5xxl
|
||||||
modules_dim = {}
|
modules_dim = {}
|
||||||
@@ -634,20 +642,30 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
skipped_te += skipped
|
skipped_te += skipped
|
||||||
|
|
||||||
# create LoRA for U-Net
|
# create LoRA for U-Net
|
||||||
|
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||||
|
# Filter by block type using name-based filtering in create_modules
|
||||||
|
# All block types use JointTransformerBlock, so we filter by module path name
|
||||||
|
block_filter = None # None means no filtering (train all)
|
||||||
if self.train_blocks == "all":
|
if self.train_blocks == "all":
|
||||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
block_filter = None
|
||||||
# TODO: limit different blocks
|
|
||||||
elif self.train_blocks == "transformer":
|
elif self.train_blocks == "transformer":
|
||||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
block_filter = "layers_" # main transformer blocks: "lora_unet_layers_N_..."
|
||||||
elif self.train_blocks == "refiners":
|
|
||||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
|
||||||
elif self.train_blocks == "noise_refiner":
|
elif self.train_blocks == "noise_refiner":
|
||||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
block_filter = "noise_refiner"
|
||||||
elif self.train_blocks == "cap_refiner":
|
elif self.train_blocks == "context_refiner":
|
||||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
block_filter = "context_refiner"
|
||||||
|
elif self.train_blocks == "refiners":
|
||||||
|
block_filter = None # handled below with two calls
|
||||||
|
|
||||||
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
|
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
|
||||||
self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules)
|
if self.train_blocks == "refiners":
|
||||||
|
# Refiners = noise_refiner + context_refiner, need two calls
|
||||||
|
noise_loras, skipped_noise = create_modules(True, unet, target_replace_modules, filter="noise_refiner")
|
||||||
|
context_loras, skipped_context = create_modules(True, unet, target_replace_modules, filter="context_refiner")
|
||||||
|
self.unet_loras = noise_loras + context_loras
|
||||||
|
skipped_un = skipped_noise + skipped_context
|
||||||
|
else:
|
||||||
|
self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules, filter=block_filter)
|
||||||
|
|
||||||
# Handle embedders
|
# Handle embedders
|
||||||
if self.embedder_dims:
|
if self.embedder_dims:
|
||||||
@@ -689,7 +707,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
weights_sd = load_file(file)
|
weights_sd = load_file(file)
|
||||||
else:
|
else:
|
||||||
weights_sd = torch.load(file, map_location="cpu")
|
weights_sd = torch.load(file, map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
info = self.load_state_dict(weights_sd, False)
|
info = self.load_state_dict(weights_sd, False)
|
||||||
return info
|
return info
|
||||||
@@ -751,10 +769,10 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
for key in list(state_dict.keys()):
|
for key in list(state_dict.keys()):
|
||||||
if "double" in key and "qkv" in key:
|
if "qkv" in key:
|
||||||
split_dims = [3072] * 3
|
# Lumina 2B: dim=2304, n_heads=24, n_kv_heads=8, head_dim=96
|
||||||
elif "single" in key and "linear1" in key:
|
# Q=24*96=2304, K=8*96=768, V=8*96=768
|
||||||
split_dims = [3072] * 3 + [12288]
|
split_dims = [2304, 768, 768]
|
||||||
else:
|
else:
|
||||||
new_state_dict[key] = state_dict[key]
|
new_state_dict[key] = state_dict[key]
|
||||||
continue
|
continue
|
||||||
@@ -1035,4 +1053,4 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
scalednorm = updown.norm() * ratio
|
scalednorm = updown.norm() * ratio
|
||||||
norms.append(scalednorm.item())
|
norms.append(scalednorm.item())
|
||||||
|
|
||||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||||
Reference in New Issue
Block a user