mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'dev' into sd3
This commit is contained in:
26
README.md
26
README.md
@@ -567,7 +567,31 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
|||||||
- transformers, accelerate and huggingface_hub are updated.
|
- transformers, accelerate and huggingface_hub are updated.
|
||||||
- If you encounter any issues, please report them.
|
- If you encounter any issues, please report them.
|
||||||
|
|
||||||
- en: The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds!
|
- Improvements in OFT (Orthogonal Finetuning) Implementation
|
||||||
|
1. Optimization of Calculation Order:
|
||||||
|
- Changed the calculation order in the forward method from (Wx)R to W(xR).
|
||||||
|
- This has improved computational efficiency and processing speed.
|
||||||
|
2. Correction of Bias Application:
|
||||||
|
- In the previous implementation, R was incorrectly applied to the bias.
|
||||||
|
- The new implementation now correctly handles bias by using F.conv2d and F.linear.
|
||||||
|
3. Efficiency Enhancement in Matrix Operations:
|
||||||
|
- Introduced einsum in both the forward and merge_to methods.
|
||||||
|
- This has optimized matrix operations, resulting in further speed improvements.
|
||||||
|
4. Proper Handling of Data Types:
|
||||||
|
- Improved to use torch.float32 during calculations and convert results back to the original data type.
|
||||||
|
- This maintains precision while ensuring compatibility with the original model.
|
||||||
|
5. Unified Processing for Conv2d and Linear Layers:
|
||||||
|
- Implemented a consistent method for applying OFT to both layer types.
|
||||||
|
- These changes have made the OFT implementation more efficient and accurate, potentially leading to improved model performance and training stability.
|
||||||
|
|
||||||
|
- Additional Information
|
||||||
|
* Recommended α value for OFT constraint: We recommend using α values between 1e-4 and 1e-2. This differs slightly from the original implementation of "(α\*out_dim\*out_dim)". Our implementation uses "(α\*out_dim)", hence we recommend higher values than the 1e-5 suggested in the original implementation.
|
||||||
|
|
||||||
|
* Performance Improvement: Training speed has been improved by approximately 30%.
|
||||||
|
|
||||||
|
* Inference Environment: This implementation is compatible with and operates within Stable Diffusion web UI (SD1/2 and SDXL).
|
||||||
|
|
||||||
|
- The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds!
|
||||||
- See the [transformers documentation](https://huggingface.co/docs/transformers/v4.44.2/en/main_classes/optimizer_schedules#schedules) for details on each scheduler.
|
- See the [transformers documentation](https://huggingface.co/docs/transformers/v4.44.2/en/main_classes/optimizer_schedules#schedules) for details on each scheduler.
|
||||||
- `--lr_warmup_steps` and `--lr_decay_steps` can now be specified as a ratio of the number of training steps, not just the step value. Example: `--lr_warmup_steps=0.1` or `--lr_warmup_steps=10%`, etc.
|
- `--lr_warmup_steps` and `--lr_decay_steps` can now be specified as a ratio of the number of training steps, not just the step value. Example: `--lr_warmup_steps=0.1` or `--lr_warmup_steps=10%`, etc.
|
||||||
|
|
||||||
|
|||||||
@@ -86,7 +86,8 @@ CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
# def replace_unet_modules(unet: diffusers.models.unets.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
||||||
|
def replace_unet_modules(unet, mem_eff_attn, xformers, sdpa):
|
||||||
if mem_eff_attn:
|
if mem_eff_attn:
|
||||||
logger.info("Enable memory efficient attention for U-Net")
|
logger.info("Enable memory efficient attention for U-Net")
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ def main(file):
|
|||||||
|
|
||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key:
|
if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key or "oft_" in key:
|
||||||
values.append((key, sd[key]))
|
values.append((key, sd[key]))
|
||||||
print(f"number of LoRA modules: {len(values)}")
|
print(f"number of LoRA modules: {len(values)}")
|
||||||
|
|
||||||
|
|||||||
@@ -4,13 +4,17 @@ import math
|
|||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
from diffusers import AutoencoderKL
|
from diffusers import AutoencoderKL
|
||||||
|
import einops
|
||||||
from transformers import CLIPTextModel
|
from transformers import CLIPTextModel
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import re
|
import re
|
||||||
from library.utils import setup_logging
|
from library.utils import setup_logging
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||||
@@ -45,11 +49,16 @@ class OFTModule(torch.nn.Module):
|
|||||||
|
|
||||||
if type(alpha) == torch.Tensor:
|
if type(alpha) == torch.Tensor:
|
||||||
alpha = alpha.detach().numpy()
|
alpha = alpha.detach().numpy()
|
||||||
|
|
||||||
|
# constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility
|
||||||
|
# original alpha is 1e-6, so we use 1e-3 or 1e-4 for alpha
|
||||||
self.constraint = alpha * out_dim
|
self.constraint = alpha * out_dim
|
||||||
|
|
||||||
self.register_buffer("alpha", torch.tensor(alpha))
|
self.register_buffer("alpha", torch.tensor(alpha))
|
||||||
|
|
||||||
self.block_size = out_dim // self.num_blocks
|
self.block_size = out_dim // self.num_blocks
|
||||||
self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size))
|
self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size))
|
||||||
|
self.I = torch.eye(self.block_size).unsqueeze(0).repeat(self.num_blocks, 1, 1) # cpu
|
||||||
|
|
||||||
self.out_dim = out_dim
|
self.out_dim = out_dim
|
||||||
self.shape = org_module.weight.shape
|
self.shape = org_module.weight.shape
|
||||||
@@ -69,27 +78,36 @@ class OFTModule(torch.nn.Module):
|
|||||||
norm_Q = torch.norm(block_Q.flatten())
|
norm_Q = torch.norm(block_Q.flatten())
|
||||||
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
|
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
|
||||||
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||||
I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1)
|
|
||||||
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
|
|
||||||
|
|
||||||
block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I
|
if self.I.device != block_Q.device:
|
||||||
R = torch.block_diag(*block_R_weighted)
|
self.I = self.I.to(block_Q.device)
|
||||||
|
I = self.I
|
||||||
return R
|
block_R = torch.matmul(I + block_Q, (I - block_Q).float().inverse())
|
||||||
|
block_R_weighted = self.multiplier * (block_R - I) + I
|
||||||
|
return block_R_weighted
|
||||||
|
|
||||||
def forward(self, x, scale=None):
|
def forward(self, x, scale=None):
|
||||||
x = self.org_forward(x)
|
|
||||||
if self.multiplier == 0.0:
|
if self.multiplier == 0.0:
|
||||||
return x
|
return self.org_forward(x)
|
||||||
|
org_module = self.org_module[0]
|
||||||
|
org_dtype = x.dtype
|
||||||
|
|
||||||
R = self.get_weight().to(x.device, dtype=x.dtype)
|
R = self.get_weight().to(torch.float32)
|
||||||
if x.dim() == 4:
|
W = org_module.weight.to(torch.float32)
|
||||||
x = x.permute(0, 2, 3, 1)
|
|
||||||
x = torch.matmul(x, R)
|
if len(W.shape) == 4: # Conv2d
|
||||||
x = x.permute(0, 3, 1, 2)
|
W_reshaped = einops.rearrange(W, "(k n) ... -> k n ...", k=self.num_blocks, n=self.block_size)
|
||||||
else:
|
RW = torch.einsum("k n m, k n ... -> k m ...", R, W_reshaped)
|
||||||
x = torch.matmul(x, R)
|
RW = einops.rearrange(RW, "k m ... -> (k m) ...")
|
||||||
return x
|
result = F.conv2d(
|
||||||
|
x, RW.to(org_dtype), org_module.bias, org_module.stride, org_module.padding, org_module.dilation, org_module.groups
|
||||||
|
)
|
||||||
|
else: # Linear
|
||||||
|
W_reshaped = einops.rearrange(W, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size)
|
||||||
|
RW = torch.einsum("k n m, k n p -> k m p", R, W_reshaped)
|
||||||
|
RW = einops.rearrange(RW, "k m p -> (k m) p")
|
||||||
|
result = F.linear(x, RW.to(org_dtype), org_module.bias)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class OFTInfModule(OFTModule):
|
class OFTInfModule(OFTModule):
|
||||||
@@ -115,18 +133,19 @@ class OFTInfModule(OFTModule):
|
|||||||
return self.org_forward(x)
|
return self.org_forward(x)
|
||||||
return super().forward(x, scale)
|
return super().forward(x, scale)
|
||||||
|
|
||||||
def merge_to(self, multiplier=None, sign=1):
|
def merge_to(self, multiplier=None):
|
||||||
R = self.get_weight(multiplier) * sign
|
|
||||||
|
|
||||||
# get org weight
|
# get org weight
|
||||||
org_sd = self.org_module[0].state_dict()
|
org_sd = self.org_module[0].state_dict()
|
||||||
org_weight = org_sd["weight"]
|
org_weight = org_sd["weight"].to(torch.float32)
|
||||||
R = R.to(org_weight.device, dtype=org_weight.dtype)
|
|
||||||
|
|
||||||
if org_weight.dim() == 4:
|
R = self.get_weight(multiplier).to(torch.float32)
|
||||||
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
|
|
||||||
else:
|
weight = org_weight.reshape(self.num_blocks, self.block_size, -1)
|
||||||
weight = torch.einsum("oi, op -> pi", org_weight, R)
|
weight = torch.einsum("k n m, k n ... -> k m ...", R, weight)
|
||||||
|
weight = weight.reshape(org_weight.shape)
|
||||||
|
|
||||||
|
# convert back to original dtype
|
||||||
|
weight = weight.to(org_sd["weight"].dtype)
|
||||||
|
|
||||||
# set weight to org_module
|
# set weight to org_module
|
||||||
org_sd["weight"] = weight
|
org_sd["weight"] = weight
|
||||||
@@ -145,8 +164,16 @@ def create_network(
|
|||||||
):
|
):
|
||||||
if network_dim is None:
|
if network_dim is None:
|
||||||
network_dim = 4 # default
|
network_dim = 4 # default
|
||||||
if network_alpha is None:
|
if network_alpha is None: # should be set
|
||||||
network_alpha = 1.0
|
logger.info(
|
||||||
|
"network_alpha is not set, use default value 1e-3 / network_alphaが設定されていないのでデフォルト値 1e-3 を使用します"
|
||||||
|
)
|
||||||
|
network_alpha = 1e-3
|
||||||
|
elif network_alpha >= 1:
|
||||||
|
logger.warning(
|
||||||
|
"network_alpha is too large (>=1, maybe default value is too large), please consider to set smaller value like 1e-3"
|
||||||
|
" / network_alphaが大きすぎるようです(>=1, デフォルト値が大きすぎる可能性があります)。1e-3のような小さな値を推奨"
|
||||||
|
)
|
||||||
|
|
||||||
enable_all_linear = kwargs.get("enable_all_linear", None)
|
enable_all_linear = kwargs.get("enable_all_linear", None)
|
||||||
enable_conv = kwargs.get("enable_conv", None)
|
enable_conv = kwargs.get("enable_conv", None)
|
||||||
@@ -190,12 +217,11 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
|||||||
else:
|
else:
|
||||||
if dim is None:
|
if dim is None:
|
||||||
dim = param.size()[0]
|
dim = param.size()[0]
|
||||||
if has_conv2d is None and param.dim() == 4:
|
if has_conv2d is None and "in_layers_2" in name:
|
||||||
has_conv2d = True
|
has_conv2d = True
|
||||||
if all_linear is None:
|
if all_linear is None and "_ff_" in name:
|
||||||
if param.dim() == 3 and "attn" not in name:
|
all_linear = True
|
||||||
all_linear = True
|
if dim is not None and alpha is not None and has_conv2d is not None and all_linear is not None:
|
||||||
if dim is not None and alpha is not None and has_conv2d is not None:
|
|
||||||
break
|
break
|
||||||
if has_conv2d is None:
|
if has_conv2d is None:
|
||||||
has_conv2d = False
|
has_conv2d = False
|
||||||
@@ -241,7 +267,7 @@ class OFTNetwork(torch.nn.Module):
|
|||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}"
|
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}, enable_all_linear: {enable_all_linear}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# create module instances
|
# create module instances
|
||||||
|
|||||||
Reference in New Issue
Block a user