mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
Merge branch 'sd3' into feat-support-lokr-loha
This commit is contained in:
@@ -227,19 +227,16 @@ class LoRAInfModule(LoRAModule):
|
||||
org_sd["weight"] = weight.to(dtype)
|
||||
self.org_module.load_state_dict(org_sd)
|
||||
else:
|
||||
# split_dims
|
||||
total_dims = sum(self.split_dims)
|
||||
# split_dims: merge each split's LoRA into the correct slice of the fused QKV weight
|
||||
for i in range(len(self.split_dims)):
|
||||
# get up/down weight
|
||||
down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim)
|
||||
up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank)
|
||||
up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split_dim, rank)
|
||||
|
||||
# pad up_weight -> (total_dims, rank)
|
||||
padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float)
|
||||
padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight
|
||||
|
||||
# merge weight
|
||||
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
||||
# merge into the correct slice of the fused weight
|
||||
start = sum(self.split_dims[:i])
|
||||
end = sum(self.split_dims[:i + 1])
|
||||
weight[start:end] += self.multiplier * (up_weight @ down_weight) * self.scale
|
||||
|
||||
# set weight to org_module
|
||||
org_sd["weight"] = weight.to(dtype)
|
||||
@@ -250,6 +247,17 @@ class LoRAInfModule(LoRAModule):
|
||||
if multiplier is None:
|
||||
multiplier = self.multiplier
|
||||
|
||||
# Handle split_dims case where lora_down/lora_up are ModuleList
|
||||
if self.split_dims is not None:
|
||||
# Each sub-module produces a partial weight; concatenate along output dim
|
||||
weights = []
|
||||
for lora_up, lora_down in zip(self.lora_up, self.lora_down):
|
||||
up_w = lora_up.weight.to(torch.float)
|
||||
down_w = lora_down.weight.to(torch.float)
|
||||
weights.append(up_w @ down_w)
|
||||
weight = self.multiplier * torch.cat(weights, dim=0) * self.scale
|
||||
return weight
|
||||
|
||||
# get up/down weight from module
|
||||
up_weight = self.lora_up.weight.to(torch.float)
|
||||
down_weight = self.lora_down.weight.to(torch.float)
|
||||
@@ -409,7 +417,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, wei
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
weights_sd = torch.load(file, map_location="cpu", weights_only=False)
|
||||
|
||||
# get dim/alpha mapping, and train t5xxl
|
||||
modules_dim = {}
|
||||
@@ -634,20 +642,30 @@ class LoRANetwork(torch.nn.Module):
|
||||
skipped_te += skipped
|
||||
|
||||
# create LoRA for U-Net
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
# Filter by block type using name-based filtering in create_modules
|
||||
# All block types use JointTransformerBlock, so we filter by module path name
|
||||
block_filter = None # None means no filtering (train all)
|
||||
if self.train_blocks == "all":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
# TODO: limit different blocks
|
||||
block_filter = None
|
||||
elif self.train_blocks == "transformer":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
elif self.train_blocks == "refiners":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
block_filter = "layers_" # main transformer blocks: "lora_unet_layers_N_..."
|
||||
elif self.train_blocks == "noise_refiner":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
elif self.train_blocks == "cap_refiner":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
block_filter = "noise_refiner"
|
||||
elif self.train_blocks == "context_refiner":
|
||||
block_filter = "context_refiner"
|
||||
elif self.train_blocks == "refiners":
|
||||
block_filter = None # handled below with two calls
|
||||
|
||||
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
|
||||
self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules)
|
||||
if self.train_blocks == "refiners":
|
||||
# Refiners = noise_refiner + context_refiner, need two calls
|
||||
noise_loras, skipped_noise = create_modules(True, unet, target_replace_modules, filter="noise_refiner")
|
||||
context_loras, skipped_context = create_modules(True, unet, target_replace_modules, filter="context_refiner")
|
||||
self.unet_loras = noise_loras + context_loras
|
||||
skipped_un = skipped_noise + skipped_context
|
||||
else:
|
||||
self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules, filter=block_filter)
|
||||
|
||||
# Handle embedders
|
||||
if self.embedder_dims:
|
||||
@@ -689,7 +707,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
weights_sd = torch.load(file, map_location="cpu", weights_only=False)
|
||||
|
||||
info = self.load_state_dict(weights_sd, False)
|
||||
return info
|
||||
@@ -751,10 +769,10 @@ class LoRANetwork(torch.nn.Module):
|
||||
state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
new_state_dict = {}
|
||||
for key in list(state_dict.keys()):
|
||||
if "double" in key and "qkv" in key:
|
||||
split_dims = [3072] * 3
|
||||
elif "single" in key and "linear1" in key:
|
||||
split_dims = [3072] * 3 + [12288]
|
||||
if "qkv" in key:
|
||||
# Lumina 2B: dim=2304, n_heads=24, n_kv_heads=8, head_dim=96
|
||||
# Q=24*96=2304, K=8*96=768, V=8*96=768
|
||||
split_dims = [2304, 768, 768]
|
||||
else:
|
||||
new_state_dict[key] = state_dict[key]
|
||||
continue
|
||||
@@ -1035,4 +1053,4 @@ class LoRANetwork(torch.nn.Module):
|
||||
scalednorm = updown.norm() * ratio
|
||||
norms.append(scalednorm.item())
|
||||
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
Reference in New Issue
Block a user