mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Compare commits
3 Commits
dependabot
...
resume-ste
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
faadc350a4 | ||
|
|
6d9338f8b5 | ||
|
|
5f0eebaa56 |
2
.github/workflows/typos.yml
vendored
2
.github/workflows/typos.yml
vendored
@@ -18,4 +18,4 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: typos-action
|
||||
uses: crate-ci/typos@v1.28.1
|
||||
uses: crate-ci/typos@v1.19.0
|
||||
|
||||
@@ -3,10 +3,6 @@ Stable Diffusionの学習、画像生成、その他のスクリプトを入れ
|
||||
|
||||
[README in English](./README.md) ←更新情報はこちらにあります
|
||||
|
||||
開発中のバージョンはdevブランチにあります。最新の変更点はdevブランチをご確認ください。
|
||||
|
||||
FLUX.1およびSD3/SD3.5対応はsd3ブランチで行っています。それらの学習を行う場合はsd3ブランチをご利用ください。
|
||||
|
||||
GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。
|
||||
|
||||
以下のスクリプトがあります。
|
||||
|
||||
40
README.md
40
README.md
@@ -5,11 +5,6 @@ This repository contains training, generation and utility scripts for Stable Dif
|
||||
|
||||
[日本語版READMEはこちら](./README-ja.md)
|
||||
|
||||
The development version is in the `dev` branch. Please check the dev branch for the latest changes.
|
||||
|
||||
FLUX.1 and SD3/SD3.5 support is done in the `sd3` branch. If you want to train them, please use the sd3 branch.
|
||||
|
||||
|
||||
For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais!
|
||||
|
||||
This repository contains the scripts for:
|
||||
@@ -142,41 +137,6 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
||||
|
||||
## Change History
|
||||
|
||||
### Oct 27, 2024 / 2024-10-27:
|
||||
|
||||
- `svd_merge_lora.py` VRAM usage has been reduced. However, main memory usage will increase (32GB is sufficient).
|
||||
- This will be included in the next release.
|
||||
- `svd_merge_lora.py` のVRAM使用量を削減しました。ただし、メインメモリの使用量は増加します(32GBあれば十分です)。
|
||||
- これは次回リリースに含まれます。
|
||||
|
||||
### Oct 26, 2024 / 2024-10-26:
|
||||
|
||||
- Fixed a bug in `svd_merge_lora.py`, `sdxl_merge_lora.py`, and `resize_lora.py` where the hash value of LoRA metadata was not correctly calculated when the `save_precision` was different from the `precision` used in the calculation. See issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) for details. Thanks to JujoHotaru for raising the issue.
|
||||
- It will be included in the next release.
|
||||
|
||||
- `svd_merge_lora.py`、`sdxl_merge_lora.py`、`resize_lora.py`で、保存時の精度が計算時の精度と異なる場合、LoRAメタデータのハッシュ値が正しく計算されない不具合を修正しました。詳細は issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) をご覧ください。問題提起していただいた JujoHotaru 氏に感謝します。
|
||||
- 以上は次回リリースに含まれます。
|
||||
|
||||
### Sep 13, 2024 / 2024-09-13:
|
||||
|
||||
- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580).
|
||||
- `svd_merge_lora.py` now supports LBW. Thanks to terracottahaniwa. See PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) for details.
|
||||
- `sdxl_merge_lora.py` also supports LBW.
|
||||
- See [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) by hako-mikan for details on LBW.
|
||||
- These will be included in the next release.
|
||||
|
||||
- `sdxl_merge_lora.py` が OFT をサポートされました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。
|
||||
- `svd_merge_lora.py` で LBW がサポートされました。PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) terracottahaniwa 氏に感謝します。
|
||||
- `sdxl_merge_lora.py` でも LBW がサポートされました。
|
||||
- LBW の詳細は hako-mikan 氏の [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) をご覧ください。
|
||||
- 以上は次回リリースに含まれます。
|
||||
|
||||
### Jun 23, 2024 / 2024-06-23:
|
||||
|
||||
- Fixed `cache_latents.py` and `cache_text_encoder_outputs.py` not working. (Will be included in the next release.)
|
||||
|
||||
- `cache_latents.py` および `cache_text_encoder_outputs.py` が動作しなくなっていたのを修正しました。(次回リリースに含まれます。)
|
||||
|
||||
### Apr 7, 2024 / 2024-04-07: v0.8.7
|
||||
|
||||
- The default value of `huber_schedule` in Scheduled Huber Loss is changed from `exponential` to `snr`, which is expected to give better results.
|
||||
|
||||
@@ -649,8 +649,15 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
def set_current_epoch(self, epoch):
|
||||
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
|
||||
self.shuffle_buckets()
|
||||
self.current_epoch = epoch
|
||||
if epoch > self.current_epoch:
|
||||
logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
|
||||
num_epochs = epoch - self.current_epoch
|
||||
for _ in range(num_epochs):
|
||||
self.current_epoch += 1
|
||||
self.shuffle_buckets()
|
||||
else:
|
||||
logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
|
||||
self.current_epoch = epoch
|
||||
|
||||
def set_current_step(self, step):
|
||||
self.current_step = step
|
||||
@@ -941,7 +948,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self._length = len(self.buckets_indices)
|
||||
|
||||
def shuffle_buckets(self):
|
||||
# set random seed for this epoch
|
||||
# set random seed for this epoch: current_epoch is not incremented
|
||||
random.seed(self.seed + self.current_epoch)
|
||||
|
||||
random.shuffle(self.buckets_indices)
|
||||
@@ -1489,7 +1496,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
def load_dreambooth_dir(subset: DreamBoothSubset):
|
||||
if not os.path.isdir(subset.image_dir):
|
||||
logger.warning(f"not directory: {subset.image_dir}")
|
||||
return [], [], []
|
||||
return [], []
|
||||
|
||||
info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE)
|
||||
use_cached_info_for_subset = subset.cache_info
|
||||
@@ -2346,10 +2353,10 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
|
||||
|
||||
|
||||
def load_image(image_path):
|
||||
image = Image.open(image_path)
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
img = np.array(image, np.uint8)
|
||||
with Image.open(image_path) as image:
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
img = np.array(image, np.uint8)
|
||||
return img
|
||||
|
||||
|
||||
@@ -5387,7 +5394,7 @@ class LossRecorder:
|
||||
self.loss_total: float = 0.0
|
||||
|
||||
def add(self, *, epoch: int, step: int, loss: float) -> None:
|
||||
if epoch == 0:
|
||||
if epoch == 0 or step >= len(self.loss_list):
|
||||
self.loss_list.append(loss)
|
||||
else:
|
||||
self.loss_total -= self.loss_list[step]
|
||||
|
||||
@@ -39,7 +39,12 @@ def load_state_dict(file_name, dtype):
|
||||
return sd, metadata
|
||||
|
||||
|
||||
def save_to_file(file_name, state_dict, metadata):
|
||||
def save_to_file(file_name, state_dict, dtype, metadata):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
if model_util.is_safetensors(file_name):
|
||||
save_file(state_dict, file_name, metadata)
|
||||
else:
|
||||
@@ -344,18 +349,12 @@ def resize(args):
|
||||
metadata["ss_network_dim"] = "Dynamic"
|
||||
metadata["ss_network_alpha"] = "Dynamic"
|
||||
|
||||
# cast to save_dtype before calculating hashes
|
||||
for key in list(state_dict.keys()):
|
||||
value = state_dict[key]
|
||||
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
|
||||
state_dict[key] = value.to(save_dtype)
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
logger.info(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, metadata)
|
||||
save_to_file(args.save_to, state_dict, save_dtype, metadata)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
@@ -1,25 +1,18 @@
|
||||
import itertools
|
||||
import math
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import concurrent.futures
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from library import sai_model_spec, sdxl_model_util, train_util
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
import oft
|
||||
from svd_merge_lora import format_lbws, get_lbw_block_index, LAYER26
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
sd = load_file(file_name)
|
||||
@@ -35,58 +28,36 @@ def load_state_dict(file_name, dtype):
|
||||
return sd, metadata
|
||||
|
||||
|
||||
def save_to_file(file_name, model, metadata):
|
||||
def save_to_file(file_name, model, state_dict, dtype, metadata):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
save_file(model, file_name, metadata=metadata)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
def detect_method_from_training_model(models, dtype):
|
||||
for model in models:
|
||||
# TODO It is better to use key names to detect the method
|
||||
lora_sd, _ = load_state_dict(model, dtype)
|
||||
for key in tqdm(lora_sd.keys()):
|
||||
if "lora_up" in key or "lora_down" in key:
|
||||
return "LoRA"
|
||||
elif "oft_blocks" in key:
|
||||
return "OFT"
|
||||
|
||||
|
||||
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws, merge_dtype):
|
||||
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
|
||||
text_encoder1.to(merge_dtype)
|
||||
text_encoder1.to(merge_dtype)
|
||||
text_encoder2.to(merge_dtype)
|
||||
unet.to(merge_dtype)
|
||||
|
||||
# detect the method: OFT or LoRA_module
|
||||
method = detect_method_from_training_model(models, merge_dtype)
|
||||
logger.info(f"method:{method}")
|
||||
|
||||
if lbws:
|
||||
lbws, _, LBW_TARGET_IDX = format_lbws(lbws)
|
||||
else:
|
||||
LBW_TARGET_IDX = []
|
||||
|
||||
# create module map
|
||||
name_to_module = {}
|
||||
for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
|
||||
if method == "LoRA":
|
||||
if i <= 1:
|
||||
if i == 0:
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
|
||||
else:
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
|
||||
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||
if i <= 1:
|
||||
if i == 0:
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
|
||||
else:
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
|
||||
target_replace_modules = (
|
||||
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
)
|
||||
elif method == "OFT":
|
||||
prefix = oft.OFTNetwork.OFT_PREFIX_UNET
|
||||
# ALL_LINEAR includes ATTN_ONLY, so we don't need to specify ATTN_ONLY
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
|
||||
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||
else:
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
|
||||
target_replace_modules = (
|
||||
oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
)
|
||||
|
||||
for name, module in root_module.named_modules():
|
||||
@@ -97,172 +68,65 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws,
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
name_to_module[lora_name] = child_module
|
||||
|
||||
for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
|
||||
for model, ratio in zip(models, ratios):
|
||||
logger.info(f"loading: {model}")
|
||||
lora_sd, _ = load_state_dict(model, merge_dtype)
|
||||
|
||||
logger.info(f"merging...")
|
||||
for key in tqdm(lora_sd.keys()):
|
||||
if "lora_down" in key:
|
||||
up_key = key.replace("lora_down", "lora_up")
|
||||
alpha_key = key[: key.index("lora_down")] + "alpha"
|
||||
|
||||
if lbw:
|
||||
lbw_weights = [1] * 26
|
||||
for index, value in zip(LBW_TARGET_IDX, lbw):
|
||||
lbw_weights[index] = value
|
||||
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
|
||||
|
||||
if method == "LoRA":
|
||||
for key in tqdm(lora_sd.keys()):
|
||||
if "lora_down" in key:
|
||||
up_key = key.replace("lora_down", "lora_up")
|
||||
alpha_key = key[: key.index("lora_down")] + "alpha"
|
||||
|
||||
# find original module for this lora
|
||||
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
|
||||
if module_name not in name_to_module:
|
||||
logger.info(f"no module found for LoRA weight: {key}")
|
||||
continue
|
||||
module = name_to_module[module_name]
|
||||
# logger.info(f"apply {key} to {module}")
|
||||
|
||||
down_weight = lora_sd[key]
|
||||
up_weight = lora_sd[up_key]
|
||||
|
||||
dim = down_weight.size()[0]
|
||||
alpha = lora_sd.get(alpha_key, dim)
|
||||
scale = alpha / dim
|
||||
|
||||
if lbw:
|
||||
index = get_lbw_block_index(key, True)
|
||||
is_lbw_target = index in LBW_TARGET_IDX
|
||||
if is_lbw_target:
|
||||
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける
|
||||
|
||||
# W <- W + U * D
|
||||
weight = module.weight
|
||||
# logger.info(module_name, down_weight.size(), up_weight.size())
|
||||
if len(weight.size()) == 2:
|
||||
# linear
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = (
|
||||
weight
|
||||
+ ratio
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + ratio * conved * scale
|
||||
|
||||
module.weight = torch.nn.Parameter(weight)
|
||||
|
||||
elif method == "OFT":
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
for key in tqdm(lora_sd.keys()):
|
||||
if "oft_blocks" in key:
|
||||
oft_blocks = lora_sd[key]
|
||||
dim = oft_blocks.shape[0]
|
||||
break
|
||||
for key in tqdm(lora_sd.keys()):
|
||||
if "alpha" in key:
|
||||
oft_blocks = lora_sd[key]
|
||||
alpha = oft_blocks.item()
|
||||
break
|
||||
|
||||
def merge_to(key):
|
||||
if "alpha" in key:
|
||||
return
|
||||
|
||||
# find original module for this OFT
|
||||
module_name = ".".join(key.split(".")[:-1])
|
||||
# find original module for this lora
|
||||
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
|
||||
if module_name not in name_to_module:
|
||||
logger.info(f"no module found for OFT weight: {key}")
|
||||
return
|
||||
logger.info(f"no module found for LoRA weight: {key}")
|
||||
continue
|
||||
module = name_to_module[module_name]
|
||||
|
||||
# logger.info(f"apply {key} to {module}")
|
||||
|
||||
oft_blocks = lora_sd[key]
|
||||
down_weight = lora_sd[key]
|
||||
up_weight = lora_sd[up_key]
|
||||
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
out_dim = module.out_features
|
||||
elif isinstance(module, torch.nn.Conv2d):
|
||||
out_dim = module.out_channels
|
||||
dim = down_weight.size()[0]
|
||||
alpha = lora_sd.get(alpha_key, dim)
|
||||
scale = alpha / dim
|
||||
|
||||
num_blocks = dim
|
||||
block_size = out_dim // dim
|
||||
constraint = (0 if alpha is None else alpha) * out_dim
|
||||
|
||||
multiplier = 1
|
||||
if lbw:
|
||||
index = get_lbw_block_index(key, False)
|
||||
is_lbw_target = index in LBW_TARGET_IDX
|
||||
if is_lbw_target:
|
||||
multiplier *= lbw_weights[index]
|
||||
|
||||
block_Q = oft_blocks - oft_blocks.transpose(1, 2)
|
||||
norm_Q = torch.norm(block_Q.flatten())
|
||||
new_norm_Q = torch.clamp(norm_Q, max=constraint)
|
||||
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||
I = torch.eye(block_size, device=oft_blocks.device).unsqueeze(0).repeat(num_blocks, 1, 1)
|
||||
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
|
||||
block_R_weighted = multiplier * block_R + (1 - multiplier) * I
|
||||
R = torch.block_diag(*block_R_weighted)
|
||||
|
||||
# get org weight
|
||||
org_sd = module.state_dict()
|
||||
org_weight = org_sd["weight"].to(device)
|
||||
|
||||
R = R.to(org_weight.device, dtype=org_weight.dtype)
|
||||
|
||||
if org_weight.dim() == 4:
|
||||
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
|
||||
# W <- W + U * D
|
||||
weight = module.weight
|
||||
# logger.info(module_name, down_weight.size(), up_weight.size())
|
||||
if len(weight.size()) == 2:
|
||||
# linear
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = (
|
||||
weight
|
||||
+ ratio
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
weight = torch.einsum("oi, op -> pi", org_weight, R)
|
||||
|
||||
weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + ratio * conved * scale
|
||||
|
||||
module.weight = torch.nn.Parameter(weight)
|
||||
|
||||
# TODO multi-threading may cause OOM on CPU if cpu_count is too high and RAM is not enough
|
||||
max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys())))
|
||||
|
||||
|
||||
def merge_lora_models(models, ratios, lbws, merge_dtype, concat=False, shuffle=False):
|
||||
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
||||
base_alphas = {} # alpha for merged model
|
||||
base_dims = {}
|
||||
|
||||
# detect the method: OFT or LoRA_module
|
||||
method = detect_method_from_training_model(models, merge_dtype)
|
||||
if method == "OFT":
|
||||
raise ValueError(
|
||||
"OFT model is not supported for merging OFT models. / OFTモデルはOFTモデル同士のマージには対応していません"
|
||||
)
|
||||
|
||||
if lbws:
|
||||
lbws, _, LBW_TARGET_IDX = format_lbws(lbws)
|
||||
else:
|
||||
LBW_TARGET_IDX = []
|
||||
|
||||
merged_sd = {}
|
||||
v2 = None
|
||||
base_model = None
|
||||
for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
|
||||
for model, ratio in zip(models, ratios):
|
||||
logger.info(f"loading: {model}")
|
||||
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
||||
|
||||
if lbw:
|
||||
lbw_weights = [1] * 26
|
||||
for index, value in zip(LBW_TARGET_IDX, lbw):
|
||||
lbw_weights[index] = value
|
||||
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
|
||||
|
||||
if lora_metadata is not None:
|
||||
if v2 is None:
|
||||
v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず
|
||||
@@ -300,7 +164,7 @@ def merge_lora_models(models, ratios, lbws, merge_dtype, concat=False, shuffle=F
|
||||
for key in tqdm(lora_sd.keys()):
|
||||
if "alpha" in key:
|
||||
continue
|
||||
|
||||
|
||||
if "lora_up" in key and concat:
|
||||
concat_dim = 1
|
||||
elif "lora_down" in key and concat:
|
||||
@@ -314,14 +178,8 @@ def merge_lora_models(models, ratios, lbws, merge_dtype, concat=False, shuffle=F
|
||||
alpha = alphas[lora_module_name]
|
||||
|
||||
scale = math.sqrt(alpha / base_alpha) * ratio
|
||||
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
|
||||
|
||||
if lbw:
|
||||
index = get_lbw_block_index(key, True)
|
||||
is_lbw_target = index in LBW_TARGET_IDX
|
||||
if is_lbw_target:
|
||||
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける
|
||||
|
||||
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
|
||||
|
||||
if key in merged_sd:
|
||||
assert (
|
||||
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
|
||||
@@ -343,7 +201,7 @@ def merge_lora_models(models, ratios, lbws, merge_dtype, concat=False, shuffle=F
|
||||
dim = merged_sd[key_down].shape[0]
|
||||
perm = torch.randperm(dim)
|
||||
merged_sd[key_down] = merged_sd[key_down][perm]
|
||||
merged_sd[key_up] = merged_sd[key_up][:, perm]
|
||||
merged_sd[key_up] = merged_sd[key_up][:,perm]
|
||||
|
||||
logger.info("merged model")
|
||||
logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
||||
@@ -371,15 +229,7 @@ def merge_lora_models(models, ratios, lbws, merge_dtype, concat=False, shuffle=F
|
||||
|
||||
|
||||
def merge(args):
|
||||
assert len(args.models) == len(
|
||||
args.ratios
|
||||
), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||
if args.lbws:
|
||||
assert len(args.models) == len(
|
||||
args.lbws
|
||||
), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください"
|
||||
else:
|
||||
args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく
|
||||
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||
|
||||
def str_to_dtype(p):
|
||||
if p == "float":
|
||||
@@ -407,7 +257,7 @@ def merge(args):
|
||||
ckpt_info,
|
||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu")
|
||||
|
||||
merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, args.lbws, merge_dtype)
|
||||
merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype)
|
||||
|
||||
if args.no_metadata:
|
||||
sai_metadata = None
|
||||
@@ -423,13 +273,7 @@ def merge(args):
|
||||
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
|
||||
)
|
||||
else:
|
||||
state_dict, metadata = merge_lora_models(args.models, args.ratios, args.lbws, merge_dtype, args.concat, args.shuffle)
|
||||
|
||||
# cast to save_dtype before calculating hashes
|
||||
for key in list(state_dict.keys()):
|
||||
value = state_dict[key]
|
||||
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
|
||||
state_dict[key] = value.to(save_dtype)
|
||||
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
|
||||
|
||||
logger.info(f"calculating hashes and creating metadata...")
|
||||
|
||||
@@ -446,7 +290,7 @@ def merge(args):
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
logger.info(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, metadata)
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
@@ -472,19 +316,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
|
||||
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors",
|
||||
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
|
||||
)
|
||||
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||
parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率")
|
||||
parser.add_argument(
|
||||
"--no_metadata",
|
||||
action="store_true",
|
||||
@@ -500,7 +337,8 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--shuffle",
|
||||
action="store_true",
|
||||
help="shuffle lora weight./ " + "LoRAの重みをシャッフルする",
|
||||
help="shuffle lora weight./ "
|
||||
+ "LoRAの重みをシャッフルする",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import argparse
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
@@ -11,195 +8,12 @@ from library import sai_model_spec, train_util
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
|
||||
ACCEPTABLE = [12, 17, 20, 26]
|
||||
SDXL_LAYER_NUM = [12, 20]
|
||||
|
||||
LAYER12 = {
|
||||
"BASE": True,
|
||||
"IN00": False,
|
||||
"IN01": False,
|
||||
"IN02": False,
|
||||
"IN03": False,
|
||||
"IN04": True,
|
||||
"IN05": True,
|
||||
"IN06": False,
|
||||
"IN07": True,
|
||||
"IN08": True,
|
||||
"IN09": False,
|
||||
"IN10": False,
|
||||
"IN11": False,
|
||||
"MID": True,
|
||||
"OUT00": True,
|
||||
"OUT01": True,
|
||||
"OUT02": True,
|
||||
"OUT03": True,
|
||||
"OUT04": True,
|
||||
"OUT05": True,
|
||||
"OUT06": False,
|
||||
"OUT07": False,
|
||||
"OUT08": False,
|
||||
"OUT09": False,
|
||||
"OUT10": False,
|
||||
"OUT11": False,
|
||||
}
|
||||
|
||||
LAYER17 = {
|
||||
"BASE": True,
|
||||
"IN00": False,
|
||||
"IN01": True,
|
||||
"IN02": True,
|
||||
"IN03": False,
|
||||
"IN04": True,
|
||||
"IN05": True,
|
||||
"IN06": False,
|
||||
"IN07": True,
|
||||
"IN08": True,
|
||||
"IN09": False,
|
||||
"IN10": False,
|
||||
"IN11": False,
|
||||
"MID": True,
|
||||
"OUT00": False,
|
||||
"OUT01": False,
|
||||
"OUT02": False,
|
||||
"OUT03": True,
|
||||
"OUT04": True,
|
||||
"OUT05": True,
|
||||
"OUT06": True,
|
||||
"OUT07": True,
|
||||
"OUT08": True,
|
||||
"OUT09": True,
|
||||
"OUT10": True,
|
||||
"OUT11": True,
|
||||
}
|
||||
|
||||
LAYER20 = {
|
||||
"BASE": True,
|
||||
"IN00": True,
|
||||
"IN01": True,
|
||||
"IN02": True,
|
||||
"IN03": True,
|
||||
"IN04": True,
|
||||
"IN05": True,
|
||||
"IN06": True,
|
||||
"IN07": True,
|
||||
"IN08": True,
|
||||
"IN09": False,
|
||||
"IN10": False,
|
||||
"IN11": False,
|
||||
"MID": True,
|
||||
"OUT00": True,
|
||||
"OUT01": True,
|
||||
"OUT02": True,
|
||||
"OUT03": True,
|
||||
"OUT04": True,
|
||||
"OUT05": True,
|
||||
"OUT06": True,
|
||||
"OUT07": True,
|
||||
"OUT08": True,
|
||||
"OUT09": False,
|
||||
"OUT10": False,
|
||||
"OUT11": False,
|
||||
}
|
||||
|
||||
LAYER26 = {
|
||||
"BASE": True,
|
||||
"IN00": True,
|
||||
"IN01": True,
|
||||
"IN02": True,
|
||||
"IN03": True,
|
||||
"IN04": True,
|
||||
"IN05": True,
|
||||
"IN06": True,
|
||||
"IN07": True,
|
||||
"IN08": True,
|
||||
"IN09": True,
|
||||
"IN10": True,
|
||||
"IN11": True,
|
||||
"MID": True,
|
||||
"OUT00": True,
|
||||
"OUT01": True,
|
||||
"OUT02": True,
|
||||
"OUT03": True,
|
||||
"OUT04": True,
|
||||
"OUT05": True,
|
||||
"OUT06": True,
|
||||
"OUT07": True,
|
||||
"OUT08": True,
|
||||
"OUT09": True,
|
||||
"OUT10": True,
|
||||
"OUT11": True,
|
||||
}
|
||||
|
||||
assert len([v for v in LAYER12.values() if v]) == 12
|
||||
assert len([v for v in LAYER17.values() if v]) == 17
|
||||
assert len([v for v in LAYER20.values() if v]) == 20
|
||||
assert len([v for v in LAYER26.values() if v]) == 26
|
||||
|
||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||
|
||||
|
||||
def get_lbw_block_index(lora_name: str, is_sdxl: bool = False) -> int:
|
||||
# lbw block index is 0-based, but 0 for text encoder, so we return 0 for text encoder
|
||||
if "text_model_encoder_" in lora_name: # LoRA for text encoder
|
||||
return 0
|
||||
|
||||
# lbw block index is 1-based for U-Net, and no "input_blocks.0" in CompVis SD, so "input_blocks.1" have index 2
|
||||
block_idx = -1 # invalid lora name
|
||||
if not is_sdxl:
|
||||
NUM_OF_BLOCKS = 12 # up/down blocks
|
||||
m = RE_UPDOWN.search(lora_name)
|
||||
if m:
|
||||
g = m.groups()
|
||||
up_down = g[0]
|
||||
i = int(g[1])
|
||||
j = int(g[3])
|
||||
if up_down == "down":
|
||||
if g[2] == "resnets" or g[2] == "attentions":
|
||||
idx = 3 * i + j + 1
|
||||
elif g[2] == "downsamplers":
|
||||
idx = 3 * (i + 1)
|
||||
else:
|
||||
return block_idx # invalid lora name
|
||||
elif up_down == "up":
|
||||
if g[2] == "resnets" or g[2] == "attentions":
|
||||
idx = 3 * i + j
|
||||
elif g[2] == "upsamplers":
|
||||
idx = 3 * i + 2
|
||||
else:
|
||||
return block_idx # invalid lora name
|
||||
|
||||
if g[0] == "down":
|
||||
block_idx = 1 + idx # 1-based index, down block index
|
||||
elif g[0] == "up":
|
||||
block_idx = 1 + NUM_OF_BLOCKS + 1 + idx # 1-based index, num blocks, mid block, up block index
|
||||
|
||||
elif "mid_block_" in lora_name:
|
||||
block_idx = 1 + NUM_OF_BLOCKS # 1-based index, num blocks, mid block
|
||||
else:
|
||||
# SDXL: some numbers are skipped
|
||||
if lora_name.startswith("lora_unet_"):
|
||||
name = lora_name[len("lora_unet_") :]
|
||||
if name.startswith("time_embed_") or name.startswith("label_emb_"): # 1, No LoRA in sd-scripts
|
||||
block_idx = 1
|
||||
elif name.startswith("input_blocks_"): # 1-8 to 2-9
|
||||
block_idx = 1 + int(name.split("_")[2])
|
||||
elif name.startswith("middle_block_"): # 13
|
||||
block_idx = 13
|
||||
elif name.startswith("output_blocks_"): # 0-8 to 14-22
|
||||
block_idx = 14 + int(name.split("_")[2])
|
||||
elif name.startswith("out_"): # 23, No LoRA in sd-scripts
|
||||
block_idx = 23
|
||||
|
||||
return block_idx
|
||||
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
@@ -216,53 +30,24 @@ def load_state_dict(file_name, dtype):
|
||||
return sd, metadata
|
||||
|
||||
|
||||
def save_to_file(file_name, state_dict, metadata):
|
||||
def save_to_file(file_name, state_dict, dtype, metadata):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
save_file(state_dict, file_name, metadata=metadata)
|
||||
else:
|
||||
torch.save(state_dict, file_name)
|
||||
|
||||
|
||||
def format_lbws(lbws):
|
||||
try:
|
||||
# lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している
|
||||
lbws = [json.loads(lbw) for lbw in lbws]
|
||||
except Exception:
|
||||
raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください")
|
||||
assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください"
|
||||
assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください"
|
||||
assert all(
|
||||
len(lbw) in ACCEPTABLE for lbw in lbws
|
||||
), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください"
|
||||
assert all(
|
||||
all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws
|
||||
), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください"
|
||||
|
||||
layer_num = len(lbws[0])
|
||||
is_sdxl = True if layer_num in SDXL_LAYER_NUM else False
|
||||
FLAGS = {
|
||||
"12": LAYER12.values(),
|
||||
"17": LAYER17.values(),
|
||||
"20": LAYER20.values(),
|
||||
"26": LAYER26.values(),
|
||||
}[str(layer_num)]
|
||||
LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag]
|
||||
return lbws, is_sdxl, LBW_TARGET_IDX
|
||||
|
||||
|
||||
def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype):
|
||||
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
|
||||
logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
|
||||
merged_sd = {}
|
||||
v2 = None # This is meaning LoRA Metadata v2, Not meaning SD2
|
||||
v2 = None
|
||||
base_model = None
|
||||
|
||||
if lbws:
|
||||
lbws, is_sdxl, LBW_TARGET_IDX = format_lbws(lbws)
|
||||
else:
|
||||
is_sdxl = False
|
||||
LBW_TARGET_IDX = []
|
||||
|
||||
for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
|
||||
for model, ratio in zip(models, ratios):
|
||||
logger.info(f"loading: {model}")
|
||||
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
||||
|
||||
@@ -272,12 +57,6 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
|
||||
if base_model is None:
|
||||
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
|
||||
|
||||
if lbw:
|
||||
lbw_weights = [1] * 26
|
||||
for index, value in zip(LBW_TARGET_IDX, lbw):
|
||||
lbw_weights[index] = value
|
||||
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
|
||||
|
||||
# merge
|
||||
logger.info(f"merging...")
|
||||
for key in tqdm(list(lora_sd.keys())):
|
||||
@@ -301,10 +80,10 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
|
||||
# make original weight if not exist
|
||||
if lora_module_name not in merged_sd:
|
||||
weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
|
||||
if device:
|
||||
weight = weight.to(device)
|
||||
else:
|
||||
weight = merged_sd[lora_module_name]
|
||||
if device:
|
||||
weight = weight.to(device)
|
||||
|
||||
# merge to weight
|
||||
if device:
|
||||
@@ -314,12 +93,6 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
|
||||
# W <- W + U * D
|
||||
scale = alpha / network_dim
|
||||
|
||||
if lbw:
|
||||
index = get_lbw_block_index(key, is_sdxl)
|
||||
is_lbw_target = index in LBW_TARGET_IDX
|
||||
if is_lbw_target:
|
||||
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける
|
||||
|
||||
if device: # and isinstance(scale, torch.Tensor):
|
||||
scale = scale.to(device)
|
||||
|
||||
@@ -336,16 +109,13 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
weight = weight + ratio * conved * scale
|
||||
|
||||
merged_sd[lora_module_name] = weight.to("cpu")
|
||||
merged_sd[lora_module_name] = weight
|
||||
|
||||
# extract from merged weights
|
||||
logger.info("extract new lora...")
|
||||
merged_lora_sd = {}
|
||||
with torch.no_grad():
|
||||
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
||||
if device:
|
||||
mat = mat.to(device)
|
||||
|
||||
conv2d = len(mat.size()) == 4
|
||||
kernel_size = None if not conv2d else mat.size()[2:4]
|
||||
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||
@@ -384,7 +154,7 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
|
||||
|
||||
merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous()
|
||||
merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous()
|
||||
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank, device="cpu")
|
||||
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank)
|
||||
|
||||
# build minimum metadata
|
||||
dims = f"{new_rank}"
|
||||
@@ -399,15 +169,7 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
|
||||
|
||||
|
||||
def merge(args):
|
||||
assert len(args.models) == len(
|
||||
args.ratios
|
||||
), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||
if args.lbws:
|
||||
assert len(args.models) == len(
|
||||
args.lbws
|
||||
), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください"
|
||||
else:
|
||||
args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく
|
||||
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||
|
||||
def str_to_dtype(p):
|
||||
if p == "float":
|
||||
@@ -425,15 +187,9 @@ def merge(args):
|
||||
|
||||
new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
|
||||
state_dict, metadata, v2, base_model = merge_lora_models(
|
||||
args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype
|
||||
args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype
|
||||
)
|
||||
|
||||
# cast to save_dtype before calculating hashes
|
||||
for key in list(state_dict.keys()):
|
||||
value = state_dict[key]
|
||||
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
|
||||
state_dict[key] = value.to(save_dtype)
|
||||
|
||||
logger.info(f"calculating hashes and creating metadata...")
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
@@ -455,7 +211,7 @@ def merge(args):
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
logger.info(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, metadata)
|
||||
save_to_file(args.save_to, state_dict, save_dtype, metadata)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
@@ -475,19 +231,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
|
||||
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors",
|
||||
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
|
||||
)
|
||||
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||
parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率")
|
||||
parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
||||
parser.add_argument(
|
||||
"--new_conv_rank",
|
||||
@@ -495,9 +244,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う"
|
||||
)
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
parser.add_argument(
|
||||
"--no_metadata",
|
||||
action="store_true",
|
||||
|
||||
@@ -3,7 +3,7 @@ transformers==4.36.2
|
||||
diffusers[torch]==0.25.0
|
||||
ftfy==6.1.1
|
||||
# albumentations==1.3.0
|
||||
opencv-python==4.8.1.78
|
||||
opencv-python==4.7.0.68
|
||||
einops==0.7.0
|
||||
pytorch-lightning==1.9.0
|
||||
bitsandbytes==0.43.0
|
||||
|
||||
@@ -16,13 +16,12 @@ from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
setup_logging(args, reset=True)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
|
||||
# check cache latents arg
|
||||
@@ -95,7 +94,6 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
args.deepspeed = False
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
@@ -172,7 +170,6 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
|
||||
@@ -16,13 +16,12 @@ from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
setup_logging(args, reset=True)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
|
||||
# check cache arg
|
||||
@@ -100,7 +99,6 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
args.deepspeed = False
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
@@ -173,7 +171,6 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
|
||||
107
train_network.py
107
train_network.py
@@ -483,6 +483,15 @@ class NetworkTrainer:
|
||||
weights.pop(i)
|
||||
# print(f"save model hook: {len(weights)} weights will be saved")
|
||||
|
||||
# save current ecpoch and step
|
||||
train_state_file = os.path.join(output_dir, "train_state.json")
|
||||
# +1 is needed because the state is saved before current_step is set from global_step
|
||||
logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}")
|
||||
with open(train_state_file, "w", encoding="utf-8") as f:
|
||||
json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f)
|
||||
|
||||
steps_from_state = None
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
# remove models except network
|
||||
remove_indices = []
|
||||
@@ -493,6 +502,15 @@ class NetworkTrainer:
|
||||
models.pop(i)
|
||||
# print(f"load model hook: {len(models)} models will be loaded")
|
||||
|
||||
# load current epoch and step to
|
||||
nonlocal steps_from_state
|
||||
train_state_file = os.path.join(input_dir, "train_state.json")
|
||||
if os.path.exists(train_state_file):
|
||||
with open(train_state_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
steps_from_state = data["current_step"] + 1 # because
|
||||
logger.info(f"load train state from {train_state_file}: {data}")
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
@@ -736,7 +754,52 @@ class NetworkTrainer:
|
||||
if key in metadata:
|
||||
minimum_metadata[key] = metadata[key]
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
# calculate steps to skip when resuming or starting from a specific step
|
||||
initial_step = 0
|
||||
if args.initial_epoch is not None or args.initial_step is not None:
|
||||
# if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming
|
||||
if steps_from_state is not None:
|
||||
logger.warning(
|
||||
"steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます"
|
||||
)
|
||||
if args.initial_step is not None:
|
||||
initial_step = args.initial_step
|
||||
else:
|
||||
# num steps per epoch is calculated by num_processes and gradient_accumulation_steps
|
||||
initial_step = (args.initial_epoch - 1) * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
else:
|
||||
# if initial_epoch and initial_step are not specified, steps_from_state is used when resuming
|
||||
if steps_from_state is not None:
|
||||
initial_step = steps_from_state
|
||||
steps_from_state = None
|
||||
|
||||
if initial_step > 0:
|
||||
assert (
|
||||
args.max_train_steps > initial_step
|
||||
), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
|
||||
)
|
||||
|
||||
epoch_to_start = 0
|
||||
if initial_step > 0:
|
||||
if args.skip_until_initial_step:
|
||||
# if skip_until_initial_step is specified, load data and discard it to ensure the same data is used
|
||||
if not args.resume:
|
||||
logger.info(
|
||||
f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります"
|
||||
)
|
||||
logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
|
||||
else:
|
||||
# if not, only epoch no is skipped for informative purpose
|
||||
epoch_to_start = initial_step // math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
initial_step = 0 # do not skip
|
||||
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
@@ -793,16 +856,35 @@ class NetworkTrainer:
|
||||
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
if initial_step > 0:
|
||||
# set starting global step calculated from initial_step. because skipping steps doesn't increment global_step
|
||||
global_step = initial_step // (accelerator.num_processes * args.gradient_accumulation_steps)
|
||||
|
||||
for epoch in range(epoch_to_start, num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
||||
if initial_step > steps_per_epoch:
|
||||
logger.info(f"skipping epoch {epoch+1} because initial_step (multiplied) is {initial_step}")
|
||||
initial_step -= steps_per_epoch
|
||||
continue
|
||||
|
||||
metadata["ss_epoch"] = str(epoch + 1)
|
||||
|
||||
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
active_dataloader = train_dataloader
|
||||
if initial_step > 0:
|
||||
logger.info(f"skipping {initial_step} batches in epoch {epoch+1}")
|
||||
active_dataloader = accelerator.skip_first_batches(
|
||||
train_dataloader, initial_step * args.gradient_accumulation_steps
|
||||
)
|
||||
initial_step = 0
|
||||
|
||||
for step, batch in enumerate(active_dataloader):
|
||||
current_step.value = global_step
|
||||
|
||||
with accelerator.accumulate(training_model):
|
||||
on_step_start(text_encoder, unet)
|
||||
|
||||
@@ -1101,6 +1183,25 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_until_initial_step",
|
||||
action="store_true",
|
||||
help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initial_epoch",
|
||||
type=int,
|
||||
default=None,
|
||||
help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`."
|
||||
+ " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initial_step",
|
||||
type=int,
|
||||
default=None,
|
||||
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
|
||||
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user