mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Compare commits
4 Commits
ea634a1f93
...
141b101619
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
141b101619 | ||
|
|
8b5ce3e641 | ||
|
|
da07e4c617 | ||
|
|
872124c5e1 |
@@ -87,7 +87,14 @@ def index_sv_ratio(S, target):
|
||||
# Modified from Kohaku-blueleaf's extract/merge functions
|
||||
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
out_size, in_size, kernel_size, _ = weight.size()
|
||||
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
|
||||
weight = weight.reshape(out_size, -1)
|
||||
_in_size = in_size * kernel_size * kernel_size
|
||||
|
||||
if out_size > 2048 and _in_size > 2048:
|
||||
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, _in_size))
|
||||
Vh = V.T
|
||||
else:
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
lora_rank = param_dict["new_rank"]
|
||||
@@ -106,7 +113,11 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale
|
||||
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
out_size, in_size = weight.size()
|
||||
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
if out_size > 2048 and in_size > 2048:
|
||||
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, in_size))
|
||||
Vh = V.T
|
||||
else:
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
lora_rank = param_dict["new_rank"]
|
||||
|
||||
@@ -15,6 +15,12 @@ import random
|
||||
import re
|
||||
|
||||
import diffusers
|
||||
|
||||
# Compatible import for diffusers old/new UNet path
|
||||
try:
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
except ImportError:
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
@@ -80,7 +86,7 @@ CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
"""
|
||||
|
||||
|
||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
||||
def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
||||
if mem_eff_attn:
|
||||
logger.info("Enable memory efficient attention for U-Net")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user