Compare commits

...

4 Commits

Author SHA1 Message Date
woctordho
141b101619 Merge 872124c5e1 into 8b5ce3e641 2026-01-20 04:04:55 +01:00
Kohya S.
8b5ce3e641 Merge pull request #2255 from cgcalatrava/fix-diffusers-unet-import
Fix AttributeError for UNet2DConditionModel with newer diffusers versions
2026-01-20 07:50:04 +09:00
cgcalatrava
da07e4c617 Make UNet2DConditionModel import compatible with old and new diffusers versions 2026-01-19 20:53:00 +01:00
woctordho
872124c5e1 Use svd_lowrank for large matrices in resize_lora.py 2025-11-17 10:14:17 +08:00
2 changed files with 20 additions and 3 deletions

View File

@@ -87,7 +87,14 @@ def index_sv_ratio(S, target):
# Modified from Kohaku-blueleaf's extract/merge functions
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
out_size, in_size, kernel_size, _ = weight.size()
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
weight = weight.reshape(out_size, -1)
_in_size = in_size * kernel_size * kernel_size
if out_size > 2048 and _in_size > 2048:
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, _in_size))
Vh = V.T
else:
U, S, Vh = torch.linalg.svd(weight.to(device))
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
lora_rank = param_dict["new_rank"]
@@ -106,7 +113,11 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
out_size, in_size = weight.size()
U, S, Vh = torch.linalg.svd(weight.to(device))
if out_size > 2048 and in_size > 2048:
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, in_size))
Vh = V.T
else:
U, S, Vh = torch.linalg.svd(weight.to(device))
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
lora_rank = param_dict["new_rank"]

View File

@@ -15,6 +15,12 @@ import random
import re
import diffusers
# Compatible import for diffusers old/new UNet path
try:
from diffusers.models.unet_2d_condition import UNet2DConditionModel
except ImportError:
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
import numpy as np
import torch
@@ -80,7 +86,7 @@ CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
"""
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
logger.info("Enable memory efficient attention for U-Net")