Merge pull request #2268 from kohya-ss/sd3

merge sd3 to main
This commit is contained in:
Kohya S.
2026-02-16 08:07:29 +09:00
committed by GitHub
6 changed files with 150 additions and 122 deletions

View File

@@ -96,7 +96,7 @@ def prepare_deepspeed_plugin(args: argparse.Namespace):
deepspeed_plugin.deepspeed_config["train_batch_size"] = ( deepspeed_plugin.deepspeed_config["train_batch_size"] = (
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"]) args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
) )
deepspeed_plugin.set_mixed_precision(args.mixed_precision) deepspeed_plugin.set_mixed_precision(args.mixed_precision)
if args.mixed_precision.lower() == "fp16": if args.mixed_precision.lower() == "fp16":
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow. deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
@@ -125,18 +125,18 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
class DeepSpeedWrapper(torch.nn.Module): class DeepSpeedWrapper(torch.nn.Module):
def __init__(self, **kw_models) -> None: def __init__(self, **kw_models) -> None:
super().__init__() super().__init__()
self.models = torch.nn.ModuleDict() self.models = torch.nn.ModuleDict()
wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no" wrap_model_forward_with_torch_autocast = args.mixed_precision != "no"
for key, model in kw_models.items(): for key, model in kw_models.items():
if isinstance(model, list): if isinstance(model, list):
model = torch.nn.ModuleList(model) model = torch.nn.ModuleList(model)
if wrap_model_forward_with_torch_autocast: if wrap_model_forward_with_torch_autocast:
model = self.__wrap_model_with_torch_autocast(model) model = self.__wrap_model_with_torch_autocast(model)
assert isinstance( assert isinstance(
model, torch.nn.Module model, torch.nn.Module
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}" ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
@@ -151,7 +151,7 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
return model return model
def __wrap_model_forward_with_torch_autocast(self, model): def __wrap_model_forward_with_torch_autocast(self, model):
assert hasattr(model, "forward"), f"model must have a forward method." assert hasattr(model, "forward"), f"model must have a forward method."
forward_fn = model.forward forward_fn = model.forward
@@ -161,20 +161,19 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
device_type = model.device.type device_type = model.device.type
except AttributeError: except AttributeError:
logger.warning( logger.warning(
"[DeepSpeed] model.device is not available. Using get_preferred_device() " "[DeepSpeed] model.device is not available. Using get_preferred_device() "
"to determine the device_type for torch.autocast()." "to determine the device_type for torch.autocast()."
) )
device_type = get_preferred_device().type device_type = get_preferred_device().type
with torch.autocast(device_type = device_type): with torch.autocast(device_type=device_type):
return forward_fn(*args, **kwargs) return forward_fn(*args, **kwargs)
model.forward = forward model.forward = forward
return model return model
def get_models(self): def get_models(self):
return self.models return self.models
ds_model = DeepSpeedWrapper(**models) ds_model = DeepSpeedWrapper(**models)
return ds_model return ds_model

View File

@@ -34,18 +34,18 @@ from library import custom_offloading_utils
try: try:
from flash_attn import flash_attn_varlen_func from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
except: except ImportError:
# flash_attn may not be available but it is not required # flash_attn may not be available but it is not required
pass pass
try: try:
from sageattention import sageattn from sageattention import sageattn
except: except ImportError:
pass pass
try: try:
from apex.normalization import FusedRMSNorm as RMSNorm from apex.normalization import FusedRMSNorm as RMSNorm
except: except ImportError:
import warnings import warnings
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
@@ -98,7 +98,7 @@ except:
x_dtype = x.dtype x_dtype = x.dtype
# To handle float8 we need to convert the tensor to float # To handle float8 we need to convert the tensor to float
x = x.float() x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
return ((x * rrms) * self.weight.float()).to(dtype=x_dtype) return ((x * rrms) * self.weight.float()).to(dtype=x_dtype)
@@ -370,7 +370,7 @@ class JointAttention(nn.Module):
if self.use_sage_attn: if self.use_sage_attn:
# Handle GQA (Grouped Query Attention) if needed # Handle GQA (Grouped Query Attention) if needed
n_rep = self.n_local_heads // self.n_local_kv_heads n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1: if n_rep > 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
@@ -379,7 +379,7 @@ class JointAttention(nn.Module):
output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale) output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale)
else: else:
n_rep = self.n_local_heads // self.n_local_kv_heads n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1: if n_rep > 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
@@ -456,51 +456,47 @@ class JointAttention(nn.Module):
bsz = q.shape[0] bsz = q.shape[0]
seqlen = q.shape[1] seqlen = q.shape[1]
# Transpose tensors to match SageAttention's expected format (HND layout) # Transpose to SageAttention's expected HND layout: [batch, heads, seq_len, head_dim]
q_transposed = q.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim] q_transposed = q.permute(0, 2, 1, 3)
k_transposed = k.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim] k_transposed = k.permute(0, 2, 1, 3)
v_transposed = v.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim] v_transposed = v.permute(0, 2, 1, 3)
# Handle masking for SageAttention # Fast path: if all tokens are valid, run batched SageAttention directly
# We need to filter out masked positions - this approach handles variable sequence lengths if x_mask.all():
outputs = [] output = sageattn(
for b in range(bsz): q_transposed, k_transposed, v_transposed,
# Find valid token positions from the mask tensor_layout="HND", is_causal=False, sm_scale=softmax_scale,
valid_indices = torch.nonzero(x_mask[b], as_tuple=False).squeeze(-1) )
if valid_indices.numel() == 0: # output: [batch, heads, seq_len, head_dim] -> [batch, seq_len, heads, head_dim]
# If all tokens are masked, create a zero output output = output.permute(0, 2, 1, 3)
batch_output = torch.zeros( else:
seqlen, self.n_local_heads, self.head_dim, # Slow path: per-batch loop to handle variable-length masking
device=q.device, dtype=q.dtype # SageAttention does not support attention masks natively
) outputs = []
else: for b in range(bsz):
# Extract only valid tokens for this batch valid_indices = x_mask[b].nonzero(as_tuple=True)[0]
batch_q = q_transposed[b, :, valid_indices, :] if valid_indices.numel() == 0:
batch_k = k_transposed[b, :, valid_indices, :] outputs.append(torch.zeros(
batch_v = v_transposed[b, :, valid_indices, :] seqlen, self.n_local_heads, self.head_dim,
device=q.device, dtype=q.dtype,
# Run SageAttention on valid tokens only ))
continue
batch_output_valid = sageattn( batch_output_valid = sageattn(
batch_q.unsqueeze(0), # Add batch dimension back q_transposed[b:b+1, :, valid_indices, :],
batch_k.unsqueeze(0), k_transposed[b:b+1, :, valid_indices, :],
batch_v.unsqueeze(0), v_transposed[b:b+1, :, valid_indices, :],
tensor_layout="HND", tensor_layout="HND", is_causal=False, sm_scale=softmax_scale,
is_causal=False,
sm_scale=softmax_scale
) )
# Create output tensor with zeros for masked positions
batch_output = torch.zeros( batch_output = torch.zeros(
seqlen, self.n_local_heads, self.head_dim, seqlen, self.n_local_heads, self.head_dim,
device=q.device, dtype=q.dtype device=q.device, dtype=q.dtype,
) )
# Place valid outputs back in the right positions
batch_output[valid_indices] = batch_output_valid.squeeze(0).permute(1, 0, 2) batch_output[valid_indices] = batch_output_valid.squeeze(0).permute(1, 0, 2)
outputs.append(batch_output)
outputs.append(batch_output)
output = torch.stack(outputs, dim=0)
# Stack batch outputs and reshape to expected format
output = torch.stack(outputs, dim=0) # [batch, seq_len, heads, head_dim]
except NameError as e: except NameError as e:
raise RuntimeError( raise RuntimeError(
f"Could not load Sage Attention. Please install https://github.com/thu-ml/SageAttention. / Sage Attention を読み込めませんでした。https://github.com/thu-ml/SageAttention をインストールしてください。 / {e}" f"Could not load Sage Attention. Please install https://github.com/thu-ml/SageAttention. / Sage Attention を読み込めませんでした。https://github.com/thu-ml/SageAttention をインストールしてください。 / {e}"
@@ -1113,10 +1109,9 @@ class NextDiT(nn.Module):
x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2) x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)
x_mask = torch.zeros(bsz, image_seq_len, dtype=torch.bool, device=device) # x.shape[1] == image_seq_len after patchify, so this was assigning to itself.
for i in range(bsz): # The mask can be set without a loop since all samples have the same image_seq_len.
x[i, :image_seq_len] = x[i] x_mask = torch.ones(bsz, image_seq_len, dtype=torch.bool, device=device)
x_mask[i, :image_seq_len] = True
x = self.x_embedder(x) x = self.x_embedder(x)
@@ -1389,4 +1384,4 @@ def NextDiT_7B_GQA_patch2_Adaln_Refiner(**kwargs):
axes_dims=[40, 40, 40], axes_dims=[40, 40, 40],
axes_lens=[300, 512, 512], axes_lens=[300, 512, 512],
**kwargs, **kwargs,
) )

View File

@@ -334,32 +334,35 @@ def sample_image_inference(
# No need to add system prompt here, as it has been handled in the tokenize_strategy # No need to add system prompt here, as it has been handled in the tokenize_strategy
# Get sample prompts from cache # Get sample prompts from cache, fallback to live encoding
gemma2_conds = None
neg_gemma2_conds = None
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs: if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
gemma2_conds = sample_prompts_gemma2_outputs[prompt] gemma2_conds = sample_prompts_gemma2_outputs[prompt]
logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}") logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}")
if ( if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs:
sample_prompts_gemma2_outputs
and negative_prompt in sample_prompts_gemma2_outputs
):
neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt] neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt]
logger.info( logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}")
f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}"
)
# Load sample prompts from Gemma 2 # Only encode if not found in cache
if gemma2_model is not None: if gemma2_conds is None and gemma2_model is not None:
tokens_and_masks = tokenize_strategy.tokenize(prompt) tokens_and_masks = tokenize_strategy.tokenize(prompt)
gemma2_conds = encoding_strategy.encode_tokens( gemma2_conds = encoding_strategy.encode_tokens(
tokenize_strategy, gemma2_model, tokens_and_masks tokenize_strategy, gemma2_model, tokens_and_masks
) )
if neg_gemma2_conds is None and gemma2_model is not None:
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True) tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True)
neg_gemma2_conds = encoding_strategy.encode_tokens( neg_gemma2_conds = encoding_strategy.encode_tokens(
tokenize_strategy, gemma2_model, tokens_and_masks tokenize_strategy, gemma2_model, tokens_and_masks
) )
if gemma2_conds is None or neg_gemma2_conds is None:
logger.error(f"Cannot generate sample: no cached outputs and no text encoder available for prompt: {prompt}")
continue
# Unpack Gemma2 outputs # Unpack Gemma2 outputs
gemma2_hidden_states, _, gemma2_attn_mask = gemma2_conds gemma2_hidden_states, _, gemma2_attn_mask = gemma2_conds
neg_gemma2_hidden_states, _, neg_gemma2_attn_mask = neg_gemma2_conds neg_gemma2_hidden_states, _, neg_gemma2_attn_mask = neg_gemma2_conds
@@ -475,6 +478,7 @@ def sample_image_inference(
def time_shift(mu: float, sigma: float, t: torch.Tensor): def time_shift(mu: float, sigma: float, t: torch.Tensor):
"""Apply time shifting to timesteps."""
t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
return t return t
@@ -483,7 +487,7 @@ def get_lin_function(
x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15 x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15
) -> Callable[[float], float]: ) -> Callable[[float], float]:
""" """
Get linear function Get linear function for resolution-dependent shifting.
Args: Args:
image_seq_len, image_seq_len,
@@ -528,6 +532,7 @@ def get_schedule(
mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)( mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)(
image_seq_len image_seq_len
) )
timesteps = torch.clamp(timesteps, min=1e-7).to(timesteps.device)
timesteps = time_shift(mu, 1.0, timesteps) timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist() return timesteps.tolist()
@@ -689,15 +694,15 @@ def denoise(
img_dtype = img.dtype img_dtype = img.dtype
if img.dtype != img_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
img = img.to(img_dtype)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
noise_pred = -noise_pred noise_pred = -noise_pred
img = scheduler.step(noise_pred, t, img, return_dict=False)[0] img = scheduler.step(noise_pred, t, img, return_dict=False)[0]
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
if img.dtype != img_dtype:
if torch.backends.mps.is_available():
img = img.to(img_dtype)
model.prepare_block_swap_before_forward() model.prepare_block_swap_before_forward()
return img return img
@@ -823,6 +828,7 @@ def get_noisy_model_input_and_timesteps(
timesteps = sigmas * num_timesteps timesteps = sigmas * num_timesteps
elif args.timestep_sampling == "nextdit_shift": elif args.timestep_sampling == "nextdit_shift":
sigmas = torch.rand((bsz,), device=device) sigmas = torch.rand((bsz,), device=device)
sigmas = torch.clamp(sigmas, min=1e-7).to(device)
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
sigmas = time_shift(mu, 1.0, sigmas) sigmas = time_shift(mu, 1.0, sigmas)
@@ -831,6 +837,7 @@ def get_noisy_model_input_and_timesteps(
sigmas = torch.randn(bsz, device=device) sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid() sigmas = sigmas.sigmoid()
sigmas = torch.clamp(sigmas, min=1e-7).to(device)
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
sigmas = time_shift(mu, 1.0, sigmas) sigmas = time_shift(mu, 1.0, sigmas)
timesteps = sigmas * num_timesteps timesteps = sigmas * num_timesteps

View File

@@ -370,19 +370,25 @@ def train(args):
grouped_params = [] grouped_params = []
param_group = {} param_group = {}
for group in params_to_optimize: for group in params_to_optimize:
named_parameters = list(nextdit.named_parameters()) named_parameters = [(n, p) for n, p in nextdit.named_parameters() if p.requires_grad]
assert len(named_parameters) == len( assert len(named_parameters) == len(
group["params"] group["params"]
), "number of parameters does not match" ), f"number of trainable parameters ({len(named_parameters)}) does not match optimizer group ({len(group['params'])})"
for p, np in zip(group["params"], named_parameters): for p, np in zip(group["params"], named_parameters):
# determine target layer and block index for each parameter # determine target layer and block index for each parameter
block_type = "other" # double, single or other # Lumina NextDiT architecture:
if np[0].startswith("double_blocks"): # - "layers.{i}.*" : main transformer blocks (e.g. 32 blocks for 2B)
# - "context_refiner.{i}.*" : context refiner blocks (2 blocks)
# - "noise_refiner.{i}.*" : noise refiner blocks (2 blocks)
# - others: t_embedder, cap_embedder, x_embedder, norm_final, final_layer
block_type = "other"
if np[0].startswith("layers."):
block_index = int(np[0].split(".")[1]) block_index = int(np[0].split(".")[1])
block_type = "double" block_type = "main"
elif np[0].startswith("single_blocks"): elif np[0].startswith("context_refiner.") or np[0].startswith("noise_refiner."):
block_index = int(np[0].split(".")[1]) # All refiner blocks (context + noise) grouped together
block_type = "single" block_index = -1
block_type = "refiner"
else: else:
block_index = -1 block_index = -1
@@ -759,7 +765,7 @@ def train(args):
# calculate loss # calculate loss
huber_c = train_util.get_huber_threshold_if_needed( huber_c = train_util.get_huber_threshold_if_needed(
args, timesteps, noise_scheduler args, 1000 - timesteps, noise_scheduler
) )
loss = train_util.conditional_loss( loss = train_util.conditional_loss(
model_pred.float(), target.float(), args.loss_type, "none", huber_c model_pred.float(), target.float(), args.loss_type, "none", huber_c

View File

@@ -43,9 +43,9 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
logger.warning("Enabling cache_text_encoder_outputs due to disk caching") logger.warning("Enabling cache_text_encoder_outputs due to disk caching")
args.cache_text_encoder_outputs = True args.cache_text_encoder_outputs = True
train_dataset_group.verify_bucket_reso_steps(32) train_dataset_group.verify_bucket_reso_steps(16)
if val_dataset_group is not None: if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(32) val_dataset_group.verify_bucket_reso_steps(16)
self.train_gemma2 = not args.network_train_unet_only self.train_gemma2 = not args.network_train_unet_only
@@ -134,13 +134,16 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
# When TE is not be trained, it will not be prepared so we need to use explicit autocast # When TE is not be trained, it will not be prepared so we need to use explicit autocast
logger.info("move text encoders to gpu") logger.info("move text encoders to gpu")
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 # Lumina uses a single text encoder (Gemma2) at index 0.
# Check original dtype BEFORE casting to preserve fp8 detection.
gemma2_original_dtype = text_encoders[0].dtype
text_encoders[0].to(accelerator.device)
if text_encoders[0].dtype == torch.float8_e4m3fn: if gemma2_original_dtype == torch.float8_e4m3fn:
# if we load fp8 weights, the model is already fp8, so we use it as is # Model was loaded as fp8 — apply fp8 optimization
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) self.prepare_text_encoder_fp8(0, text_encoders[0], gemma2_original_dtype, weight_dtype)
else: else:
# otherwise, we need to convert it to target dtype # Otherwise, cast to target dtype
text_encoders[0].to(weight_dtype) text_encoders[0].to(weight_dtype)
with accelerator.autocast(): with accelerator.autocast():

View File

@@ -227,19 +227,16 @@ class LoRAInfModule(LoRAModule):
org_sd["weight"] = weight.to(dtype) org_sd["weight"] = weight.to(dtype)
self.org_module.load_state_dict(org_sd) self.org_module.load_state_dict(org_sd)
else: else:
# split_dims # split_dims: merge each split's LoRA into the correct slice of the fused QKV weight
total_dims = sum(self.split_dims)
for i in range(len(self.split_dims)): for i in range(len(self.split_dims)):
# get up/down weight # get up/down weight
down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim) down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim)
up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank) up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split_dim, rank)
# pad up_weight -> (total_dims, rank) # merge into the correct slice of the fused weight
padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float) start = sum(self.split_dims[:i])
padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight end = sum(self.split_dims[:i + 1])
weight[start:end] += self.multiplier * (up_weight @ down_weight) * self.scale
# merge weight
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
# set weight to org_module # set weight to org_module
org_sd["weight"] = weight.to(dtype) org_sd["weight"] = weight.to(dtype)
@@ -250,6 +247,17 @@ class LoRAInfModule(LoRAModule):
if multiplier is None: if multiplier is None:
multiplier = self.multiplier multiplier = self.multiplier
# Handle split_dims case where lora_down/lora_up are ModuleList
if self.split_dims is not None:
# Each sub-module produces a partial weight; concatenate along output dim
weights = []
for lora_up, lora_down in zip(self.lora_up, self.lora_down):
up_w = lora_up.weight.to(torch.float)
down_w = lora_down.weight.to(torch.float)
weights.append(up_w @ down_w)
weight = self.multiplier * torch.cat(weights, dim=0) * self.scale
return weight
# get up/down weight from module # get up/down weight from module
up_weight = self.lora_up.weight.to(torch.float) up_weight = self.lora_up.weight.to(torch.float)
down_weight = self.lora_down.weight.to(torch.float) down_weight = self.lora_down.weight.to(torch.float)
@@ -409,7 +417,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, wei
weights_sd = load_file(file) weights_sd = load_file(file)
else: else:
weights_sd = torch.load(file, map_location="cpu") weights_sd = torch.load(file, map_location="cpu", weights_only=False)
# get dim/alpha mapping, and train t5xxl # get dim/alpha mapping, and train t5xxl
modules_dim = {} modules_dim = {}
@@ -634,20 +642,30 @@ class LoRANetwork(torch.nn.Module):
skipped_te += skipped skipped_te += skipped
# create LoRA for U-Net # create LoRA for U-Net
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
# Filter by block type using name-based filtering in create_modules
# All block types use JointTransformerBlock, so we filter by module path name
block_filter = None # None means no filtering (train all)
if self.train_blocks == "all": if self.train_blocks == "all":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE block_filter = None
# TODO: limit different blocks
elif self.train_blocks == "transformer": elif self.train_blocks == "transformer":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE block_filter = "layers_" # main transformer blocks: "lora_unet_layers_N_..."
elif self.train_blocks == "refiners":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
elif self.train_blocks == "noise_refiner": elif self.train_blocks == "noise_refiner":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE block_filter = "noise_refiner"
elif self.train_blocks == "cap_refiner": elif self.train_blocks == "context_refiner":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE block_filter = "context_refiner"
elif self.train_blocks == "refiners":
block_filter = None # handled below with two calls
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules) if self.train_blocks == "refiners":
# Refiners = noise_refiner + context_refiner, need two calls
noise_loras, skipped_noise = create_modules(True, unet, target_replace_modules, filter="noise_refiner")
context_loras, skipped_context = create_modules(True, unet, target_replace_modules, filter="context_refiner")
self.unet_loras = noise_loras + context_loras
skipped_un = skipped_noise + skipped_context
else:
self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules, filter=block_filter)
# Handle embedders # Handle embedders
if self.embedder_dims: if self.embedder_dims:
@@ -689,7 +707,7 @@ class LoRANetwork(torch.nn.Module):
weights_sd = load_file(file) weights_sd = load_file(file)
else: else:
weights_sd = torch.load(file, map_location="cpu") weights_sd = torch.load(file, map_location="cpu", weights_only=False)
info = self.load_state_dict(weights_sd, False) info = self.load_state_dict(weights_sd, False)
return info return info
@@ -751,10 +769,10 @@ class LoRANetwork(torch.nn.Module):
state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
new_state_dict = {} new_state_dict = {}
for key in list(state_dict.keys()): for key in list(state_dict.keys()):
if "double" in key and "qkv" in key: if "qkv" in key:
split_dims = [3072] * 3 # Lumina 2B: dim=2304, n_heads=24, n_kv_heads=8, head_dim=96
elif "single" in key and "linear1" in key: # Q=24*96=2304, K=8*96=768, V=8*96=768
split_dims = [3072] * 3 + [12288] split_dims = [2304, 768, 768]
else: else:
new_state_dict[key] = state_dict[key] new_state_dict[key] = state_dict[key]
continue continue
@@ -1035,4 +1053,4 @@ class LoRANetwork(torch.nn.Module):
scalednorm = updown.norm() * ratio scalednorm = updown.norm() * ratio
norms.append(scalednorm.item()) norms.append(scalednorm.item())
return keys_scaled, sum(norms) / len(norms), max(norms) return keys_scaled, sum(norms) / len(norms), max(norms)