formatting, update README

This commit is contained in:
Kohya S
2024-09-13 19:45:42 +09:00
parent 57ae44eb61
commit 3387dc7306
2 changed files with 54 additions and 38 deletions

View File

@@ -137,6 +137,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
## Change History ## Change History
### Sep 13, 2024 / 2024-09-13:
- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580). Will be included in the next release.
- `sdxl_merge_lora.py` が OFT をサポートしました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。次のリリースに含まれます。
### Jun 23, 2024 / 2024-06-23: ### Jun 23, 2024 / 2024-06-23:
- Fixed `cache_latents.py` and `cache_text_encoder_outputs.py` not working. (Will be included in the next release.) - Fixed `cache_latents.py` and `cache_text_encoder_outputs.py` not working. (Will be included in the next release.)

View File

@@ -10,11 +10,14 @@ import library.model_util as model_util
import lora import lora
import oft import oft
from library.utils import setup_logging from library.utils import setup_logging
setup_logging() setup_logging()
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
import concurrent.futures import concurrent.futures
def load_state_dict(file_name, dtype): def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors": if os.path.splitext(file_name)[1] == ".safetensors":
sd = load_file(file_name) sd = load_file(file_name)
@@ -41,20 +44,22 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
else: else:
torch.save(model, file_name) torch.save(model, file_name)
def detect_method_from_training_model(models, dtype): def detect_method_from_training_model(models, dtype):
for model in models: for model in models:
lora_sd, _ = load_state_dict(model, dtype) lora_sd, _ = load_state_dict(model, dtype)
for key in tqdm(lora_sd.keys()): for key in tqdm(lora_sd.keys()):
if 'lora_up' in key or 'lora_down' in key: if "lora_up" in key or "lora_down" in key:
return 'LoRA' return "LoRA"
elif "oft_blocks" in key: elif "oft_blocks" in key:
return 'OFT' return "OFT"
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
text_encoder1.to(merge_dtype) text_encoder1.to(merge_dtype)
text_encoder1.to(merge_dtype) text_encoder1.to(merge_dtype)
unet.to(merge_dtype) unet.to(merge_dtype)
# detect the method: OFT or LoRA_module # detect the method: OFT or LoRA_module
method = detect_method_from_training_model(models, merge_dtype) method = detect_method_from_training_model(models, merge_dtype)
logger.info(f"method:{method}") logger.info(f"method:{method}")
@@ -62,7 +67,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
# create module map # create module map
name_to_module = {} name_to_module = {}
for i, root_module in enumerate([text_encoder1, text_encoder2, unet]): for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
if method == 'LoRA': if method == "LoRA":
if i <= 1: if i <= 1:
if i == 0: if i == 0:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
@@ -72,9 +77,9 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
else: else:
prefix = lora.LoRANetwork.LORA_PREFIX_UNET prefix = lora.LoRANetwork.LORA_PREFIX_UNET
target_replace_modules = ( target_replace_modules = (
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
) )
elif method == 'OFT': elif method == "OFT":
prefix = oft.OFTNetwork.OFT_PREFIX_UNET prefix = oft.OFTNetwork.OFT_PREFIX_UNET
# ALL_LINEAR includes ATTN_ONLY, so we don't need to specify ATTN_ONLY # ALL_LINEAR includes ATTN_ONLY, so we don't need to specify ATTN_ONLY
target_replace_modules = ( target_replace_modules = (
@@ -88,15 +93,14 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
lora_name = prefix + "." + name + "." + child_name lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_") lora_name = lora_name.replace(".", "_")
name_to_module[lora_name] = child_module name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios): for model, ratio in zip(models, ratios):
logger.info(f"loading: {model}") logger.info(f"loading: {model}")
lora_sd, _ = load_state_dict(model, merge_dtype) lora_sd, _ = load_state_dict(model, merge_dtype)
logger.info(f"merging...") logger.info(f"merging...")
if method == 'LoRA': if method == "LoRA":
for key in tqdm(lora_sd.keys()): for key in tqdm(lora_sd.keys()):
if "lora_down" in key: if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up") up_key = key.replace("lora_down", "lora_up")
@@ -139,12 +143,11 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
module.weight = torch.nn.Parameter(weight) module.weight = torch.nn.Parameter(weight)
elif method == "OFT":
elif method == 'OFT':
multiplier = 1.0
multiplier=1.0 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for key in tqdm(lora_sd.keys()): for key in tqdm(lora_sd.keys()):
if "oft_blocks" in key: if "oft_blocks" in key:
oft_blocks = lora_sd[key] oft_blocks = lora_sd[key]
@@ -154,12 +157,12 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
if "alpha" in key: if "alpha" in key:
oft_blocks = lora_sd[key] oft_blocks = lora_sd[key]
alpha = oft_blocks.item() alpha = oft_blocks.item()
break break
def merge_to(key): def merge_to(key):
if "alpha" in key: if "alpha" in key:
return return
# find original module for this OFT # find original module for this OFT
module_name = ".".join(key.split(".")[:-1]) module_name = ".".join(key.split(".")[:-1])
if module_name not in name_to_module: if module_name not in name_to_module:
@@ -168,18 +171,18 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
module = name_to_module[module_name] module = name_to_module[module_name]
# logger.info(f"apply {key} to {module}") # logger.info(f"apply {key} to {module}")
oft_blocks = lora_sd[key] oft_blocks = lora_sd[key]
if isinstance(module, torch.nn.Linear): if isinstance(module, torch.nn.Linear):
out_dim = module.out_features out_dim = module.out_features
elif isinstance(module, torch.nn.Conv2d): elif isinstance(module, torch.nn.Conv2d):
out_dim = module.out_channels out_dim = module.out_channels
num_blocks = dim num_blocks = dim
block_size = out_dim // dim block_size = out_dim // dim
constraint = (0 if alpha is None else alpha) * out_dim constraint = (0 if alpha is None else alpha) * out_dim
block_Q = oft_blocks - oft_blocks.transpose(1, 2) block_Q = oft_blocks - oft_blocks.transpose(1, 2)
norm_Q = torch.norm(block_Q.flatten()) norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=constraint) new_norm_Q = torch.clamp(norm_Q, max=constraint)
@@ -188,24 +191,24 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
block_R_weighted = multiplier * block_R + (1 - multiplier) * I block_R_weighted = multiplier * block_R + (1 - multiplier) * I
R = torch.block_diag(*block_R_weighted) R = torch.block_diag(*block_R_weighted)
# get org weight # get org weight
org_sd = module.state_dict() org_sd = module.state_dict()
org_weight = org_sd["weight"].to(device) org_weight = org_sd["weight"].to(device)
R = R.to(org_weight.device, dtype=org_weight.dtype) R = R.to(org_weight.device, dtype=org_weight.dtype)
if org_weight.dim() == 4: if org_weight.dim() == 4:
weight = torch.einsum("oihw, op -> pihw", org_weight, R) weight = torch.einsum("oihw, op -> pihw", org_weight, R)
else: else:
weight = torch.einsum("oi, op -> pi", org_weight, R) weight = torch.einsum("oi, op -> pi", org_weight, R)
weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor
module.weight = torch.nn.Parameter(weight) module.weight = torch.nn.Parameter(weight)
# TODO multi-threading may cause OOM on CPU if cpu_count is too high and RAM is not enough # TODO multi-threading may cause OOM on CPU if cpu_count is too high and RAM is not enough
max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys()))) list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys())))
@@ -258,7 +261,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
for key in tqdm(lora_sd.keys()): for key in tqdm(lora_sd.keys()):
if "alpha" in key: if "alpha" in key:
continue continue
if "lora_up" in key and concat: if "lora_up" in key and concat:
concat_dim = 1 concat_dim = 1
elif "lora_down" in key and concat: elif "lora_down" in key and concat:
@@ -272,8 +275,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
alpha = alphas[lora_module_name] alpha = alphas[lora_module_name]
scale = math.sqrt(alpha / base_alpha) * ratio scale = math.sqrt(alpha / base_alpha) * ratio
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
if key in merged_sd: if key in merged_sd:
assert ( assert (
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
@@ -295,7 +298,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
dim = merged_sd[key_down].shape[0] dim = merged_sd[key_down].shape[0]
perm = torch.randperm(dim) perm = torch.randperm(dim)
merged_sd[key_down] = merged_sd[key_down][perm] merged_sd[key_down] = merged_sd[key_down][perm]
merged_sd[key_up] = merged_sd[key_up][:,perm] merged_sd[key_up] = merged_sd[key_up][:, perm]
logger.info("merged model") logger.info("merged model")
logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
@@ -323,7 +326,9 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
def merge(args): def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" assert len(args.models) == len(
args.ratios
), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
def str_to_dtype(p): def str_to_dtype(p):
if p == "float": if p == "float":
@@ -410,10 +415,16 @@ def setup_parser() -> argparse.ArgumentParser:
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする", help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
) )
parser.add_argument( parser.add_argument(
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" "--save_to",
type=str,
default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
) )
parser.add_argument( parser.add_argument(
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" "--models",
type=str,
nargs="*",
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors",
) )
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument( parser.add_argument(
@@ -431,8 +442,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument( parser.add_argument(
"--shuffle", "--shuffle",
action="store_true", action="store_true",
help="shuffle lora weight./ " help="shuffle lora weight./ " + "LoRAの重みをシャッフルする",
+ "LoRAの重みをシャッフルする",
) )
return parser return parser