mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'dev' into sd3
This commit is contained in:
@@ -704,6 +704,13 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します
|
|||||||
|
|
||||||
- `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。また同スクリプトで Diffusers ベースのキーを持つ LoRA の重みに対応しました。
|
- `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。また同スクリプトで Diffusers ベースのキーを持つ LoRA の重みに対応しました。
|
||||||
|
|
||||||
|
|
||||||
|
### Sep 13, 2024 / 2024-09-13:
|
||||||
|
|
||||||
|
- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580). Will be included in the next release.
|
||||||
|
|
||||||
|
- `sdxl_merge_lora.py` が OFT をサポートしました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。次のリリースに含まれます。
|
||||||
|
|
||||||
### Jun 23, 2024 / 2024-06-23:
|
### Jun 23, 2024 / 2024-06-23:
|
||||||
|
|
||||||
- Fixed `cache_latents.py` and `cache_text_encoder_outputs.py` not working. (Will be included in the next release.)
|
- Fixed `cache_latents.py` and `cache_text_encoder_outputs.py` not working. (Will be included in the next release.)
|
||||||
|
|||||||
@@ -8,10 +8,15 @@ from tqdm import tqdm
|
|||||||
from library import sai_model_spec, sdxl_model_util, train_util
|
from library import sai_model_spec, sdxl_model_util, train_util
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
import lora
|
import lora
|
||||||
|
import oft
|
||||||
from library.utils import setup_logging
|
from library.utils import setup_logging
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(file_name, dtype):
|
def load_state_dict(file_name, dtype):
|
||||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||||
@@ -40,14 +45,29 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
|
|||||||
torch.save(model, file_name)
|
torch.save(model, file_name)
|
||||||
|
|
||||||
|
|
||||||
|
def detect_method_from_training_model(models, dtype):
|
||||||
|
for model in models:
|
||||||
|
lora_sd, _ = load_state_dict(model, dtype)
|
||||||
|
for key in tqdm(lora_sd.keys()):
|
||||||
|
if "lora_up" in key or "lora_down" in key:
|
||||||
|
return "LoRA"
|
||||||
|
elif "oft_blocks" in key:
|
||||||
|
return "OFT"
|
||||||
|
|
||||||
|
|
||||||
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
|
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
|
||||||
text_encoder1.to(merge_dtype)
|
text_encoder1.to(merge_dtype)
|
||||||
text_encoder1.to(merge_dtype)
|
text_encoder1.to(merge_dtype)
|
||||||
unet.to(merge_dtype)
|
unet.to(merge_dtype)
|
||||||
|
|
||||||
|
# detect the method: OFT or LoRA_module
|
||||||
|
method = detect_method_from_training_model(models, merge_dtype)
|
||||||
|
logger.info(f"method:{method}")
|
||||||
|
|
||||||
# create module map
|
# create module map
|
||||||
name_to_module = {}
|
name_to_module = {}
|
||||||
for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
|
for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
|
||||||
|
if method == "LoRA":
|
||||||
if i <= 1:
|
if i <= 1:
|
||||||
if i == 0:
|
if i == 0:
|
||||||
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
|
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
|
||||||
@@ -59,6 +79,12 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
|||||||
target_replace_modules = (
|
target_replace_modules = (
|
||||||
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||||
)
|
)
|
||||||
|
elif method == "OFT":
|
||||||
|
prefix = oft.OFTNetwork.OFT_PREFIX_UNET
|
||||||
|
# ALL_LINEAR includes ATTN_ONLY, so we don't need to specify ATTN_ONLY
|
||||||
|
target_replace_modules = (
|
||||||
|
oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||||
|
)
|
||||||
|
|
||||||
for name, module in root_module.named_modules():
|
for name, module in root_module.named_modules():
|
||||||
if module.__class__.__name__ in target_replace_modules:
|
if module.__class__.__name__ in target_replace_modules:
|
||||||
@@ -73,6 +99,8 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
|||||||
lora_sd, _ = load_state_dict(model, merge_dtype)
|
lora_sd, _ = load_state_dict(model, merge_dtype)
|
||||||
|
|
||||||
logger.info(f"merging...")
|
logger.info(f"merging...")
|
||||||
|
|
||||||
|
if method == "LoRA":
|
||||||
for key in tqdm(lora_sd.keys()):
|
for key in tqdm(lora_sd.keys()):
|
||||||
if "lora_down" in key:
|
if "lora_down" in key:
|
||||||
up_key = key.replace("lora_down", "lora_up")
|
up_key = key.replace("lora_down", "lora_up")
|
||||||
@@ -115,6 +143,75 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
|||||||
|
|
||||||
module.weight = torch.nn.Parameter(weight)
|
module.weight = torch.nn.Parameter(weight)
|
||||||
|
|
||||||
|
elif method == "OFT":
|
||||||
|
|
||||||
|
multiplier = 1.0
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
for key in tqdm(lora_sd.keys()):
|
||||||
|
if "oft_blocks" in key:
|
||||||
|
oft_blocks = lora_sd[key]
|
||||||
|
dim = oft_blocks.shape[0]
|
||||||
|
break
|
||||||
|
for key in tqdm(lora_sd.keys()):
|
||||||
|
if "alpha" in key:
|
||||||
|
oft_blocks = lora_sd[key]
|
||||||
|
alpha = oft_blocks.item()
|
||||||
|
break
|
||||||
|
|
||||||
|
def merge_to(key):
|
||||||
|
if "alpha" in key:
|
||||||
|
return
|
||||||
|
|
||||||
|
# find original module for this OFT
|
||||||
|
module_name = ".".join(key.split(".")[:-1])
|
||||||
|
if module_name not in name_to_module:
|
||||||
|
logger.info(f"no module found for OFT weight: {key}")
|
||||||
|
return
|
||||||
|
module = name_to_module[module_name]
|
||||||
|
|
||||||
|
# logger.info(f"apply {key} to {module}")
|
||||||
|
|
||||||
|
oft_blocks = lora_sd[key]
|
||||||
|
|
||||||
|
if isinstance(module, torch.nn.Linear):
|
||||||
|
out_dim = module.out_features
|
||||||
|
elif isinstance(module, torch.nn.Conv2d):
|
||||||
|
out_dim = module.out_channels
|
||||||
|
|
||||||
|
num_blocks = dim
|
||||||
|
block_size = out_dim // dim
|
||||||
|
constraint = (0 if alpha is None else alpha) * out_dim
|
||||||
|
|
||||||
|
block_Q = oft_blocks - oft_blocks.transpose(1, 2)
|
||||||
|
norm_Q = torch.norm(block_Q.flatten())
|
||||||
|
new_norm_Q = torch.clamp(norm_Q, max=constraint)
|
||||||
|
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||||
|
I = torch.eye(block_size, device=oft_blocks.device).unsqueeze(0).repeat(num_blocks, 1, 1)
|
||||||
|
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
|
||||||
|
block_R_weighted = multiplier * block_R + (1 - multiplier) * I
|
||||||
|
R = torch.block_diag(*block_R_weighted)
|
||||||
|
|
||||||
|
# get org weight
|
||||||
|
org_sd = module.state_dict()
|
||||||
|
org_weight = org_sd["weight"].to(device)
|
||||||
|
|
||||||
|
R = R.to(org_weight.device, dtype=org_weight.dtype)
|
||||||
|
|
||||||
|
if org_weight.dim() == 4:
|
||||||
|
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
|
||||||
|
else:
|
||||||
|
weight = torch.einsum("oi, op -> pi", org_weight, R)
|
||||||
|
|
||||||
|
weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor
|
||||||
|
|
||||||
|
module.weight = torch.nn.Parameter(weight)
|
||||||
|
|
||||||
|
# TODO multi-threading may cause OOM on CPU if cpu_count is too high and RAM is not enough
|
||||||
|
max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys())))
|
||||||
|
|
||||||
|
|
||||||
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
||||||
base_alphas = {} # alpha for merged model
|
base_alphas = {} # alpha for merged model
|
||||||
@@ -229,7 +326,9 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
|||||||
|
|
||||||
|
|
||||||
def merge(args):
|
def merge(args):
|
||||||
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
assert len(args.models) == len(
|
||||||
|
args.ratios
|
||||||
|
), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||||
|
|
||||||
def str_to_dtype(p):
|
def str_to_dtype(p):
|
||||||
if p == "float":
|
if p == "float":
|
||||||
@@ -316,10 +415,16 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
|
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
|
"--save_to",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
|
"--models",
|
||||||
|
type=str,
|
||||||
|
nargs="*",
|
||||||
|
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors",
|
||||||
)
|
)
|
||||||
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
|
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -337,8 +442,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--shuffle",
|
"--shuffle",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="shuffle lora weight./ "
|
help="shuffle lora weight./ " + "LoRAの重みをシャッフルする",
|
||||||
+ "LoRAの重みをシャッフルする",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|||||||
Reference in New Issue
Block a user