mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'kohya-ss:main' into weighted_captions
This commit is contained in:
38
README.md
38
README.md
@@ -127,7 +127,43 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
|||||||
|
|
||||||
## Change History
|
## Change History
|
||||||
|
|
||||||
- 4 Apr. 2023, 2023/4/4:
|
### 6 Apr. 2023, 2023/4/6:
|
||||||
|
- There may be bugs because I changed a lot. If you cannot revert the script to the previous version when a problem occurs, please wait for the update for a while.
|
||||||
|
|
||||||
|
- Added a feature to upload model and state to HuggingFace. Thanks to ddPn08 for the contribution! [PR #348](https://github.com/kohya-ss/sd-scripts/pull/348)
|
||||||
|
- When `--huggingface_repo_id` is specified, the model is uploaded to HuggingFace at the same time as saving the model.
|
||||||
|
- Please note that the access token is handled with caution. Please refer to the [HuggingFace documentation](https://huggingface.co/docs/hub/security-tokens).
|
||||||
|
- For example, specify other arguments as follows.
|
||||||
|
- `--huggingface_repo_id "your-hf-name/your-model" --huggingface_path_in_repo "path" --huggingface_repo_type model --huggingface_repo_visibility private --huggingface_token hf_YourAccessTokenHere`
|
||||||
|
- If `public` is specified for `--huggingface_repo_visibility`, the repository will be public. If the option is omitted or `private` (or anything other than `public`) is specified, it will be private.
|
||||||
|
- If you specify `--save_state` and `--save_state_to_huggingface`, the state will also be uploaded.
|
||||||
|
- If you specify `--resume` and `--resume_from_huggingface`, the state will be downloaded from HuggingFace and resumed.
|
||||||
|
- In this case, the `--resume` option is `--resume {repo_id}/{path_in_repo}:{revision}:{repo_type}`. For example: `--resume_from_huggingface --resume your-hf-name/your-model/path/test-000002-state:main:model`
|
||||||
|
- If you specify `--async_upload`, the upload will be done asynchronously.
|
||||||
|
- Added the documentation for applying LoRA to generate with the standard pipeline of Diffusers. [training LoRA](./train_network_README-ja.md#diffusersのpipelineで生成する) (Japanese only)
|
||||||
|
- Support for Attention Couple and regional LoRA in `gen_img_diffusers.py`.
|
||||||
|
- If you use ` AND ` to separate the prompts, each sub-prompt is sequentially applied to LoRA. `--mask_path` is treated as a mask image. The number of sub-prompts and the number of LoRA must match.
|
||||||
|
|
||||||
|
|
||||||
|
- 大きく変更したため不具合があるかもしれません。問題が起きた時にスクリプトを前のバージョンに戻せない場合は、しばらく更新を控えてください。
|
||||||
|
|
||||||
|
|
||||||
|
- モデルおよびstateをHuggingFaceにアップロードする機能を各スクリプトに追加しました。 [PR #348](https://github.com/kohya-ss/sd-scripts/pull/348) ddPn08 氏の貢献に感謝します。
|
||||||
|
- `--huggingface_repo_id`が指定されているとモデル保存時に同時にHuggingFaceにアップロードします。
|
||||||
|
- アクセストークンの取り扱いに注意してください。[HuggingFaceのドキュメント](https://huggingface.co/docs/hub/security-tokens)を参照してください。
|
||||||
|
- 他の引数をたとえば以下のように指定してください。
|
||||||
|
- `--huggingface_repo_id "your-hf-name/your-model" --huggingface_path_in_repo "path" --huggingface_repo_type model --huggingface_repo_visibility private --huggingface_token hf_YourAccessTokenHere`
|
||||||
|
- `--huggingface_repo_visibility`に`public`を指定するとリポジトリが公開されます。省略時または`private`(など`public`以外)を指定すると非公開になります。
|
||||||
|
- `--save_state`オプション指定時に`--save_state_to_huggingface`を指定するとstateもアップロードします。
|
||||||
|
- `--resume`オプション指定時に`--resume_from_huggingface`を指定するとHuggingFaceからstateをダウンロードして再開します。
|
||||||
|
- その時の `--resume`オプションは `--resume {repo_id}/{path_in_repo}:{revision}:{repo_type}`になります。例: `--resume_from_huggingface --resume your-hf-name/your-model/path/test-000002-state:main:model`
|
||||||
|
- `--async_upload`オプションを指定するとアップロードを非同期で行います。
|
||||||
|
- [LoRAの文書](./train_network_README-ja.md#diffusersのpipelineで生成する)に、LoRAを適用してDiffusersの標準的なパイプラインで生成する方法を追記しました。
|
||||||
|
- `gen_img_diffusers.py` で Attention Couple および領域別LoRAに対応しました。
|
||||||
|
- プロンプトを` AND `で区切ると各サブプロンプトが順にLoRAに適用されます。`--mask_path` がマスク画像として扱われます。サブプロンプトの数とLoRAの数は一致している必要があります。
|
||||||
|
|
||||||
|
|
||||||
|
### 4 Apr. 2023, 2023/4/4, Release 0.6.0:
|
||||||
- There may be bugs because I changed a lot. If you cannot revert the script to the previous version when a problem occurs, please wait for the update for a while.
|
- There may be bugs because I changed a lot. If you cannot revert the script to the previous version when a problem occurs, please wait for the update for a while.
|
||||||
- The learning rate and dim (rank) of each block may not work with other modules (LyCORIS, etc.) because the module needs to be changed.
|
- The learning rate and dim (rank) of each block may not work with other modules (LyCORIS, etc.) because the module needs to be changed.
|
||||||
|
|
||||||
|
|||||||
@@ -231,9 +231,7 @@ def train(args):
|
|||||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||||
|
|
||||||
# resumeする
|
# resumeする
|
||||||
if args.resume is not None:
|
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||||
print(f"resume training from state: {args.resume}")
|
|
||||||
accelerator.load_state(args.resume)
|
|
||||||
|
|
||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
|||||||
@@ -92,6 +92,7 @@ from PIL.PngImagePlugin import PngInfo
|
|||||||
|
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
|
from networks.lora import LoRANetwork
|
||||||
import tools.original_control_net as original_control_net
|
import tools.original_control_net as original_control_net
|
||||||
from tools.original_control_net import ControlNetInfo
|
from tools.original_control_net import ControlNetInfo
|
||||||
|
|
||||||
@@ -634,6 +635,7 @@ class PipelineLike:
|
|||||||
img2img_noise=None,
|
img2img_noise=None,
|
||||||
clip_prompts=None,
|
clip_prompts=None,
|
||||||
clip_guide_images=None,
|
clip_guide_images=None,
|
||||||
|
networks: Optional[List[LoRANetwork]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -717,6 +719,7 @@ class PipelineLike:
|
|||||||
batch_size = len(prompt)
|
batch_size = len(prompt)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||||
|
reginonal_network = " AND " in prompt[0]
|
||||||
|
|
||||||
vae_batch_size = (
|
vae_batch_size = (
|
||||||
batch_size
|
batch_size
|
||||||
@@ -1010,6 +1013,11 @@ class PipelineLike:
|
|||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
if self.control_nets:
|
if self.control_nets:
|
||||||
|
if reginonal_network:
|
||||||
|
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
|
||||||
|
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2::num_sub_and_neg_prompts] # last subprompt
|
||||||
|
else:
|
||||||
|
text_emb_last = text_embeddings
|
||||||
noise_pred = original_control_net.call_unet_and_control_net(
|
noise_pred = original_control_net.call_unet_and_control_net(
|
||||||
i,
|
i,
|
||||||
num_latent_input,
|
num_latent_input,
|
||||||
@@ -1019,7 +1027,7 @@ class PipelineLike:
|
|||||||
i / len(timesteps),
|
i / len(timesteps),
|
||||||
latent_model_input,
|
latent_model_input,
|
||||||
t,
|
t,
|
||||||
text_embeddings,
|
text_emb_last,
|
||||||
).sample
|
).sample
|
||||||
else:
|
else:
|
||||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||||
@@ -1890,6 +1898,12 @@ def get_weighted_text_embeddings(
|
|||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
prompt = [prompt]
|
prompt = [prompt]
|
||||||
|
|
||||||
|
# split the prompts with "AND". each prompt must have the same number of splits
|
||||||
|
new_prompts = []
|
||||||
|
for p in prompt:
|
||||||
|
new_prompts.extend(p.split(" AND "))
|
||||||
|
prompt = new_prompts
|
||||||
|
|
||||||
if not skip_parsing:
|
if not skip_parsing:
|
||||||
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer)
|
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer)
|
||||||
if uncond_prompt is not None:
|
if uncond_prompt is not None:
|
||||||
@@ -2059,6 +2073,7 @@ class BatchDataExt(NamedTuple):
|
|||||||
negative_scale: float
|
negative_scale: float
|
||||||
strength: float
|
strength: float
|
||||||
network_muls: Tuple[float]
|
network_muls: Tuple[float]
|
||||||
|
num_sub_prompts: int
|
||||||
|
|
||||||
|
|
||||||
class BatchData(NamedTuple):
|
class BatchData(NamedTuple):
|
||||||
@@ -2276,14 +2291,18 @@ def main(args):
|
|||||||
print(f"metadata for: {network_weight}: {metadata}")
|
print(f"metadata for: {network_weight}: {metadata}")
|
||||||
|
|
||||||
network, weights_sd = imported_module.create_network_from_weights(
|
network, weights_sd = imported_module.create_network_from_weights(
|
||||||
network_mul, network_weight, vae, text_encoder, unet, **net_kwargs
|
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("No weight. Weight is required.")
|
raise ValueError("No weight. Weight is required.")
|
||||||
if network is None:
|
if network is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not args.network_merge:
|
mergiable = hasattr(network, "merge_to")
|
||||||
|
if args.network_merge and not mergiable:
|
||||||
|
print("network is not mergiable. ignore merge option.")
|
||||||
|
|
||||||
|
if not args.network_merge or not mergiable:
|
||||||
network.apply_to(text_encoder, unet)
|
network.apply_to(text_encoder, unet)
|
||||||
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
||||||
print(f"weights are loaded: {info}")
|
print(f"weights are loaded: {info}")
|
||||||
@@ -2349,12 +2368,12 @@ def main(args):
|
|||||||
if args.diffusers_xformers:
|
if args.diffusers_xformers:
|
||||||
pipe.enable_xformers_memory_efficient_attention()
|
pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
|
# Extended Textual Inversion および Textual Inversionを処理する
|
||||||
if args.XTI_embeddings:
|
if args.XTI_embeddings:
|
||||||
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
|
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
|
||||||
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
|
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
|
||||||
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
|
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
|
||||||
|
|
||||||
# Textual Inversionを処理する
|
|
||||||
if args.textual_inversion_embeddings:
|
if args.textual_inversion_embeddings:
|
||||||
token_ids_embeds = []
|
token_ids_embeds = []
|
||||||
for embeds_file in args.textual_inversion_embeddings:
|
for embeds_file in args.textual_inversion_embeddings:
|
||||||
@@ -2558,16 +2577,22 @@ def main(args):
|
|||||||
print(f"resize img2img mask images to {args.W}*{args.H}")
|
print(f"resize img2img mask images to {args.W}*{args.H}")
|
||||||
mask_images = resize_images(mask_images, (args.W, args.H))
|
mask_images = resize_images(mask_images, (args.W, args.H))
|
||||||
|
|
||||||
|
regional_network = False
|
||||||
if networks and mask_images:
|
if networks and mask_images:
|
||||||
# mask を領域情報として流用する、現在は1枚だけ対応
|
# mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応
|
||||||
# TODO 複数のnetwork classの混在時の考慮
|
regional_network = True
|
||||||
print("use mask as region")
|
print("use mask as region")
|
||||||
# import cv2
|
|
||||||
# for i in range(3):
|
size = None
|
||||||
# cv2.imshow("msk", np.array(mask_images[0])[:,:,i])
|
for i, network in enumerate(networks):
|
||||||
# cv2.waitKey()
|
if i < 3:
|
||||||
# cv2.destroyAllWindows()
|
np_mask = np.array(mask_images[0])
|
||||||
networks[0].__class__.set_regions(networks, np.array(mask_images[0]))
|
np_mask = np_mask[:, :, i]
|
||||||
|
size = np_mask.shape
|
||||||
|
else:
|
||||||
|
np_mask = np.full(size, 255, dtype=np.uint8)
|
||||||
|
mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0)
|
||||||
|
network.set_region(i, i == len(networks) - 1, mask)
|
||||||
mask_images = None
|
mask_images = None
|
||||||
|
|
||||||
prev_image = None # for VGG16 guided
|
prev_image = None # for VGG16 guided
|
||||||
@@ -2623,7 +2648,14 @@ def main(args):
|
|||||||
height_1st = height_1st - height_1st % 32
|
height_1st = height_1st - height_1st % 32
|
||||||
|
|
||||||
ext_1st = BatchDataExt(
|
ext_1st = BatchDataExt(
|
||||||
width_1st, height_1st, args.highres_fix_steps, ext.scale, ext.negative_scale, ext.strength, ext.network_muls
|
width_1st,
|
||||||
|
height_1st,
|
||||||
|
args.highres_fix_steps,
|
||||||
|
ext.scale,
|
||||||
|
ext.negative_scale,
|
||||||
|
ext.strength,
|
||||||
|
ext.network_muls,
|
||||||
|
ext.num_sub_prompts,
|
||||||
)
|
)
|
||||||
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
|
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
|
||||||
images_1st = process_batch(batch_1st, True, True)
|
images_1st = process_batch(batch_1st, True, True)
|
||||||
@@ -2651,7 +2683,7 @@ def main(args):
|
|||||||
(
|
(
|
||||||
return_latents,
|
return_latents,
|
||||||
(step_first, _, _, _, init_image, mask_image, _, guide_image),
|
(step_first, _, _, _, init_image, mask_image, _, guide_image),
|
||||||
(width, height, steps, scale, negative_scale, strength, network_muls),
|
(width, height, steps, scale, negative_scale, strength, network_muls, num_sub_prompts),
|
||||||
) = batch[0]
|
) = batch[0]
|
||||||
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
||||||
|
|
||||||
@@ -2743,8 +2775,11 @@ def main(args):
|
|||||||
|
|
||||||
# generate
|
# generate
|
||||||
if networks:
|
if networks:
|
||||||
|
shared = {}
|
||||||
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
|
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
|
||||||
n.set_multiplier(m)
|
n.set_multiplier(m)
|
||||||
|
if regional_network:
|
||||||
|
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared)
|
||||||
|
|
||||||
images = pipe(
|
images = pipe(
|
||||||
prompts,
|
prompts,
|
||||||
@@ -2969,11 +3004,26 @@ def main(args):
|
|||||||
print("Use previous image as guide image.")
|
print("Use previous image as guide image.")
|
||||||
guide_image = prev_image
|
guide_image = prev_image
|
||||||
|
|
||||||
|
if regional_network:
|
||||||
|
num_sub_prompts = len(prompt.split(" AND "))
|
||||||
|
assert (
|
||||||
|
len(networks) <= num_sub_prompts
|
||||||
|
), "Number of networks must be less than or equal to number of sub prompts."
|
||||||
|
else:
|
||||||
|
num_sub_prompts = None
|
||||||
|
|
||||||
b1 = BatchData(
|
b1 = BatchData(
|
||||||
False,
|
False,
|
||||||
BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
||||||
BatchDataExt(
|
BatchDataExt(
|
||||||
width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None
|
width,
|
||||||
|
height,
|
||||||
|
steps,
|
||||||
|
scale,
|
||||||
|
negative_scale,
|
||||||
|
strength,
|
||||||
|
tuple(network_muls) if network_muls else None,
|
||||||
|
num_sub_prompts,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
|
if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
|
||||||
@@ -3197,6 +3247,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
nargs="*",
|
nargs="*",
|
||||||
help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
|
help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
|
||||||
)
|
)
|
||||||
|
# parser.add_argument(
|
||||||
|
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
||||||
|
# )
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|||||||
78
library/huggingface_util.py
Normal file
78
library/huggingface_util.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
from typing import *
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
from pathlib import Path
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
from library.utils import fire_in_thread
|
||||||
|
|
||||||
|
|
||||||
|
def exists_repo(
|
||||||
|
repo_id: str, repo_type: str, revision: str = "main", token: str = None
|
||||||
|
):
|
||||||
|
api = HfApi(
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def upload(
|
||||||
|
args: argparse.Namespace,
|
||||||
|
src: Union[str, Path, bytes, BinaryIO],
|
||||||
|
dest_suffix: str = "",
|
||||||
|
force_sync_upload: bool = False,
|
||||||
|
):
|
||||||
|
repo_id = args.huggingface_repo_id
|
||||||
|
repo_type = args.huggingface_repo_type
|
||||||
|
token = args.huggingface_token
|
||||||
|
path_in_repo = args.huggingface_path_in_repo + dest_suffix
|
||||||
|
private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
|
||||||
|
api = HfApi(token=token)
|
||||||
|
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
|
||||||
|
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
|
||||||
|
|
||||||
|
is_folder = (type(src) == str and os.path.isdir(src)) or (
|
||||||
|
isinstance(src, Path) and src.is_dir()
|
||||||
|
)
|
||||||
|
|
||||||
|
def uploader():
|
||||||
|
if is_folder:
|
||||||
|
api.upload_folder(
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type=repo_type,
|
||||||
|
folder_path=src,
|
||||||
|
path_in_repo=path_in_repo,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
api.upload_file(
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type=repo_type,
|
||||||
|
path_or_fileobj=src,
|
||||||
|
path_in_repo=path_in_repo,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.async_upload and not force_sync_upload:
|
||||||
|
fire_in_thread(uploader)
|
||||||
|
else:
|
||||||
|
uploader()
|
||||||
|
|
||||||
|
|
||||||
|
def list_dir(
|
||||||
|
repo_id: str,
|
||||||
|
subfolder: str,
|
||||||
|
repo_type: str,
|
||||||
|
revision: str = "main",
|
||||||
|
token: str = None,
|
||||||
|
):
|
||||||
|
api = HfApi(
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
|
||||||
|
file_list = [
|
||||||
|
file for file in repo_info.siblings if file.rfilename.startswith(subfolder)
|
||||||
|
]
|
||||||
|
return file_list
|
||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import ast
|
import ast
|
||||||
|
import asyncio
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
import pathlib
|
import pathlib
|
||||||
@@ -49,6 +50,7 @@ from diffusers import (
|
|||||||
KDPM2DiscreteScheduler,
|
KDPM2DiscreteScheduler,
|
||||||
KDPM2AncestralDiscreteScheduler,
|
KDPM2AncestralDiscreteScheduler,
|
||||||
)
|
)
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
import albumentations as albu
|
import albumentations as albu
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -58,6 +60,7 @@ from torch import einsum
|
|||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
|
import library.huggingface_util as huggingface_util
|
||||||
|
|
||||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
||||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||||
@@ -1441,7 +1444,6 @@ def glob_images_pathlib(dir_path, recursive):
|
|||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
# region モジュール入れ替え部
|
# region モジュール入れ替え部
|
||||||
"""
|
"""
|
||||||
高速化のためのモジュール入れ替え
|
高速化のためのモジュール入れ替え
|
||||||
@@ -1896,6 +1898,38 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
|||||||
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
||||||
parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
|
parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
|
||||||
parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
|
parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
|
||||||
|
parser.add_argument(
|
||||||
|
"--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--huggingface_path_in_repo",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス",
|
||||||
|
)
|
||||||
|
parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン")
|
||||||
|
parser.add_argument(
|
||||||
|
"--huggingface_repo_visibility",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定('public'で公開、'private'またはNoneで非公開)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--resume_from_huggingface",
|
||||||
|
action="store_true",
|
||||||
|
help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--async_upload",
|
||||||
|
action="store_true",
|
||||||
|
help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save_precision",
|
"--save_precision",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -2261,6 +2295,57 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
|
|||||||
# region utils
|
# region utils
|
||||||
|
|
||||||
|
|
||||||
|
def resume_from_local_or_hf_if_specified(accelerator, args):
|
||||||
|
if not args.resume:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not args.resume_from_huggingface:
|
||||||
|
print(f"resume training from local state: {args.resume}")
|
||||||
|
accelerator.load_state(args.resume)
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"resume training from huggingface state: {args.resume}")
|
||||||
|
repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1]
|
||||||
|
path_in_repo = "/".join(args.resume.split("/")[2:])
|
||||||
|
revision = None
|
||||||
|
repo_type = None
|
||||||
|
if ":" in path_in_repo:
|
||||||
|
divided = path_in_repo.split(":")
|
||||||
|
if len(divided) == 2:
|
||||||
|
path_in_repo, revision = divided
|
||||||
|
repo_type = "model"
|
||||||
|
else:
|
||||||
|
path_in_repo, revision, repo_type = divided
|
||||||
|
print(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}")
|
||||||
|
|
||||||
|
list_files = huggingface_util.list_dir(
|
||||||
|
repo_id=repo_id,
|
||||||
|
subfolder=path_in_repo,
|
||||||
|
revision=revision,
|
||||||
|
token=args.huggingface_token,
|
||||||
|
repo_type=repo_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def download(filename) -> str:
|
||||||
|
def task():
|
||||||
|
return hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename=filename,
|
||||||
|
revision=revision,
|
||||||
|
repo_type=repo_type,
|
||||||
|
token=args.huggingface_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await asyncio.get_event_loop().run_in_executor(None, task)
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files]))
|
||||||
|
if len(results) == 0:
|
||||||
|
raise ValueError("No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした")
|
||||||
|
dirname = os.path.dirname(results[0])
|
||||||
|
accelerator.load_state(dirname)
|
||||||
|
|
||||||
|
|
||||||
def get_optimizer(args, trainable_params):
|
def get_optimizer(args, trainable_params):
|
||||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
|
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
|
||||||
|
|
||||||
@@ -2645,7 +2730,7 @@ def prepare_dtype(args: argparse.Namespace):
|
|||||||
return weight_dtype, save_dtype
|
return weight_dtype, save_dtype
|
||||||
|
|
||||||
|
|
||||||
def load_target_model(args: argparse.Namespace, weight_dtype, device='cpu'):
|
def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
|
||||||
name_or_path = args.pretrained_model_name_or_path
|
name_or_path = args.pretrained_model_name_or_path
|
||||||
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||||
@@ -2772,6 +2857,8 @@ def save_sd_model_on_epoch_end(
|
|||||||
model_util.save_stable_diffusion_checkpoint(
|
model_util.save_stable_diffusion_checkpoint(
|
||||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
||||||
)
|
)
|
||||||
|
if args.huggingface_repo_id is not None:
|
||||||
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
|
||||||
|
|
||||||
def remove_sd(old_epoch_no):
|
def remove_sd(old_epoch_no):
|
||||||
_, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
|
_, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
|
||||||
@@ -2791,6 +2878,8 @@ def save_sd_model_on_epoch_end(
|
|||||||
model_util.save_diffusers_checkpoint(
|
model_util.save_diffusers_checkpoint(
|
||||||
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
||||||
)
|
)
|
||||||
|
if args.huggingface_repo_id is not None:
|
||||||
|
huggingface_util.upload(args, out_dir, "/" + model_name)
|
||||||
|
|
||||||
def remove_du(old_epoch_no):
|
def remove_du(old_epoch_no):
|
||||||
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
|
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
|
||||||
@@ -2808,7 +2897,11 @@ def save_sd_model_on_epoch_end(
|
|||||||
|
|
||||||
def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no):
|
def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no):
|
||||||
print("saving state.")
|
print("saving state.")
|
||||||
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
|
state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
|
||||||
|
accelerator.save_state(state_dir)
|
||||||
|
if args.save_state_to_huggingface:
|
||||||
|
print("uploading state to huggingface.")
|
||||||
|
huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
|
||||||
|
|
||||||
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
|
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
|
||||||
if last_n_epochs is not None:
|
if last_n_epochs is not None:
|
||||||
@@ -2819,6 +2912,17 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e
|
|||||||
shutil.rmtree(state_dir_old)
|
shutil.rmtree(state_dir_old)
|
||||||
|
|
||||||
|
|
||||||
|
def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
||||||
|
print("saving last state.")
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
||||||
|
state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))
|
||||||
|
accelerator.save_state(state_dir)
|
||||||
|
if args.save_state_to_huggingface:
|
||||||
|
print("uploading last state to huggingface.")
|
||||||
|
huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name))
|
||||||
|
|
||||||
|
|
||||||
def save_sd_model_on_train_end(
|
def save_sd_model_on_train_end(
|
||||||
args: argparse.Namespace,
|
args: argparse.Namespace,
|
||||||
src_path: str,
|
src_path: str,
|
||||||
@@ -2843,6 +2947,8 @@ def save_sd_model_on_train_end(
|
|||||||
model_util.save_stable_diffusion_checkpoint(
|
model_util.save_stable_diffusion_checkpoint(
|
||||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae
|
args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae
|
||||||
)
|
)
|
||||||
|
if args.huggingface_repo_id is not None:
|
||||||
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
|
||||||
else:
|
else:
|
||||||
out_dir = os.path.join(args.output_dir, model_name)
|
out_dir = os.path.join(args.output_dir, model_name)
|
||||||
os.makedirs(out_dir, exist_ok=True)
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
@@ -2851,13 +2957,8 @@ def save_sd_model_on_train_end(
|
|||||||
model_util.save_diffusers_checkpoint(
|
model_util.save_diffusers_checkpoint(
|
||||||
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
||||||
)
|
)
|
||||||
|
if args.huggingface_repo_id is not None:
|
||||||
|
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
|
||||||
def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
|
||||||
print("saving last state.")
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
|
||||||
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
|
||||||
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
|
||||||
|
|
||||||
|
|
||||||
# scheduler:
|
# scheduler:
|
||||||
|
|||||||
6
library/utils.py
Normal file
6
library/utils.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
import threading
|
||||||
|
from typing import *
|
||||||
|
|
||||||
|
|
||||||
|
def fire_in_thread(f, *args, **kwargs):
|
||||||
|
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
||||||
261
networks/lora.py
261
networks/lora.py
@@ -10,7 +10,6 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from library import train_util
|
|
||||||
|
|
||||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||||
|
|
||||||
@@ -61,8 +60,6 @@ class LoRAModule(torch.nn.Module):
|
|||||||
|
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
self.org_module = org_module # remove in applying
|
self.org_module = org_module # remove in applying
|
||||||
self.region = None
|
|
||||||
self.region_mask = None
|
|
||||||
|
|
||||||
def apply_to(self):
|
def apply_to(self):
|
||||||
self.org_forward = self.org_module.forward
|
self.org_forward = self.org_module.forward
|
||||||
@@ -105,39 +102,187 @@ class LoRAModule(torch.nn.Module):
|
|||||||
self.region_mask = None
|
self.region_mask = None
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.region is None:
|
|
||||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||||
|
|
||||||
# regional LoRA FIXME same as additional-network extension
|
|
||||||
if x.size()[1] % 77 == 0:
|
class LoRAInfModule(LoRAModule):
|
||||||
# print(f"LoRA for context: {self.lora_name}")
|
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
|
||||||
self.region = None
|
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
|
||||||
|
|
||||||
|
# check regional or not by lora_name
|
||||||
|
self.text_encoder = False
|
||||||
|
if lora_name.startswith("lora_te_"):
|
||||||
|
self.regional = False
|
||||||
|
self.use_sub_prompt = True
|
||||||
|
self.text_encoder = True
|
||||||
|
elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
|
||||||
|
self.regional = False
|
||||||
|
self.use_sub_prompt = True
|
||||||
|
elif "time_emb" in lora_name:
|
||||||
|
self.regional = False
|
||||||
|
self.use_sub_prompt = False
|
||||||
|
else:
|
||||||
|
self.regional = True
|
||||||
|
self.use_sub_prompt = False
|
||||||
|
|
||||||
|
self.network: LoRANetwork = None
|
||||||
|
|
||||||
|
def set_network(self, network):
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def default_forward(self, x):
|
||||||
|
# print("default_forward", self.lora_name, x.size())
|
||||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||||
|
|
||||||
# calculate region mask first time
|
def forward(self, x):
|
||||||
if self.region_mask is None:
|
if self.network is None or self.network.sub_prompt_index is None:
|
||||||
|
return self.default_forward(x)
|
||||||
|
if not self.regional and not self.use_sub_prompt:
|
||||||
|
return self.default_forward(x)
|
||||||
|
|
||||||
|
if self.regional:
|
||||||
|
return self.regional_forward(x)
|
||||||
|
else:
|
||||||
|
return self.sub_prompt_forward(x)
|
||||||
|
|
||||||
|
def get_mask_for_x(self, x):
|
||||||
|
# calculate size from shape of x
|
||||||
if len(x.size()) == 4:
|
if len(x.size()) == 4:
|
||||||
h, w = x.size()[2:4]
|
h, w = x.size()[2:4]
|
||||||
|
area = h * w
|
||||||
else:
|
else:
|
||||||
seq_len = x.size()[1]
|
area = x.size()[1]
|
||||||
ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
|
|
||||||
h = int(self.region.size()[0] / ratio + 0.5)
|
|
||||||
w = seq_len // h
|
|
||||||
|
|
||||||
r = self.region.to(x.device)
|
mask = self.network.mask_dic[area]
|
||||||
if r.dtype == torch.bfloat16:
|
if mask is None:
|
||||||
r = r.to(torch.float)
|
raise ValueError(f"mask is None for resolution {area}")
|
||||||
r = r.unsqueeze(0).unsqueeze(1)
|
if len(x.size()) != 4:
|
||||||
# print(self.lora_name, self.region.size(), x.size(), r.size(), h, w)
|
mask = torch.reshape(mask, (1, -1, 1))
|
||||||
r = torch.nn.functional.interpolate(r, (h, w), mode="bilinear")
|
return mask
|
||||||
r = r.to(x.dtype)
|
|
||||||
|
|
||||||
if len(x.size()) == 3:
|
def regional_forward(self, x):
|
||||||
r = torch.reshape(r, (1, x.size()[1], -1))
|
if "attn2_to_out" in self.lora_name:
|
||||||
|
return self.to_out_forward(x)
|
||||||
|
|
||||||
self.region_mask = r
|
if self.network.mask_dic is None: # sub_prompt_index >= 3
|
||||||
|
return self.default_forward(x)
|
||||||
|
|
||||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask
|
# apply mask for LoRA result
|
||||||
|
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||||
|
mask = self.get_mask_for_x(lx)
|
||||||
|
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
|
||||||
|
lx = lx * mask
|
||||||
|
|
||||||
|
x = self.org_forward(x)
|
||||||
|
x = x + lx
|
||||||
|
|
||||||
|
if "attn2_to_q" in self.lora_name and self.network.is_last_network:
|
||||||
|
x = self.postp_to_q(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def postp_to_q(self, x):
|
||||||
|
# repeat x to num_sub_prompts
|
||||||
|
has_real_uncond = x.size()[0] // self.network.batch_size == 3
|
||||||
|
qc = self.network.batch_size # uncond
|
||||||
|
qc += self.network.batch_size * self.network.num_sub_prompts # cond
|
||||||
|
if has_real_uncond:
|
||||||
|
qc += self.network.batch_size # real_uncond
|
||||||
|
|
||||||
|
query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype)
|
||||||
|
query[: self.network.batch_size] = x[: self.network.batch_size]
|
||||||
|
|
||||||
|
for i in range(self.network.batch_size):
|
||||||
|
qi = self.network.batch_size + i * self.network.num_sub_prompts
|
||||||
|
query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i]
|
||||||
|
|
||||||
|
if has_real_uncond:
|
||||||
|
query[-self.network.batch_size :] = x[-self.network.batch_size :]
|
||||||
|
|
||||||
|
# print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
|
||||||
|
return query
|
||||||
|
|
||||||
|
def sub_prompt_forward(self, x):
|
||||||
|
if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA
|
||||||
|
return self.org_forward(x)
|
||||||
|
|
||||||
|
emb_idx = self.network.sub_prompt_index
|
||||||
|
if not self.text_encoder:
|
||||||
|
emb_idx += self.network.batch_size
|
||||||
|
|
||||||
|
# apply sub prompt of X
|
||||||
|
lx = x[emb_idx :: self.network.num_sub_prompts]
|
||||||
|
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
|
||||||
|
|
||||||
|
# print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
|
||||||
|
|
||||||
|
x = self.org_forward(x)
|
||||||
|
x[emb_idx :: self.network.num_sub_prompts] += lx
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def to_out_forward(self, x):
|
||||||
|
# print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
|
||||||
|
|
||||||
|
if self.network.is_last_network:
|
||||||
|
masks = [None] * self.network.num_sub_prompts
|
||||||
|
self.network.shared[self.lora_name] = (None, masks)
|
||||||
|
else:
|
||||||
|
lx, masks = self.network.shared[self.lora_name]
|
||||||
|
|
||||||
|
# call own LoRA
|
||||||
|
x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts]
|
||||||
|
lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale
|
||||||
|
|
||||||
|
if self.network.is_last_network:
|
||||||
|
lx = torch.zeros(
|
||||||
|
(self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype
|
||||||
|
)
|
||||||
|
self.network.shared[self.lora_name] = (lx, masks)
|
||||||
|
|
||||||
|
# print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
|
||||||
|
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
|
||||||
|
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
|
||||||
|
|
||||||
|
# if not last network, return x and masks
|
||||||
|
x = self.org_forward(x)
|
||||||
|
if not self.network.is_last_network:
|
||||||
|
return x
|
||||||
|
|
||||||
|
lx, masks = self.network.shared.pop(self.lora_name)
|
||||||
|
|
||||||
|
# if last network, combine separated x with mask weighted sum
|
||||||
|
has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2
|
||||||
|
|
||||||
|
out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype)
|
||||||
|
out[: self.network.batch_size] = x[: self.network.batch_size] # uncond
|
||||||
|
if has_real_uncond:
|
||||||
|
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
|
||||||
|
|
||||||
|
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
|
||||||
|
# for i in range(len(masks)):
|
||||||
|
# if masks[i] is None:
|
||||||
|
# masks[i] = torch.zeros_like(masks[-1])
|
||||||
|
|
||||||
|
mask = torch.cat(masks)
|
||||||
|
mask_sum = torch.sum(mask, dim=0) + 1e-4
|
||||||
|
for i in range(self.network.batch_size):
|
||||||
|
# 1枚の画像ごとに処理する
|
||||||
|
lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts]
|
||||||
|
lx1 = lx1 * mask
|
||||||
|
lx1 = torch.sum(lx1, dim=0)
|
||||||
|
|
||||||
|
xi = self.network.batch_size + i * self.network.num_sub_prompts
|
||||||
|
x1 = x[xi : xi + self.network.num_sub_prompts]
|
||||||
|
x1 = x1 * mask
|
||||||
|
x1 = torch.sum(x1, dim=0)
|
||||||
|
x1 = x1 / mask_sum
|
||||||
|
|
||||||
|
x1 = x1 + lx1
|
||||||
|
out[self.network.batch_size + i] = x1
|
||||||
|
|
||||||
|
# print("to_out_forward", x.size(), out.size(), has_real_uncond)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
||||||
@@ -421,7 +566,7 @@ def get_block_index(lora_name: str) -> int:
|
|||||||
|
|
||||||
|
|
||||||
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
||||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs):
|
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||||
if weights_sd is None:
|
if weights_sd is None:
|
||||||
if os.path.splitext(file)[1] == ".safetensors":
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
from safetensors.torch import load_file, safe_open
|
from safetensors.torch import load_file, safe_open
|
||||||
@@ -450,7 +595,11 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
|||||||
if key not in modules_alpha:
|
if key not in modules_alpha:
|
||||||
modules_alpha = modules_dim[key]
|
modules_alpha = modules_dim[key]
|
||||||
|
|
||||||
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
module_class = LoRAInfModule if for_inference else LoRAModule
|
||||||
|
|
||||||
|
network = LoRANetwork(
|
||||||
|
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
|
||||||
|
)
|
||||||
return network, weights_sd
|
return network, weights_sd
|
||||||
|
|
||||||
|
|
||||||
@@ -479,6 +628,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
conv_block_alphas=None,
|
conv_block_alphas=None,
|
||||||
modules_dim=None,
|
modules_dim=None,
|
||||||
modules_alpha=None,
|
modules_alpha=None,
|
||||||
|
module_class=LoRAModule,
|
||||||
varbose=False,
|
varbose=False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -554,7 +704,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
skipped.append(lora_name)
|
skipped.append(lora_name)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha)
|
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha)
|
||||||
loras.append(lora)
|
loras.append(lora)
|
||||||
return loras, skipped
|
return loras, skipped
|
||||||
|
|
||||||
@@ -750,6 +900,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
if os.path.splitext(file)[1] == ".safetensors":
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
|
from library import train_util
|
||||||
|
|
||||||
# Precalculate model hashes to save time on indexing
|
# Precalculate model hashes to save time on indexing
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
@@ -762,17 +913,45 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
torch.save(state_dict, file)
|
torch.save(state_dict, file)
|
||||||
|
|
||||||
@staticmethod
|
# mask is a tensor with values from 0 to 1
|
||||||
def set_regions(networks, image):
|
def set_region(self, sub_prompt_index, is_last_network, mask):
|
||||||
image = image.astype(np.float32) / 255.0
|
if mask.max() == 0:
|
||||||
for i, network in enumerate(networks[:3]):
|
mask = torch.ones_like(mask)
|
||||||
# NOTE: consider averaging overwrapping area
|
|
||||||
region = image[:, :, i]
|
|
||||||
if region.max() == 0:
|
|
||||||
continue
|
|
||||||
region = torch.tensor(region)
|
|
||||||
network.set_region(region)
|
|
||||||
|
|
||||||
def set_region(self, region):
|
self.mask = mask
|
||||||
for lora in self.unet_loras:
|
self.sub_prompt_index = sub_prompt_index
|
||||||
lora.set_region(region)
|
self.is_last_network = is_last_network
|
||||||
|
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
lora.set_network(self)
|
||||||
|
|
||||||
|
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.num_sub_prompts = num_sub_prompts
|
||||||
|
self.current_size = (height, width)
|
||||||
|
self.shared = shared
|
||||||
|
|
||||||
|
# create masks
|
||||||
|
mask = self.mask
|
||||||
|
mask_dic = {}
|
||||||
|
mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w
|
||||||
|
ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight
|
||||||
|
dtype = ref_weight.dtype
|
||||||
|
device = ref_weight.device
|
||||||
|
|
||||||
|
def resize_add(mh, mw):
|
||||||
|
# print(mh, mw, mh * mw)
|
||||||
|
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
|
||||||
|
m = m.to(device, dtype=dtype)
|
||||||
|
mask_dic[mh * mw] = m
|
||||||
|
|
||||||
|
h = height // 8
|
||||||
|
w = width // 8
|
||||||
|
for _ in range(4):
|
||||||
|
resize_add(h, w)
|
||||||
|
if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
|
||||||
|
resize_add(h + h % 2, w + w % 2)
|
||||||
|
h = (h + 1) // 2
|
||||||
|
w = (w + 1) // 2
|
||||||
|
|
||||||
|
self.mask_dic = mask_dic
|
||||||
|
|||||||
@@ -21,6 +21,6 @@ fairscale==0.4.13
|
|||||||
# for WD14 captioning
|
# for WD14 captioning
|
||||||
# tensorflow<2.11
|
# tensorflow<2.11
|
||||||
tensorflow==2.10.1
|
tensorflow==2.10.1
|
||||||
huggingface-hub==0.12.0
|
huggingface-hub==0.13.3
|
||||||
# for kohya_ss library
|
# for kohya_ss library
|
||||||
.
|
.
|
||||||
|
|||||||
@@ -201,9 +201,7 @@ def train(args):
|
|||||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||||
|
|
||||||
# resumeする
|
# resumeする
|
||||||
if args.resume is not None:
|
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||||
print(f"resume training from state: {args.resume}")
|
|
||||||
accelerator.load_state(args.resume)
|
|
||||||
|
|
||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from library.config_util import (
|
|||||||
ConfigSanitizer,
|
ConfigSanitizer,
|
||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
|
import library.huggingface_util as huggingface_util
|
||||||
import library.custom_train_functions as custom_train_functions
|
import library.custom_train_functions as custom_train_functions
|
||||||
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
|
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
|
||||||
|
|
||||||
@@ -71,7 +72,8 @@ def train(args):
|
|||||||
use_dreambooth_method = args.in_json is None
|
use_dreambooth_method = args.in_json is None
|
||||||
use_user_config = args.dataset_config is not None
|
use_user_config = args.dataset_config is not None
|
||||||
|
|
||||||
if args.seed is not None:
|
if args.seed is None:
|
||||||
|
args.seed = random.randint(0, 2**32)
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
|
|
||||||
tokenizer = train_util.load_tokenizer(args)
|
tokenizer = train_util.load_tokenizer(args)
|
||||||
@@ -308,9 +310,7 @@ def train(args):
|
|||||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||||
|
|
||||||
# resumeする
|
# resumeする
|
||||||
if args.resume is not None:
|
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||||
print(f"resume training from state: {args.resume}")
|
|
||||||
accelerator.load_state(args.resume)
|
|
||||||
|
|
||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
@@ -658,6 +658,8 @@ def train(args):
|
|||||||
metadata["ss_training_finished_at"] = str(time.time())
|
metadata["ss_training_finished_at"] = str(time.time())
|
||||||
print(f"saving checkpoint: {ckpt_file}")
|
print(f"saving checkpoint: {ckpt_file}")
|
||||||
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
||||||
|
if args.huggingface_repo_id is not None:
|
||||||
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
|
||||||
|
|
||||||
def remove_old_func(old_epoch_no):
|
def remove_old_func(old_epoch_no):
|
||||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
||||||
@@ -697,6 +699,8 @@ def train(args):
|
|||||||
|
|
||||||
print(f"save trained model to {ckpt_file}")
|
print(f"save trained model to {ckpt_file}")
|
||||||
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
||||||
|
if args.huggingface_repo_id is not None:
|
||||||
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
|
||||||
print("model saved.")
|
print("model saved.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -188,6 +188,73 @@ gen_img_diffusers.pyに、--network_module、--network_weightsの各オプショ
|
|||||||
|
|
||||||
--network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。
|
--network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。
|
||||||
|
|
||||||
|
## Diffusersのpipelineで生成する
|
||||||
|
|
||||||
|
以下の例を参考にしてください。必要なファイルはnetworks/lora.pyのみです。Diffusersのバージョンは0.10.2以外では動作しない可能性があります。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffusers import StableDiffusionPipeline
|
||||||
|
from networks.lora import LoRAModule, create_network_from_weights
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
# if the ckpt is CompVis based, convert it to Diffusers beforehand with tools/convert_diffusers20_original_sd.py. See --help for more details.
|
||||||
|
|
||||||
|
model_id_or_dir = r"model_id_on_hugging_face_or_dir"
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
# create pipe
|
||||||
|
print(f"creating pipe from {model_id_or_dir}...")
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(model_id_or_dir, revision="fp16", torch_dtype=torch.float16)
|
||||||
|
pipe = pipe.to(device)
|
||||||
|
vae = pipe.vae
|
||||||
|
text_encoder = pipe.text_encoder
|
||||||
|
unet = pipe.unet
|
||||||
|
|
||||||
|
# load lora networks
|
||||||
|
print(f"loading lora networks...")
|
||||||
|
|
||||||
|
lora_path1 = r"lora1.safetensors"
|
||||||
|
sd = load_file(lora_path1) # If the file is .ckpt, use torch.load instead.
|
||||||
|
network1, sd = create_network_from_weights(0.5, None, vae, text_encoder,unet, sd)
|
||||||
|
network1.apply_to(text_encoder, unet)
|
||||||
|
network1.load_state_dict(sd)
|
||||||
|
network1.to(device, dtype=torch.float16)
|
||||||
|
|
||||||
|
# # You can merge weights instead of apply_to+load_state_dict. network.set_multiplier does not work
|
||||||
|
# network.merge_to(text_encoder, unet, sd)
|
||||||
|
|
||||||
|
lora_path2 = r"lora2.safetensors"
|
||||||
|
sd = load_file(lora_path2)
|
||||||
|
network2, sd = create_network_from_weights(0.7, None, vae, text_encoder,unet, sd)
|
||||||
|
network2.apply_to(text_encoder, unet)
|
||||||
|
network2.load_state_dict(sd)
|
||||||
|
network2.to(device, dtype=torch.float16)
|
||||||
|
|
||||||
|
lora_path3 = r"lora3.safetensors"
|
||||||
|
sd = load_file(lora_path3)
|
||||||
|
network3, sd = create_network_from_weights(0.5, None, vae, text_encoder,unet, sd)
|
||||||
|
network3.apply_to(text_encoder, unet)
|
||||||
|
network3.load_state_dict(sd)
|
||||||
|
network3.to(device, dtype=torch.float16)
|
||||||
|
|
||||||
|
# prompts
|
||||||
|
prompt = "masterpiece, best quality, 1girl, in white shirt, looking at viewer"
|
||||||
|
negative_prompt = "bad quality, worst quality, bad anatomy, bad hands"
|
||||||
|
|
||||||
|
# exec pipe
|
||||||
|
print("generating image...")
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
image = pipe(prompt, guidance_scale=7.5, negative_prompt=negative_prompt).images[0]
|
||||||
|
|
||||||
|
# if not merged, you can use set_multiplier
|
||||||
|
# network1.set_multiplier(0.8)
|
||||||
|
# and generate image again...
|
||||||
|
|
||||||
|
# save image
|
||||||
|
image.save(r"by_diffusers..png")
|
||||||
|
```
|
||||||
|
|
||||||
## 二つのモデルの差分からLoRAモデルを作成する
|
## 二つのモデルの差分からLoRAモデルを作成する
|
||||||
|
|
||||||
[こちらのディスカッション](https://github.com/cloneofsimo/lora/discussions/56)を参考に実装したものです。数式はそのまま使わせていただきました(よく理解していませんが近似には特異値分解を用いるようです)。
|
[こちらのディスカッション](https://github.com/cloneofsimo/lora/discussions/56)を参考に実装したものです。数式はそのまま使わせていただきました(よく理解していませんが近似には特異値分解を用いるようです)。
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import diffusers
|
|||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
|
import library.huggingface_util as huggingface_util
|
||||||
import library.config_util as config_util
|
import library.config_util as config_util
|
||||||
from library.config_util import (
|
from library.config_util import (
|
||||||
ConfigSanitizer,
|
ConfigSanitizer,
|
||||||
@@ -304,9 +305,7 @@ def train(args):
|
|||||||
text_encoder.to(weight_dtype)
|
text_encoder.to(weight_dtype)
|
||||||
|
|
||||||
# resumeする
|
# resumeする
|
||||||
if args.resume is not None:
|
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||||
print(f"resume training from state: {args.resume}")
|
|
||||||
accelerator.load_state(args.resume)
|
|
||||||
|
|
||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
@@ -452,6 +451,8 @@ def train(args):
|
|||||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||||
print(f"saving checkpoint: {ckpt_file}")
|
print(f"saving checkpoint: {ckpt_file}")
|
||||||
save_weights(ckpt_file, updated_embs, save_dtype)
|
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||||
|
if args.huggingface_repo_id is not None:
|
||||||
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
|
||||||
|
|
||||||
def remove_old_func(old_epoch_no):
|
def remove_old_func(old_epoch_no):
|
||||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
||||||
@@ -492,6 +493,8 @@ def train(args):
|
|||||||
|
|
||||||
print(f"save trained model to {ckpt_file}")
|
print(f"save trained model to {ckpt_file}")
|
||||||
save_weights(ckpt_file, updated_embs, save_dtype)
|
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||||
|
if args.huggingface_repo_id is not None:
|
||||||
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
|
||||||
print("model saved.")
|
print("model saved.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import diffusers
|
|||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
|
import library.huggingface_util as huggingface_util
|
||||||
import library.config_util as config_util
|
import library.config_util as config_util
|
||||||
from library.config_util import (
|
from library.config_util import (
|
||||||
ConfigSanitizer,
|
ConfigSanitizer,
|
||||||
@@ -340,9 +341,7 @@ def train(args):
|
|||||||
text_encoder.to(weight_dtype)
|
text_encoder.to(weight_dtype)
|
||||||
|
|
||||||
# resumeする
|
# resumeする
|
||||||
if args.resume is not None:
|
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||||
print(f"resume training from state: {args.resume}")
|
|
||||||
accelerator.load_state(args.resume)
|
|
||||||
|
|
||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
@@ -493,6 +492,8 @@ def train(args):
|
|||||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||||
print(f"saving checkpoint: {ckpt_file}")
|
print(f"saving checkpoint: {ckpt_file}")
|
||||||
save_weights(ckpt_file, updated_embs, save_dtype)
|
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||||
|
if args.huggingface_repo_id is not None:
|
||||||
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
|
||||||
|
|
||||||
def remove_old_func(old_epoch_no):
|
def remove_old_func(old_epoch_no):
|
||||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
||||||
@@ -534,6 +535,8 @@ def train(args):
|
|||||||
|
|
||||||
print(f"save trained model to {ckpt_file}")
|
print(f"save trained model to {ckpt_file}")
|
||||||
save_weights(ckpt_file, updated_embs, save_dtype)
|
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||||
|
if args.huggingface_repo_id is not None:
|
||||||
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
|
||||||
print("model saved.")
|
print("model saved.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user