Merge branch 'sd3' into feat-support-lokr-loha

This commit is contained in:
Kohya S
2026-02-23 15:08:56 +09:00
5 changed files with 137 additions and 108 deletions

View File

@@ -227,19 +227,16 @@ class LoRAInfModule(LoRAModule):
org_sd["weight"] = weight.to(dtype)
self.org_module.load_state_dict(org_sd)
else:
# split_dims
total_dims = sum(self.split_dims)
# split_dims: merge each split's LoRA into the correct slice of the fused QKV weight
for i in range(len(self.split_dims)):
# get up/down weight
down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim)
up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank)
up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split_dim, rank)
# pad up_weight -> (total_dims, rank)
padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float)
padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight
# merge weight
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
# merge into the correct slice of the fused weight
start = sum(self.split_dims[:i])
end = sum(self.split_dims[:i + 1])
weight[start:end] += self.multiplier * (up_weight @ down_weight) * self.scale
# set weight to org_module
org_sd["weight"] = weight.to(dtype)
@@ -250,6 +247,17 @@ class LoRAInfModule(LoRAModule):
if multiplier is None:
multiplier = self.multiplier
# Handle split_dims case where lora_down/lora_up are ModuleList
if self.split_dims is not None:
# Each sub-module produces a partial weight; concatenate along output dim
weights = []
for lora_up, lora_down in zip(self.lora_up, self.lora_down):
up_w = lora_up.weight.to(torch.float)
down_w = lora_down.weight.to(torch.float)
weights.append(up_w @ down_w)
weight = self.multiplier * torch.cat(weights, dim=0) * self.scale
return weight
# get up/down weight from module
up_weight = self.lora_up.weight.to(torch.float)
down_weight = self.lora_down.weight.to(torch.float)
@@ -409,7 +417,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, wei
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
weights_sd = torch.load(file, map_location="cpu", weights_only=False)
# get dim/alpha mapping, and train t5xxl
modules_dim = {}
@@ -634,20 +642,30 @@ class LoRANetwork(torch.nn.Module):
skipped_te += skipped
# create LoRA for U-Net
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
# Filter by block type using name-based filtering in create_modules
# All block types use JointTransformerBlock, so we filter by module path name
block_filter = None # None means no filtering (train all)
if self.train_blocks == "all":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
# TODO: limit different blocks
block_filter = None
elif self.train_blocks == "transformer":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
elif self.train_blocks == "refiners":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
block_filter = "layers_" # main transformer blocks: "lora_unet_layers_N_..."
elif self.train_blocks == "noise_refiner":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
elif self.train_blocks == "cap_refiner":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
block_filter = "noise_refiner"
elif self.train_blocks == "context_refiner":
block_filter = "context_refiner"
elif self.train_blocks == "refiners":
block_filter = None # handled below with two calls
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules)
if self.train_blocks == "refiners":
# Refiners = noise_refiner + context_refiner, need two calls
noise_loras, skipped_noise = create_modules(True, unet, target_replace_modules, filter="noise_refiner")
context_loras, skipped_context = create_modules(True, unet, target_replace_modules, filter="context_refiner")
self.unet_loras = noise_loras + context_loras
skipped_un = skipped_noise + skipped_context
else:
self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules, filter=block_filter)
# Handle embedders
if self.embedder_dims:
@@ -689,7 +707,7 @@ class LoRANetwork(torch.nn.Module):
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
weights_sd = torch.load(file, map_location="cpu", weights_only=False)
info = self.load_state_dict(weights_sd, False)
return info
@@ -751,10 +769,10 @@ class LoRANetwork(torch.nn.Module):
state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
new_state_dict = {}
for key in list(state_dict.keys()):
if "double" in key and "qkv" in key:
split_dims = [3072] * 3
elif "single" in key and "linear1" in key:
split_dims = [3072] * 3 + [12288]
if "qkv" in key:
# Lumina 2B: dim=2304, n_heads=24, n_kv_heads=8, head_dim=96
# Q=24*96=2304, K=8*96=768, V=8*96=768
split_dims = [2304, 768, 768]
else:
new_state_dict[key] = state_dict[key]
continue
@@ -1035,4 +1053,4 @@ class LoRANetwork(torch.nn.Module):
scalednorm = updown.norm() * ratio
norms.append(scalednorm.item())
return keys_scaled, sum(norms) / len(norms), max(norms)
return keys_scaled, sum(norms) / len(norms), max(norms)