mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support textual inversion training
This commit is contained in:
30
README.md
30
README.md
@@ -25,18 +25,31 @@ The feature of SDXL training is now available in sdxl branch as an experimental
|
|||||||
Summary of the feature:
|
Summary of the feature:
|
||||||
|
|
||||||
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
|
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
|
||||||
- `--full_bf16` option is added. This option enables the full bfloat16 training. This option is useful to reduce the GPU memory usage.
|
- `--full_bf16` option is added. Thanks to KohakuBlueleaf!
|
||||||
|
- This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage.
|
||||||
- However, bitsandbytes==0.35 doesn't seem to support this. Please use a newer version of bitsandbytes or another optimizer.
|
- However, bitsandbytes==0.35 doesn't seem to support this. Please use a newer version of bitsandbytes or another optimizer.
|
||||||
- I cannot find bitsandbytes>0.35.0 that works correctly on Windows.
|
- I cannot find bitsandbytes>0.35.0 that works correctly on Windows.
|
||||||
|
- In addition, the full bfloat16 training might be unstable. Please use it at your own risk.
|
||||||
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
|
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
|
||||||
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
|
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
|
||||||
- Both scripts has following additional options:
|
- Both scripts has following additional options:
|
||||||
- `--cache_text_encoder_outputs`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
|
- `--cache_text_encoder_outputs`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
|
||||||
- `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.
|
- `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.
|
||||||
- The image generation during training is now available. However, the VAE for SDXL seems to produce NaNs in some cases when using `fp16`. The images will be black. Currently, the NaNs cannot be avoided even with `--no_half_vae` option. It works with `bf16` or without mixed precision.
|
- The image generation during training is now available. However, the VAE for SDXL seems to produce NaNs in some cases when using `fp16`. The images will be black. Currently, the NaNs cannot be avoided even with `--no_half_vae` option. It works with `bf16` or without mixed precision.
|
||||||
- `--weighted_captions` option is not supported yet.
|
|
||||||
|
- `--weighted_captions` option is not supported yet for both scripts.
|
||||||
- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
|
- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
|
||||||
|
|
||||||
|
- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`.
|
||||||
|
- `--cache_text_encoder_outputs` is not supported.
|
||||||
|
- `token_string` must be alphabet only currently, due to the limitation of the open-clip tokenizer.
|
||||||
|
- There are two options for captions:
|
||||||
|
1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens.
|
||||||
|
2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored.
|
||||||
|
- See below for the format of the embeddings.
|
||||||
|
|
||||||
- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA. See the help message for the usage.
|
- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA. See the help message for the usage.
|
||||||
|
- Textual Inversion is supported, but the name for the embeds in the caption becomes alphabet only. For example, `neg_hand_v1.safetensors` can be activated with `neghandv`.
|
||||||
|
|
||||||
`requirements.txt` is updated to support SDXL training.
|
`requirements.txt` is updated to support SDXL training.
|
||||||
|
|
||||||
@@ -54,7 +67,7 @@ Summary of the feature:
|
|||||||
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.
|
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.
|
||||||
|
|
||||||
Example of the optimizer settings for Adafactor with the fixed learning rate:
|
Example of the optimizer settings for Adafactor with the fixed learning rate:
|
||||||
```
|
```toml
|
||||||
optimizer_type = "adafactor"
|
optimizer_type = "adafactor"
|
||||||
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
|
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
|
||||||
lr_scheduler = "constant_with_warmup"
|
lr_scheduler = "constant_with_warmup"
|
||||||
@@ -62,13 +75,22 @@ lr_warmup_steps = 100
|
|||||||
learning_rate = 4e-7 # SDXL original learning rate
|
learning_rate = 4e-7 # SDXL original learning rate
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Format of Textual Inversion embeddings
|
||||||
|
|
||||||
|
```python
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768}
|
||||||
|
save_file(state_dict, file)
|
||||||
|
```
|
||||||
|
|
||||||
### TODO
|
### TODO
|
||||||
|
|
||||||
- [ ] Support Textual Inversion training.
|
|
||||||
- [ ] Support conversion of Diffusers SDXL models.
|
- [ ] Support conversion of Diffusers SDXL models.
|
||||||
- [ ] Support `--weighted_captions` option.
|
- [ ] Support `--weighted_captions` option.
|
||||||
- [ ] Change `--output_config` option to continue the training.
|
- [ ] Change `--output_config` option to continue the training.
|
||||||
- [ ] Extend `--full_bf16` for all the scripts.
|
- [ ] Extend `--full_bf16` for all the scripts.
|
||||||
|
- [x] Support Textual Inversion training.
|
||||||
|
|
||||||
## About requirements.txt
|
## About requirements.txt
|
||||||
|
|
||||||
|
|||||||
@@ -78,12 +78,13 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
|
|||||||
|
|
||||||
class WrapperTokenizer:
|
class WrapperTokenizer:
|
||||||
# open clipのtokenizerをHuggingFaceのtokenizerと同じ形で使えるようにする
|
# open clipのtokenizerをHuggingFaceのtokenizerと同じ形で使えるようにする
|
||||||
|
# make open clip tokenizer compatible with HuggingFace tokenizer
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
open_clip_tokenizer = open_clip.tokenizer._tokenizer
|
open_clip_tokenizer = open_clip.tokenizer._tokenizer
|
||||||
self.model_max_length = 77
|
self.model_max_length = 77
|
||||||
self.bos_token_id = open_clip_tokenizer.all_special_ids[0]
|
self.bos_token_id = open_clip_tokenizer.all_special_ids[0]
|
||||||
self.eos_token_id = open_clip_tokenizer.all_special_ids[1]
|
self.eos_token_id = open_clip_tokenizer.all_special_ids[1]
|
||||||
self.pad_token_id = 0 # 結果から推定している
|
self.pad_token_id = 0 # 結果から推定している assumption from result
|
||||||
|
|
||||||
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
||||||
return self.tokenize(*args, **kwds)
|
return self.tokenize(*args, **kwds)
|
||||||
@@ -107,6 +108,42 @@ class WrapperTokenizer:
|
|||||||
input_ids = input_ids[: eos_index + 1] # include eos
|
input_ids = input_ids[: eos_index + 1] # include eos
|
||||||
return SimpleNamespace(**{"input_ids": input_ids})
|
return SimpleNamespace(**{"input_ids": input_ids})
|
||||||
|
|
||||||
|
# for Textual Inversion
|
||||||
|
# わりと面倒くさいな……これWeb UIとかでどうするんだろう / this is a bit annoying... how to do this in Web UI?
|
||||||
|
|
||||||
|
def encode(self, text, add_special_tokens=False):
|
||||||
|
assert not add_special_tokens
|
||||||
|
input_ids = open_clip.tokenizer._tokenizer.encode(text)
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
def add_tokens(self, new_tokens):
|
||||||
|
tokens_to_add = []
|
||||||
|
for token in new_tokens:
|
||||||
|
token = token.lower()
|
||||||
|
if token + "</w>" not in open_clip.tokenizer._tokenizer.encoder:
|
||||||
|
tokens_to_add.append(token)
|
||||||
|
|
||||||
|
# open clipのtokenizerに直接追加する / add tokens to open clip tokenizer
|
||||||
|
for token in tokens_to_add:
|
||||||
|
open_clip.tokenizer._tokenizer.encoder[token + "</w>"] = len(open_clip.tokenizer._tokenizer.encoder)
|
||||||
|
open_clip.tokenizer._tokenizer.decoder[len(open_clip.tokenizer._tokenizer.decoder)] = token + "</w>"
|
||||||
|
open_clip.tokenizer._tokenizer.vocab_size += 1
|
||||||
|
|
||||||
|
# open clipのtokenizerのcacheに直接設定することで、bpeとかいうやつに含まれていなくてもtokenizeできるようにする
|
||||||
|
# めちゃくちゃ乱暴なので、open clipのtokenizerの仕様が変わったら動かなくなる
|
||||||
|
# set cache of open clip tokenizer directly to enable tokenization even if the token is not included in bpe
|
||||||
|
# this is very rough, so it will not work if the specification of open clip tokenizer changes
|
||||||
|
open_clip.tokenizer._tokenizer.cache[token] = token + "</w>"
|
||||||
|
|
||||||
|
return len(tokens_to_add)
|
||||||
|
|
||||||
|
def convert_tokens_to_ids(self, tokens):
|
||||||
|
input_ids = [open_clip.tokenizer._tokenizer.encoder[token + "</w>"] for token in tokens]
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return open_clip.tokenizer._tokenizer.vocab_size
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizers(args: argparse.Namespace):
|
def load_tokenizers(args: argparse.Namespace):
|
||||||
print("prepare tokenizers")
|
print("prepare tokenizers")
|
||||||
@@ -392,7 +429,7 @@ def verify_sdxl_training_args(args: argparse.Namespace):
|
|||||||
print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
not args.weighted_captions
|
not hasattr(args, "weighted_captions") or not args.weighted_captions
|
||||||
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -320,7 +320,7 @@ class PipelineLike:
|
|||||||
self.scheduler = scheduler
|
self.scheduler = scheduler
|
||||||
self.safety_checker = None
|
self.safety_checker = None
|
||||||
|
|
||||||
# Textual Inversion # not tested yet
|
# Textual Inversion
|
||||||
self.token_replacements_list = []
|
self.token_replacements_list = []
|
||||||
for _ in range(len(self.text_encoders)):
|
for _ in range(len(self.text_encoders)):
|
||||||
self.token_replacements_list.append({})
|
self.token_replacements_list.append({})
|
||||||
@@ -341,6 +341,10 @@ class PipelineLike:
|
|||||||
token_replacements = self.token_replacements_list[tokenizer_index]
|
token_replacements = self.token_replacements_list[tokenizer_index]
|
||||||
|
|
||||||
def replace_tokens(tokens):
|
def replace_tokens(tokens):
|
||||||
|
# print("replace_tokens", tokens, "=>", token_replacements)
|
||||||
|
if isinstance(tokens, torch.Tensor):
|
||||||
|
tokens = tokens.tolist()
|
||||||
|
|
||||||
new_tokens = []
|
new_tokens = []
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
if token in token_replacements:
|
if token in token_replacements:
|
||||||
@@ -1594,19 +1598,26 @@ def main(args):
|
|||||||
|
|
||||||
if "string_to_param" in data:
|
if "string_to_param" in data:
|
||||||
data = data["string_to_param"]
|
data = data["string_to_param"]
|
||||||
embeds1 = data["clip_l"]
|
|
||||||
embeds2 = data["clip_g"]
|
embeds1 = data["clip_l"] # text encoder 1
|
||||||
|
embeds2 = data["clip_g"] # text encoder 2
|
||||||
|
|
||||||
num_vectors_per_token = embeds1.size()[0]
|
num_vectors_per_token = embeds1.size()[0]
|
||||||
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
|
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
|
||||||
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
|
|
||||||
|
# remove non-alphabet characters to avoid splitting by tokenizer
|
||||||
|
# TODO make random alphabet string
|
||||||
|
token_string = "".join([c for c in token_string if c.isalpha()])
|
||||||
|
|
||||||
|
token_strings = [token_string] + [f"{token_string}{chr(ord('a') + i)}" for i in range(num_vectors_per_token - 1)]
|
||||||
|
|
||||||
# add new word to tokenizer, count is num_vectors_per_token
|
# add new word to tokenizer, count is num_vectors_per_token
|
||||||
num_added_tokens1 = tokenizer1.add_tokens(token_strings)
|
num_added_tokens1 = tokenizer1.add_tokens(token_strings)
|
||||||
num_added_tokens2 = tokenizer2.add_tokens(token_strings) # not working now
|
num_added_tokens2 = tokenizer2.add_tokens(token_strings)
|
||||||
assert (
|
assert num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token, (
|
||||||
num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token
|
f"tokenizer has same word to token string (filename). characters except alphabet are removed: {embeds_file}"
|
||||||
), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
|
+ f" / 指定した名前(ファイル名)のトークンが既に存在します。アルファベット以外の文字は削除されます: {embeds_file}"
|
||||||
|
)
|
||||||
|
|
||||||
token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings)
|
token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings)
|
||||||
token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings)
|
token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings)
|
||||||
@@ -1617,11 +1628,11 @@ def main(args):
|
|||||||
assert (
|
assert (
|
||||||
min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1
|
min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1
|
||||||
), f"token ids2 is not ordered"
|
), f"token ids2 is not ordered"
|
||||||
assert len(tokenizer1) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer1)}"
|
assert len(tokenizer1) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizer1)}"
|
||||||
assert len(tokenizer2) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer2)}"
|
assert len(tokenizer2) - 1 == token_ids2[-1], f"token ids 2 is not end of tokenize: {len(tokenizer2)}"
|
||||||
|
|
||||||
if num_vectors_per_token > 1:
|
if num_vectors_per_token > 1:
|
||||||
pipe.add_token_replacement(0, token_ids1[0], token_ids1)
|
pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ...
|
||||||
pipe.add_token_replacement(1, token_ids2[0], token_ids2)
|
pipe.add_token_replacement(1, token_ids2[0], token_ids2)
|
||||||
|
|
||||||
token_ids_embeds1.append((token_ids1, embeds1))
|
token_ids_embeds1.append((token_ids1, embeds1))
|
||||||
|
|||||||
142
sdxl_train_textual_inversion.py
Normal file
142
sdxl_train_textual_inversion.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import regex
|
||||||
|
import torch
|
||||||
|
import open_clip
|
||||||
|
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||||
|
|
||||||
|
import train_textual_inversion
|
||||||
|
|
||||||
|
|
||||||
|
class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTrainer):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
|
||||||
|
|
||||||
|
def assert_extra_args(self, args, train_dataset_group):
|
||||||
|
super().assert_extra_args(args, train_dataset_group)
|
||||||
|
sdxl_train_util.verify_sdxl_training_args(args)
|
||||||
|
|
||||||
|
def load_target_model(self, args, weight_dtype, accelerator):
|
||||||
|
(
|
||||||
|
load_stable_diffusion_format,
|
||||||
|
text_encoder1,
|
||||||
|
text_encoder2,
|
||||||
|
vae,
|
||||||
|
unet,
|
||||||
|
logit_scale,
|
||||||
|
ckpt_info,
|
||||||
|
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype)
|
||||||
|
|
||||||
|
self.load_stable_diffusion_format = load_stable_diffusion_format
|
||||||
|
self.logit_scale = logit_scale
|
||||||
|
self.ckpt_info = ckpt_info
|
||||||
|
|
||||||
|
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, [text_encoder1, text_encoder2], vae, unet
|
||||||
|
|
||||||
|
def load_tokenizer(self, args):
|
||||||
|
tokenizer = sdxl_train_util.load_tokenizers(args)
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
def assert_token_string(self, token_string, tokenizers):
|
||||||
|
# tokenizer 1 is seems to be ok
|
||||||
|
|
||||||
|
# count words for token string: regular expression from open_clip
|
||||||
|
pat = regex.compile(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", regex.IGNORECASE)
|
||||||
|
words = regex.findall(pat, token_string)
|
||||||
|
word_count = len(words)
|
||||||
|
assert word_count == 1, (
|
||||||
|
f"token string {token_string} contain {word_count} words, please don't use digits, punctuation, or special characters"
|
||||||
|
+ f" / トークン文字列 {token_string} には{word_count}個の単語が含まれています。数字、句読点、特殊文字は使用しないでください"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
||||||
|
input_ids1 = batch["input_ids"]
|
||||||
|
input_ids2 = batch["input_ids2"]
|
||||||
|
with torch.enable_grad():
|
||||||
|
input_ids1 = input_ids1.to(accelerator.device)
|
||||||
|
input_ids2 = input_ids2.to(accelerator.device)
|
||||||
|
encoder_hidden_states1, encoder_hidden_states2, pool2 = sdxl_train_util.get_hidden_states(
|
||||||
|
args,
|
||||||
|
input_ids1,
|
||||||
|
input_ids2,
|
||||||
|
tokenizers[0],
|
||||||
|
tokenizers[1],
|
||||||
|
text_encoders[0],
|
||||||
|
text_encoders[1],
|
||||||
|
None if not args.full_fp16 else weight_dtype,
|
||||||
|
)
|
||||||
|
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
||||||
|
|
||||||
|
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||||
|
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||||
|
|
||||||
|
# get size embeddings
|
||||||
|
orig_size = batch["original_sizes_hw"]
|
||||||
|
crop_size = batch["crop_top_lefts"]
|
||||||
|
target_size = batch["target_sizes_hw"]
|
||||||
|
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
||||||
|
|
||||||
|
# concat embeddings
|
||||||
|
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
||||||
|
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
||||||
|
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
||||||
|
|
||||||
|
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||||
|
return noise_pred
|
||||||
|
|
||||||
|
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
|
||||||
|
sdxl_train_util.sample_images(
|
||||||
|
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_weights(self, file, updated_embs, save_dtype):
|
||||||
|
state_dict = {"clip_l": updated_embs[0], "clip_g": updated_embs[1]}
|
||||||
|
|
||||||
|
if save_dtype is not None:
|
||||||
|
for key in list(state_dict.keys()):
|
||||||
|
v = state_dict[key]
|
||||||
|
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||||
|
state_dict[key] = v
|
||||||
|
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
save_file(state_dict, file)
|
||||||
|
else:
|
||||||
|
torch.save(state_dict, file)
|
||||||
|
|
||||||
|
def load_weights(self, file):
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
data = load_file(file)
|
||||||
|
else:
|
||||||
|
data = torch.load(file, map_location="cpu")
|
||||||
|
|
||||||
|
emb_l = data.get("clib_l", None) # ViT-L text encoder 1
|
||||||
|
emb_g = data.get("clib_g", None) # BiG-G text encoder 2
|
||||||
|
|
||||||
|
assert (
|
||||||
|
emb_l is not None or emb_g is not None
|
||||||
|
), f"weight file does not contains weights for text encoder 1 or 2 / 重みファイルにテキストエンコーダー1または2の重みが含まれていません: {file}"
|
||||||
|
|
||||||
|
return [emb_l, emb_g]
|
||||||
|
|
||||||
|
|
||||||
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = train_textual_inversion.setup_parser()
|
||||||
|
# don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching
|
||||||
|
# sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = setup_parser()
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
|
||||||
|
trainer = SdxlTextualInversionTrainer()
|
||||||
|
trainer.train(args)
|
||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user