Merge branch 'kohya-ss:main' into weighted_captions

This commit is contained in:
AI-Casanova
2023-04-07 14:55:40 -05:00
committed by GitHub
13 changed files with 700 additions and 174 deletions

View File

@@ -127,7 +127,43 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
## Change History ## Change History
- 4 Apr. 2023, 2023/4/4: ### 6 Apr. 2023, 2023/4/6:
- There may be bugs because I changed a lot. If you cannot revert the script to the previous version when a problem occurs, please wait for the update for a while.
- Added a feature to upload model and state to HuggingFace. Thanks to ddPn08 for the contribution! [PR #348](https://github.com/kohya-ss/sd-scripts/pull/348)
- When `--huggingface_repo_id` is specified, the model is uploaded to HuggingFace at the same time as saving the model.
- Please note that the access token is handled with caution. Please refer to the [HuggingFace documentation](https://huggingface.co/docs/hub/security-tokens).
- For example, specify other arguments as follows.
- `--huggingface_repo_id "your-hf-name/your-model" --huggingface_path_in_repo "path" --huggingface_repo_type model --huggingface_repo_visibility private --huggingface_token hf_YourAccessTokenHere`
- If `public` is specified for `--huggingface_repo_visibility`, the repository will be public. If the option is omitted or `private` (or anything other than `public`) is specified, it will be private.
- If you specify `--save_state` and `--save_state_to_huggingface`, the state will also be uploaded.
- If you specify `--resume` and `--resume_from_huggingface`, the state will be downloaded from HuggingFace and resumed.
- In this case, the `--resume` option is `--resume {repo_id}/{path_in_repo}:{revision}:{repo_type}`. For example: `--resume_from_huggingface --resume your-hf-name/your-model/path/test-000002-state:main:model`
- If you specify `--async_upload`, the upload will be done asynchronously.
- Added the documentation for applying LoRA to generate with the standard pipeline of Diffusers. [training LoRA](./train_network_README-ja.md#diffusersのpipelineで生成する) (Japanese only)
- Support for Attention Couple and regional LoRA in `gen_img_diffusers.py`.
- If you use ` AND ` to separate the prompts, each sub-prompt is sequentially applied to LoRA. `--mask_path` is treated as a mask image. The number of sub-prompts and the number of LoRA must match.
- 大きく変更したため不具合があるかもしれません。問題が起きた時にスクリプトを前のバージョンに戻せない場合は、しばらく更新を控えてください。
- モデルおよびstateをHuggingFaceにアップロードする機能を各スクリプトに追加しました。 [PR #348](https://github.com/kohya-ss/sd-scripts/pull/348) ddPn08 氏の貢献に感謝します。
- `--huggingface_repo_id`が指定されているとモデル保存時に同時にHuggingFaceにアップロードします。
- アクセストークンの取り扱いに注意してください。[HuggingFaceのドキュメント](https://huggingface.co/docs/hub/security-tokens)を参照してください。
- 他の引数をたとえば以下のように指定してください。
- `--huggingface_repo_id "your-hf-name/your-model" --huggingface_path_in_repo "path" --huggingface_repo_type model --huggingface_repo_visibility private --huggingface_token hf_YourAccessTokenHere`
- `--huggingface_repo_visibility`に`public`を指定するとリポジトリが公開されます。省略時または`private`(など`public`以外)を指定すると非公開になります。
- `--save_state`オプション指定時に`--save_state_to_huggingface`を指定するとstateもアップロードします。
- `--resume`オプション指定時に`--resume_from_huggingface`を指定するとHuggingFaceからstateをダウンロードして再開します。
- その時の `--resume`オプションは `--resume {repo_id}/{path_in_repo}:{revision}:{repo_type}`になります。例: `--resume_from_huggingface --resume your-hf-name/your-model/path/test-000002-state:main:model`
- `--async_upload`オプションを指定するとアップロードを非同期で行います。
- [LoRAの文書](./train_network_README-ja.md#diffusersのpipelineで生成する)に、LoRAを適用してDiffusersの標準的なパイプラインで生成する方法を追記しました。
- `gen_img_diffusers.py` で Attention Couple および領域別LoRAに対応しました。
- プロンプトを` AND `で区切ると各サブプロンプトが順にLoRAに適用されます。`--mask_path` がマスク画像として扱われます。サブプロンプトの数とLoRAの数は一致している必要があります。
### 4 Apr. 2023, 2023/4/4, Release 0.6.0:
- There may be bugs because I changed a lot. If you cannot revert the script to the previous version when a problem occurs, please wait for the update for a while. - There may be bugs because I changed a lot. If you cannot revert the script to the previous version when a problem occurs, please wait for the update for a while.
- The learning rate and dim (rank) of each block may not work with other modules (LyCORIS, etc.) because the module needs to be changed. - The learning rate and dim (rank) of each block may not work with other modules (LyCORIS, etc.) because the module needs to be changed.

View File

@@ -231,9 +231,7 @@ def train(args):
train_util.patch_accelerator_for_fp16_training(accelerator) train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする # resumeする
if args.resume is not None: train_util.resume_from_local_or_hf_if_specified(accelerator, args)
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
# epoch数を計算する # epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

View File

@@ -92,6 +92,7 @@ from PIL.PngImagePlugin import PngInfo
import library.model_util as model_util import library.model_util as model_util
import library.train_util as train_util import library.train_util as train_util
from networks.lora import LoRANetwork
import tools.original_control_net as original_control_net import tools.original_control_net as original_control_net
from tools.original_control_net import ControlNetInfo from tools.original_control_net import ControlNetInfo
@@ -634,6 +635,7 @@ class PipelineLike:
img2img_noise=None, img2img_noise=None,
clip_prompts=None, clip_prompts=None,
clip_guide_images=None, clip_guide_images=None,
networks: Optional[List[LoRANetwork]] = None,
**kwargs, **kwargs,
): ):
r""" r"""
@@ -717,6 +719,7 @@ class PipelineLike:
batch_size = len(prompt) batch_size = len(prompt)
else: else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
reginonal_network = " AND " in prompt[0]
vae_batch_size = ( vae_batch_size = (
batch_size batch_size
@@ -1010,6 +1013,11 @@ class PipelineLike:
# predict the noise residual # predict the noise residual
if self.control_nets: if self.control_nets:
if reginonal_network:
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2::num_sub_and_neg_prompts] # last subprompt
else:
text_emb_last = text_embeddings
noise_pred = original_control_net.call_unet_and_control_net( noise_pred = original_control_net.call_unet_and_control_net(
i, i,
num_latent_input, num_latent_input,
@@ -1019,7 +1027,7 @@ class PipelineLike:
i / len(timesteps), i / len(timesteps),
latent_model_input, latent_model_input,
t, t,
text_embeddings, text_emb_last,
).sample ).sample
else: else:
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
@@ -1890,6 +1898,12 @@ def get_weighted_text_embeddings(
if isinstance(prompt, str): if isinstance(prompt, str):
prompt = [prompt] prompt = [prompt]
# split the prompts with "AND". each prompt must have the same number of splits
new_prompts = []
for p in prompt:
new_prompts.extend(p.split(" AND "))
prompt = new_prompts
if not skip_parsing: if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer) prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer)
if uncond_prompt is not None: if uncond_prompt is not None:
@@ -2059,6 +2073,7 @@ class BatchDataExt(NamedTuple):
negative_scale: float negative_scale: float
strength: float strength: float
network_muls: Tuple[float] network_muls: Tuple[float]
num_sub_prompts: int
class BatchData(NamedTuple): class BatchData(NamedTuple):
@@ -2276,14 +2291,18 @@ def main(args):
print(f"metadata for: {network_weight}: {metadata}") print(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights( network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoder, unet, **net_kwargs network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
) )
else: else:
raise ValueError("No weight. Weight is required.") raise ValueError("No weight. Weight is required.")
if network is None: if network is None:
return return
if not args.network_merge: mergiable = hasattr(network, "merge_to")
if args.network_merge and not mergiable:
print("network is not mergiable. ignore merge option.")
if not args.network_merge or not mergiable:
network.apply_to(text_encoder, unet) network.apply_to(text_encoder, unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
print(f"weights are loaded: {info}") print(f"weights are loaded: {info}")
@@ -2349,12 +2368,12 @@ def main(args):
if args.diffusers_xformers: if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention() pipe.enable_xformers_memory_efficient_attention()
# Extended Textual Inversion および Textual Inversionを処理する
if args.XTI_embeddings: if args.XTI_embeddings:
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
# Textual Inversionを処理する
if args.textual_inversion_embeddings: if args.textual_inversion_embeddings:
token_ids_embeds = [] token_ids_embeds = []
for embeds_file in args.textual_inversion_embeddings: for embeds_file in args.textual_inversion_embeddings:
@@ -2558,16 +2577,22 @@ def main(args):
print(f"resize img2img mask images to {args.W}*{args.H}") print(f"resize img2img mask images to {args.W}*{args.H}")
mask_images = resize_images(mask_images, (args.W, args.H)) mask_images = resize_images(mask_images, (args.W, args.H))
regional_network = False
if networks and mask_images: if networks and mask_images:
# mask を領域情報として流用する、現在は1枚だけ対応 # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応
# TODO 複数のnetwork classの混在時の考慮 regional_network = True
print("use mask as region") print("use mask as region")
# import cv2
# for i in range(3): size = None
# cv2.imshow("msk", np.array(mask_images[0])[:,:,i]) for i, network in enumerate(networks):
# cv2.waitKey() if i < 3:
# cv2.destroyAllWindows() np_mask = np.array(mask_images[0])
networks[0].__class__.set_regions(networks, np.array(mask_images[0])) np_mask = np_mask[:, :, i]
size = np_mask.shape
else:
np_mask = np.full(size, 255, dtype=np.uint8)
mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0)
network.set_region(i, i == len(networks) - 1, mask)
mask_images = None mask_images = None
prev_image = None # for VGG16 guided prev_image = None # for VGG16 guided
@@ -2623,7 +2648,14 @@ def main(args):
height_1st = height_1st - height_1st % 32 height_1st = height_1st - height_1st % 32
ext_1st = BatchDataExt( ext_1st = BatchDataExt(
width_1st, height_1st, args.highres_fix_steps, ext.scale, ext.negative_scale, ext.strength, ext.network_muls width_1st,
height_1st,
args.highres_fix_steps,
ext.scale,
ext.negative_scale,
ext.strength,
ext.network_muls,
ext.num_sub_prompts,
) )
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st)) batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
images_1st = process_batch(batch_1st, True, True) images_1st = process_batch(batch_1st, True, True)
@@ -2651,7 +2683,7 @@ def main(args):
( (
return_latents, return_latents,
(step_first, _, _, _, init_image, mask_image, _, guide_image), (step_first, _, _, _, init_image, mask_image, _, guide_image),
(width, height, steps, scale, negative_scale, strength, network_muls), (width, height, steps, scale, negative_scale, strength, network_muls, num_sub_prompts),
) = batch[0] ) = batch[0]
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
@@ -2743,8 +2775,11 @@ def main(args):
# generate # generate
if networks: if networks:
shared = {}
for n, m in zip(networks, network_muls if network_muls else network_default_muls): for n, m in zip(networks, network_muls if network_muls else network_default_muls):
n.set_multiplier(m) n.set_multiplier(m)
if regional_network:
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared)
images = pipe( images = pipe(
prompts, prompts,
@@ -2969,11 +3004,26 @@ def main(args):
print("Use previous image as guide image.") print("Use previous image as guide image.")
guide_image = prev_image guide_image = prev_image
if regional_network:
num_sub_prompts = len(prompt.split(" AND "))
assert (
len(networks) <= num_sub_prompts
), "Number of networks must be less than or equal to number of sub prompts."
else:
num_sub_prompts = None
b1 = BatchData( b1 = BatchData(
False, False,
BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
BatchDataExt( BatchDataExt(
width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None width,
height,
steps,
scale,
negative_scale,
strength,
tuple(network_muls) if network_muls else None,
num_sub_prompts,
), ),
) )
if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
@@ -3197,6 +3247,9 @@ def setup_parser() -> argparse.ArgumentParser:
nargs="*", nargs="*",
help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
) )
# parser.add_argument(
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
# )
return parser return parser

View File

@@ -0,0 +1,78 @@
from typing import *
from huggingface_hub import HfApi
from pathlib import Path
import argparse
import os
from library.utils import fire_in_thread
def exists_repo(
repo_id: str, repo_type: str, revision: str = "main", token: str = None
):
api = HfApi(
token=token,
)
try:
api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
return True
except:
return False
def upload(
args: argparse.Namespace,
src: Union[str, Path, bytes, BinaryIO],
dest_suffix: str = "",
force_sync_upload: bool = False,
):
repo_id = args.huggingface_repo_id
repo_type = args.huggingface_repo_type
token = args.huggingface_token
path_in_repo = args.huggingface_path_in_repo + dest_suffix
private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
api = HfApi(token=token)
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
is_folder = (type(src) == str and os.path.isdir(src)) or (
isinstance(src, Path) and src.is_dir()
)
def uploader():
if is_folder:
api.upload_folder(
repo_id=repo_id,
repo_type=repo_type,
folder_path=src,
path_in_repo=path_in_repo,
)
else:
api.upload_file(
repo_id=repo_id,
repo_type=repo_type,
path_or_fileobj=src,
path_in_repo=path_in_repo,
)
if args.async_upload and not force_sync_upload:
fire_in_thread(uploader)
else:
uploader()
def list_dir(
repo_id: str,
subfolder: str,
repo_type: str,
revision: str = "main",
token: str = None,
):
api = HfApi(
token=token,
)
repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
file_list = [
file for file in repo_info.siblings if file.rfilename.startswith(subfolder)
]
return file_list

View File

@@ -2,6 +2,7 @@
import argparse import argparse
import ast import ast
import asyncio
import importlib import importlib
import json import json
import pathlib import pathlib
@@ -49,6 +50,7 @@ from diffusers import (
KDPM2DiscreteScheduler, KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler, KDPM2AncestralDiscreteScheduler,
) )
from huggingface_hub import hf_hub_download
import albumentations as albu import albumentations as albu
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@@ -58,6 +60,7 @@ from torch import einsum
import safetensors.torch import safetensors.torch
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
import library.model_util as model_util import library.model_util as model_util
import library.huggingface_util as huggingface_util
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14" TOKENIZER_PATH = "openai/clip-vit-large-patch14"
@@ -1441,7 +1444,6 @@ def glob_images_pathlib(dir_path, recursive):
# endregion # endregion
# region モジュール入れ替え部 # region モジュール入れ替え部
""" """
高速化のためのモジュール入れ替え 高速化のためのモジュール入れ替え
@@ -1896,6 +1898,38 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ") parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名") parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
parser.add_argument(
"--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名"
)
parser.add_argument(
"--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類"
)
parser.add_argument(
"--huggingface_path_in_repo",
type=str,
default=None,
help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス",
)
parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン")
parser.add_argument(
"--huggingface_repo_visibility",
type=str,
default=None,
help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定'public'で公開、'private'またはNoneで非公開",
)
parser.add_argument(
"--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する"
)
parser.add_argument(
"--resume_from_huggingface",
action="store_true",
help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})",
)
parser.add_argument(
"--async_upload",
action="store_true",
help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする",
)
parser.add_argument( parser.add_argument(
"--save_precision", "--save_precision",
type=str, type=str,
@@ -2261,6 +2295,57 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
# region utils # region utils
def resume_from_local_or_hf_if_specified(accelerator, args):
if not args.resume:
return
if not args.resume_from_huggingface:
print(f"resume training from local state: {args.resume}")
accelerator.load_state(args.resume)
return
print(f"resume training from huggingface state: {args.resume}")
repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1]
path_in_repo = "/".join(args.resume.split("/")[2:])
revision = None
repo_type = None
if ":" in path_in_repo:
divided = path_in_repo.split(":")
if len(divided) == 2:
path_in_repo, revision = divided
repo_type = "model"
else:
path_in_repo, revision, repo_type = divided
print(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}")
list_files = huggingface_util.list_dir(
repo_id=repo_id,
subfolder=path_in_repo,
revision=revision,
token=args.huggingface_token,
repo_type=repo_type,
)
async def download(filename) -> str:
def task():
return hf_hub_download(
repo_id=repo_id,
filename=filename,
revision=revision,
repo_type=repo_type,
token=args.huggingface_token,
)
return await asyncio.get_event_loop().run_in_executor(None, task)
loop = asyncio.get_event_loop()
results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files]))
if len(results) == 0:
raise ValueError("No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした")
dirname = os.path.dirname(results[0])
accelerator.load_state(dirname)
def get_optimizer(args, trainable_params): def get_optimizer(args, trainable_params):
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor" # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
@@ -2645,7 +2730,7 @@ def prepare_dtype(args: argparse.Namespace):
return weight_dtype, save_dtype return weight_dtype, save_dtype
def load_target_model(args: argparse.Namespace, weight_dtype, device='cpu'): def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
name_or_path = args.pretrained_model_name_or_path name_or_path = args.pretrained_model_name_or_path
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
@@ -2772,6 +2857,8 @@ def save_sd_model_on_epoch_end(
model_util.save_stable_diffusion_checkpoint( model_util.save_stable_diffusion_checkpoint(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
) )
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
def remove_sd(old_epoch_no): def remove_sd(old_epoch_no):
_, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no) _, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
@@ -2791,6 +2878,8 @@ def save_sd_model_on_epoch_end(
model_util.save_diffusers_checkpoint( model_util.save_diffusers_checkpoint(
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
) )
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, out_dir, "/" + model_name)
def remove_du(old_epoch_no): def remove_du(old_epoch_no):
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no)) out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
@@ -2808,7 +2897,11 @@ def save_sd_model_on_epoch_end(
def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no): def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no):
print("saving state.") print("saving state.")
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))) state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
accelerator.save_state(state_dir)
if args.save_state_to_huggingface:
print("uploading state to huggingface.")
huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
if last_n_epochs is not None: if last_n_epochs is not None:
@@ -2819,6 +2912,17 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e
shutil.rmtree(state_dir_old) shutil.rmtree(state_dir_old)
def save_state_on_train_end(args: argparse.Namespace, accelerator):
print("saving last state.")
os.makedirs(args.output_dir, exist_ok=True)
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))
accelerator.save_state(state_dir)
if args.save_state_to_huggingface:
print("uploading last state to huggingface.")
huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name))
def save_sd_model_on_train_end( def save_sd_model_on_train_end(
args: argparse.Namespace, args: argparse.Namespace,
src_path: str, src_path: str,
@@ -2843,6 +2947,8 @@ def save_sd_model_on_train_end(
model_util.save_stable_diffusion_checkpoint( model_util.save_stable_diffusion_checkpoint(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae
) )
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
else: else:
out_dir = os.path.join(args.output_dir, model_name) out_dir = os.path.join(args.output_dir, model_name)
os.makedirs(out_dir, exist_ok=True) os.makedirs(out_dir, exist_ok=True)
@@ -2851,13 +2957,8 @@ def save_sd_model_on_train_end(
model_util.save_diffusers_checkpoint( model_util.save_diffusers_checkpoint(
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
) )
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
def save_state_on_train_end(args: argparse.Namespace, accelerator):
print("saving last state.")
os.makedirs(args.output_dir, exist_ok=True)
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
# scheduler: # scheduler:

6
library/utils.py Normal file
View File

@@ -0,0 +1,6 @@
import threading
from typing import *
def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()

View File

@@ -10,7 +10,6 @@ import numpy as np
import torch import torch
import re import re
from library import train_util
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
@@ -61,8 +60,6 @@ class LoRAModule(torch.nn.Module):
self.multiplier = multiplier self.multiplier = multiplier
self.org_module = org_module # remove in applying self.org_module = org_module # remove in applying
self.region = None
self.region_mask = None
def apply_to(self): def apply_to(self):
self.org_forward = self.org_module.forward self.org_forward = self.org_module.forward
@@ -105,39 +102,187 @@ class LoRAModule(torch.nn.Module):
self.region_mask = None self.region_mask = None
def forward(self, x): def forward(self, x):
if self.region is None:
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
# regional LoRA FIXME same as additional-network extension
if x.size()[1] % 77 == 0: class LoRAInfModule(LoRAModule):
# print(f"LoRA for context: {self.lora_name}") def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
self.region = None super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
# check regional or not by lora_name
self.text_encoder = False
if lora_name.startswith("lora_te_"):
self.regional = False
self.use_sub_prompt = True
self.text_encoder = True
elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
self.regional = False
self.use_sub_prompt = True
elif "time_emb" in lora_name:
self.regional = False
self.use_sub_prompt = False
else:
self.regional = True
self.use_sub_prompt = False
self.network: LoRANetwork = None
def set_network(self, network):
self.network = network
def default_forward(self, x):
# print("default_forward", self.lora_name, x.size())
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
# calculate region mask first time def forward(self, x):
if self.region_mask is None: if self.network is None or self.network.sub_prompt_index is None:
return self.default_forward(x)
if not self.regional and not self.use_sub_prompt:
return self.default_forward(x)
if self.regional:
return self.regional_forward(x)
else:
return self.sub_prompt_forward(x)
def get_mask_for_x(self, x):
# calculate size from shape of x
if len(x.size()) == 4: if len(x.size()) == 4:
h, w = x.size()[2:4] h, w = x.size()[2:4]
area = h * w
else: else:
seq_len = x.size()[1] area = x.size()[1]
ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
h = int(self.region.size()[0] / ratio + 0.5)
w = seq_len // h
r = self.region.to(x.device) mask = self.network.mask_dic[area]
if r.dtype == torch.bfloat16: if mask is None:
r = r.to(torch.float) raise ValueError(f"mask is None for resolution {area}")
r = r.unsqueeze(0).unsqueeze(1) if len(x.size()) != 4:
# print(self.lora_name, self.region.size(), x.size(), r.size(), h, w) mask = torch.reshape(mask, (1, -1, 1))
r = torch.nn.functional.interpolate(r, (h, w), mode="bilinear") return mask
r = r.to(x.dtype)
if len(x.size()) == 3: def regional_forward(self, x):
r = torch.reshape(r, (1, x.size()[1], -1)) if "attn2_to_out" in self.lora_name:
return self.to_out_forward(x)
self.region_mask = r if self.network.mask_dic is None: # sub_prompt_index >= 3
return self.default_forward(x)
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask # apply mask for LoRA result
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
mask = self.get_mask_for_x(lx)
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
lx = lx * mask
x = self.org_forward(x)
x = x + lx
if "attn2_to_q" in self.lora_name and self.network.is_last_network:
x = self.postp_to_q(x)
return x
def postp_to_q(self, x):
# repeat x to num_sub_prompts
has_real_uncond = x.size()[0] // self.network.batch_size == 3
qc = self.network.batch_size # uncond
qc += self.network.batch_size * self.network.num_sub_prompts # cond
if has_real_uncond:
qc += self.network.batch_size # real_uncond
query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype)
query[: self.network.batch_size] = x[: self.network.batch_size]
for i in range(self.network.batch_size):
qi = self.network.batch_size + i * self.network.num_sub_prompts
query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i]
if has_real_uncond:
query[-self.network.batch_size :] = x[-self.network.batch_size :]
# print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
return query
def sub_prompt_forward(self, x):
if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA
return self.org_forward(x)
emb_idx = self.network.sub_prompt_index
if not self.text_encoder:
emb_idx += self.network.batch_size
# apply sub prompt of X
lx = x[emb_idx :: self.network.num_sub_prompts]
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
# print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
x = self.org_forward(x)
x[emb_idx :: self.network.num_sub_prompts] += lx
return x
def to_out_forward(self, x):
# print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
if self.network.is_last_network:
masks = [None] * self.network.num_sub_prompts
self.network.shared[self.lora_name] = (None, masks)
else:
lx, masks = self.network.shared[self.lora_name]
# call own LoRA
x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts]
lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale
if self.network.is_last_network:
lx = torch.zeros(
(self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype
)
self.network.shared[self.lora_name] = (lx, masks)
# print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
# if not last network, return x and masks
x = self.org_forward(x)
if not self.network.is_last_network:
return x
lx, masks = self.network.shared.pop(self.lora_name)
# if last network, combine separated x with mask weighted sum
has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2
out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype)
out[: self.network.batch_size] = x[: self.network.batch_size] # uncond
if has_real_uncond:
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
# for i in range(len(masks)):
# if masks[i] is None:
# masks[i] = torch.zeros_like(masks[-1])
mask = torch.cat(masks)
mask_sum = torch.sum(mask, dim=0) + 1e-4
for i in range(self.network.batch_size):
# 1枚の画像ごとに処理する
lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts]
lx1 = lx1 * mask
lx1 = torch.sum(lx1, dim=0)
xi = self.network.batch_size + i * self.network.num_sub_prompts
x1 = x[xi : xi + self.network.num_sub_prompts]
x1 = x1 * mask
x1 = torch.sum(x1, dim=0)
x1 = x1 / mask_sum
x1 = x1 + lx1
out[self.network.batch_size + i] = x1
# print("to_out_forward", x.size(), out.size(), has_real_uncond)
return out
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
@@ -421,7 +566,7 @@ def get_block_index(lora_name: str) -> int:
# Create network from weights for inference, weights are not loaded here (because can be merged) # Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs): def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
if weights_sd is None: if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors": if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open from safetensors.torch import load_file, safe_open
@@ -450,7 +595,11 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
if key not in modules_alpha: if key not in modules_alpha:
modules_alpha = modules_dim[key] modules_alpha = modules_dim[key]
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha) module_class = LoRAInfModule if for_inference else LoRAModule
network = LoRANetwork(
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
)
return network, weights_sd return network, weights_sd
@@ -479,6 +628,7 @@ class LoRANetwork(torch.nn.Module):
conv_block_alphas=None, conv_block_alphas=None,
modules_dim=None, modules_dim=None,
modules_alpha=None, modules_alpha=None,
module_class=LoRAModule,
varbose=False, varbose=False,
) -> None: ) -> None:
""" """
@@ -554,7 +704,7 @@ class LoRANetwork(torch.nn.Module):
skipped.append(lora_name) skipped.append(lora_name)
continue continue
lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha) lora = module_class(lora_name, child_module, self.multiplier, dim, alpha)
loras.append(lora) loras.append(lora)
return loras, skipped return loras, skipped
@@ -750,6 +900,7 @@ class LoRANetwork(torch.nn.Module):
if os.path.splitext(file)[1] == ".safetensors": if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file from safetensors.torch import save_file
from library import train_util
# Precalculate model hashes to save time on indexing # Precalculate model hashes to save time on indexing
if metadata is None: if metadata is None:
@@ -762,17 +913,45 @@ class LoRANetwork(torch.nn.Module):
else: else:
torch.save(state_dict, file) torch.save(state_dict, file)
@staticmethod # mask is a tensor with values from 0 to 1
def set_regions(networks, image): def set_region(self, sub_prompt_index, is_last_network, mask):
image = image.astype(np.float32) / 255.0 if mask.max() == 0:
for i, network in enumerate(networks[:3]): mask = torch.ones_like(mask)
# NOTE: consider averaging overwrapping area
region = image[:, :, i]
if region.max() == 0:
continue
region = torch.tensor(region)
network.set_region(region)
def set_region(self, region): self.mask = mask
for lora in self.unet_loras: self.sub_prompt_index = sub_prompt_index
lora.set_region(region) self.is_last_network = is_last_network
for lora in self.text_encoder_loras + self.unet_loras:
lora.set_network(self)
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
self.batch_size = batch_size
self.num_sub_prompts = num_sub_prompts
self.current_size = (height, width)
self.shared = shared
# create masks
mask = self.mask
mask_dic = {}
mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w
ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight
dtype = ref_weight.dtype
device = ref_weight.device
def resize_add(mh, mw):
# print(mh, mw, mh * mw)
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
m = m.to(device, dtype=dtype)
mask_dic[mh * mw] = m
h = height // 8
w = width // 8
for _ in range(4):
resize_add(h, w)
if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
resize_add(h + h % 2, w + w % 2)
h = (h + 1) // 2
w = (w + 1) // 2
self.mask_dic = mask_dic

View File

@@ -21,6 +21,6 @@ fairscale==0.4.13
# for WD14 captioning # for WD14 captioning
# tensorflow<2.11 # tensorflow<2.11
tensorflow==2.10.1 tensorflow==2.10.1
huggingface-hub==0.12.0 huggingface-hub==0.13.3
# for kohya_ss library # for kohya_ss library
. .

View File

@@ -201,9 +201,7 @@ def train(args):
train_util.patch_accelerator_for_fp16_training(accelerator) train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする # resumeする
if args.resume is not None: train_util.resume_from_local_or_hf_if_specified(accelerator, args)
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
# epoch数を計算する # epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

View File

@@ -24,6 +24,7 @@ from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
@@ -71,7 +72,8 @@ def train(args):
use_dreambooth_method = args.in_json is None use_dreambooth_method = args.in_json is None
use_user_config = args.dataset_config is not None use_user_config = args.dataset_config is not None
if args.seed is not None: if args.seed is None:
args.seed = random.randint(0, 2**32)
set_seed(args.seed) set_seed(args.seed)
tokenizer = train_util.load_tokenizer(args) tokenizer = train_util.load_tokenizer(args)
@@ -308,9 +310,7 @@ def train(args):
train_util.patch_accelerator_for_fp16_training(accelerator) train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする # resumeする
if args.resume is not None: train_util.resume_from_local_or_hf_if_specified(accelerator, args)
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
# epoch数を計算する # epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -658,6 +658,8 @@ def train(args):
metadata["ss_training_finished_at"] = str(time.time()) metadata["ss_training_finished_at"] = str(time.time())
print(f"saving checkpoint: {ckpt_file}") print(f"saving checkpoint: {ckpt_file}")
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
def remove_old_func(old_epoch_no): def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
@@ -697,6 +699,8 @@ def train(args):
print(f"save trained model to {ckpt_file}") print(f"save trained model to {ckpt_file}")
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
print("model saved.") print("model saved.")

View File

@@ -188,6 +188,73 @@ gen_img_diffusers.pyに、--network_module、--network_weightsの各オプショ
--network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。 --network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。
## Diffusersのpipelineで生成する
以下の例を参考にしてください。必要なファイルはnetworks/lora.pyのみです。Diffusersのバージョンは0.10.2以外では動作しない可能性があります。
```python
import torch
from diffusers import StableDiffusionPipeline
from networks.lora import LoRAModule, create_network_from_weights
from safetensors.torch import load_file
# if the ckpt is CompVis based, convert it to Diffusers beforehand with tools/convert_diffusers20_original_sd.py. See --help for more details.
model_id_or_dir = r"model_id_on_hugging_face_or_dir"
device = "cuda"
# create pipe
print(f"creating pipe from {model_id_or_dir}...")
pipe = StableDiffusionPipeline.from_pretrained(model_id_or_dir, revision="fp16", torch_dtype=torch.float16)
pipe = pipe.to(device)
vae = pipe.vae
text_encoder = pipe.text_encoder
unet = pipe.unet
# load lora networks
print(f"loading lora networks...")
lora_path1 = r"lora1.safetensors"
sd = load_file(lora_path1) # If the file is .ckpt, use torch.load instead.
network1, sd = create_network_from_weights(0.5, None, vae, text_encoder,unet, sd)
network1.apply_to(text_encoder, unet)
network1.load_state_dict(sd)
network1.to(device, dtype=torch.float16)
# # You can merge weights instead of apply_to+load_state_dict. network.set_multiplier does not work
# network.merge_to(text_encoder, unet, sd)
lora_path2 = r"lora2.safetensors"
sd = load_file(lora_path2)
network2, sd = create_network_from_weights(0.7, None, vae, text_encoder,unet, sd)
network2.apply_to(text_encoder, unet)
network2.load_state_dict(sd)
network2.to(device, dtype=torch.float16)
lora_path3 = r"lora3.safetensors"
sd = load_file(lora_path3)
network3, sd = create_network_from_weights(0.5, None, vae, text_encoder,unet, sd)
network3.apply_to(text_encoder, unet)
network3.load_state_dict(sd)
network3.to(device, dtype=torch.float16)
# prompts
prompt = "masterpiece, best quality, 1girl, in white shirt, looking at viewer"
negative_prompt = "bad quality, worst quality, bad anatomy, bad hands"
# exec pipe
print("generating image...")
with torch.autocast("cuda"):
image = pipe(prompt, guidance_scale=7.5, negative_prompt=negative_prompt).images[0]
# if not merged, you can use set_multiplier
# network1.set_multiplier(0.8)
# and generate image again...
# save image
image.save(r"by_diffusers..png")
```
## 二つのモデルの差分からLoRAモデルを作成する ## 二つのモデルの差分からLoRAモデルを作成する
[こちらのディスカッション](https://github.com/cloneofsimo/lora/discussions/56)を参考に実装したものです。数式はそのまま使わせていただきました(よく理解していませんが近似には特異値分解を用いるようです)。 [こちらのディスカッション](https://github.com/cloneofsimo/lora/discussions/56)を参考に実装したものです。数式はそのまま使わせていただきました(よく理解していませんが近似には特異値分解を用いるようです)。

View File

@@ -13,6 +13,7 @@ import diffusers
from diffusers import DDPMScheduler from diffusers import DDPMScheduler
import library.train_util as train_util import library.train_util as train_util
import library.huggingface_util as huggingface_util
import library.config_util as config_util import library.config_util as config_util
from library.config_util import ( from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
@@ -304,9 +305,7 @@ def train(args):
text_encoder.to(weight_dtype) text_encoder.to(weight_dtype)
# resumeする # resumeする
if args.resume is not None: train_util.resume_from_local_or_hf_if_specified(accelerator, args)
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
# epoch数を計算する # epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -452,6 +451,8 @@ def train(args):
ckpt_file = os.path.join(args.output_dir, ckpt_name) ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}") print(f"saving checkpoint: {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype) save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
def remove_old_func(old_epoch_no): def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
@@ -492,6 +493,8 @@ def train(args):
print(f"save trained model to {ckpt_file}") print(f"save trained model to {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype) save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
print("model saved.") print("model saved.")

View File

@@ -13,6 +13,7 @@ import diffusers
from diffusers import DDPMScheduler from diffusers import DDPMScheduler
import library.train_util as train_util import library.train_util as train_util
import library.huggingface_util as huggingface_util
import library.config_util as config_util import library.config_util as config_util
from library.config_util import ( from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
@@ -340,9 +341,7 @@ def train(args):
text_encoder.to(weight_dtype) text_encoder.to(weight_dtype)
# resumeする # resumeする
if args.resume is not None: train_util.resume_from_local_or_hf_if_specified(accelerator, args)
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
# epoch数を計算する # epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -493,6 +492,8 @@ def train(args):
ckpt_file = os.path.join(args.output_dir, ckpt_name) ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}") print(f"saving checkpoint: {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype) save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
def remove_old_func(old_epoch_no): def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
@@ -534,6 +535,8 @@ def train(args):
print(f"save trained model to {ckpt_file}") print(f"save trained model to {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype) save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
print("model saved.") print("model saved.")