mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge remote-tracking branch 'upstream/main'
This commit is contained in:
48
README.md
48
README.md
@@ -127,6 +127,43 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
||||
|
||||
## Change History
|
||||
|
||||
- 28 Mar. 2023, 2023/3/28:
|
||||
- Fix an issue that the training script crashes when `max_data_loader_n_workers` is 0.
|
||||
- `max_data_loader_n_workers` が0の時に学習スクリプトがエラーとなる不具合を修正しました。
|
||||
|
||||
- 27 Mar. 2023, 2023/3/27:
|
||||
- Fix issues when `--persistent_data_loader_workers` is specified.
|
||||
- The batch members of the bucket are not shuffled.
|
||||
- `--caption_dropout_every_n_epochs` does not work.
|
||||
- These issues occurred because the epoch transition was not recognized correctly. Thanks to u-haru for reporting the issue.
|
||||
- Fix an issue that images are loaded twice in Windows environment.
|
||||
- Add Min-SNR Weighting strategy. Details are in [#308](https://github.com/kohya-ss/sd-scripts/pull/308). Thank you to AI-Casanova for this great work!
|
||||
- Add `--min_snr_gamma` option to training scripts, 5 is recommended by paper.
|
||||
|
||||
- Add tag warmup. Details are in [#322](https://github.com/kohya-ss/sd-scripts/pull/322). Thanks to u-haru!
|
||||
- Add `token_warmup_min` and `token_warmup_step` to dataset settings.
|
||||
- Gradually increase the number of tokens from `token_warmup_min` to `token_warmup_step`.
|
||||
- For example, if `token_warmup_min` is `3` and `token_warmup_step` is `10`, the first step will use the first 3 tokens, and the 10th step will use all tokens.
|
||||
- Fix a bug in `resize_lora.py`. Thanks to mgz-dev! [#328](https://github.com/kohya-ss/sd-scripts/pull/328)
|
||||
- Add `--debug_dataset` option to step to the next step with `S` key and to the next epoch with `E` key.
|
||||
- Fix other bugs.
|
||||
|
||||
- `--persistent_data_loader_workers` を指定した時の各種不具合を修正しました。
|
||||
- `--caption_dropout_every_n_epochs` が効かない。
|
||||
- バケットのバッチメンバーがシャッフルされない。
|
||||
- エポックの遷移が正しく認識されないために発生していました。ご指摘いただいたu-haru氏に感謝します。
|
||||
- Windows環境で画像が二重に読み込まれる不具合を修正しました。
|
||||
- Min-SNR Weighting strategyを追加しました。 詳細は [#308](https://github.com/kohya-ss/sd-scripts/pull/308) をご参照ください。AI-Casanova氏の素晴らしい貢献に感謝します。
|
||||
- `--min_snr_gamma` オプションを学習スクリプトに追加しました。論文では5が推奨されています。
|
||||
- タグのウォームアップを追加しました。詳細は [#322](https://github.com/kohya-ss/sd-scripts/pull/322) をご参照ください。u-haru氏に感謝します。
|
||||
- データセット設定に `token_warmup_min` と `token_warmup_step` を追加しました。
|
||||
- `token_warmup_min` で指定した数のトークン(カンマ区切りの文字列)から、`token_warmup_step` で指定したステップまで、段階的にトークンを増やしていきます。
|
||||
- たとえば `token_warmup_min`に `3` を、`token_warmup_step` に `10` を指定すると、最初のステップでは最初から3個のトークンが使われ、10ステップ目では全てのトークンが使われます。
|
||||
- `resize_lora.py` の不具合を修正しました。mgz-dev氏に感謝します。[#328](https://github.com/kohya-ss/sd-scripts/pull/328)
|
||||
- `--debug_dataset` オプションで、`S`キーで次のステップへ、`E`キーで次のエポックへ進めるようにしました。
|
||||
- その他の不具合を修正しました。
|
||||
|
||||
|
||||
- 21 Mar. 2023, 2023/3/21:
|
||||
- Add `--vae_batch_size` for faster latents caching to each training script. This batches VAE calls.
|
||||
- Please start with`2` or `4` depending on the size of VRAM.
|
||||
@@ -143,8 +180,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
||||
- Windows以外の環境向けに、画像ファイルの大文字の拡張子をサポートしました。
|
||||
- `resize_lora.py` を dynamic rank (rankが各LoRAモジュールで異なる場合、`conv_dim` が `network_dim` と異なる場合も含む)の時に正しく動作しない不具合を修正しました。toshiaki氏に感謝します。
|
||||
|
||||
|
||||
- Sample image generation:
|
||||
## Sample image generation during training
|
||||
A prompt file might look like this, for example
|
||||
|
||||
```
|
||||
@@ -166,15 +202,15 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
||||
|
||||
The prompt weighting such as `( )` and `[ ]` are working.
|
||||
|
||||
- サンプル画像生成:
|
||||
## サンプル画像生成
|
||||
プロンプトファイルは例えば以下のようになります。
|
||||
|
||||
```
|
||||
# prompt 1
|
||||
masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
|
||||
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
|
||||
|
||||
# prompt 2
|
||||
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
|
||||
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
|
||||
```
|
||||
|
||||
`#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。
|
||||
@@ -186,7 +222,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
||||
* `--l` Specifies the CFG scale of the generated image.
|
||||
* `--s` Specifies the number of steps in the generation.
|
||||
|
||||
`( )` や `[ ]` などの重みづけは動作しません。
|
||||
`( )` や `[ ]` などの重みづけも動作します。
|
||||
|
||||
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
|
||||
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
|
||||
|
||||
32
fine_tune.py
32
fine_tune.py
@@ -6,6 +6,7 @@ import gc
|
||||
import math
|
||||
import os
|
||||
import toml
|
||||
from multiprocessing import Value
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
@@ -19,10 +20,8 @@ from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
|
||||
|
||||
def collate_fn(examples):
|
||||
return examples[0]
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight
|
||||
|
||||
|
||||
def train(args):
|
||||
@@ -64,6 +63,11 @@ def train(args):
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
return
|
||||
@@ -187,16 +191,21 @@ def train(args):
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
collate_fn=collater,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
@@ -255,13 +264,14 @@ def train(args):
|
||||
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
train_dataset_group.set_current_epoch(epoch + 1)
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
for m in training_models:
|
||||
m.train()
|
||||
|
||||
loss_total = 0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
@@ -302,6 +312,13 @@ def train(args):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
if args.min_snr_gamma:
|
||||
# do not mean over batch dimension for snr weight
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
loss = loss.mean() # mean over batch dimension
|
||||
else:
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
accelerator.backward(loss)
|
||||
@@ -396,6 +413,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
|
||||
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
||||
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,7 @@ from dataclasses import (
|
||||
dataclass,
|
||||
)
|
||||
import functools
|
||||
import random
|
||||
from textwrap import dedent, indent
|
||||
import json
|
||||
from pathlib import Path
|
||||
@@ -56,6 +57,8 @@ class BaseSubsetParams:
|
||||
caption_dropout_rate: float = 0.0
|
||||
caption_dropout_every_n_epochs: int = 0
|
||||
caption_tag_dropout_rate: float = 0.0
|
||||
token_warmup_min: int = 1
|
||||
token_warmup_step: float = 0
|
||||
|
||||
@dataclass
|
||||
class DreamBoothSubsetParams(BaseSubsetParams):
|
||||
@@ -137,6 +140,8 @@ class ConfigSanitizer:
|
||||
"random_crop": bool,
|
||||
"shuffle_caption": bool,
|
||||
"keep_tokens": int,
|
||||
"token_warmup_min": int,
|
||||
"token_warmup_step": Any(float,int),
|
||||
}
|
||||
# DO means DropOut
|
||||
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
||||
@@ -406,6 +411,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
flip_aug: {subset.flip_aug}
|
||||
face_crop_aug_range: {subset.face_crop_aug_range}
|
||||
random_crop: {subset.random_crop}
|
||||
token_warmup_min: {subset.token_warmup_min},
|
||||
token_warmup_step: {subset.token_warmup_step},
|
||||
"""), " ")
|
||||
|
||||
if is_dreambooth:
|
||||
@@ -422,9 +429,12 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
print(info)
|
||||
|
||||
# make buckets first because it determines the length of dataset
|
||||
# and set the same seed for all datasets
|
||||
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
||||
for i, dataset in enumerate(datasets):
|
||||
print(f"[Dataset {i}]")
|
||||
dataset.make_buckets()
|
||||
dataset.set_seed(seed)
|
||||
|
||||
return DatasetGroup(datasets)
|
||||
|
||||
@@ -491,7 +501,6 @@ def load_user_config(file: str) -> dict:
|
||||
|
||||
return config
|
||||
|
||||
|
||||
# for config test
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
18
library/custom_train_functions.py
Normal file
18
library/custom_train_functions.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import torch
|
||||
import argparse
|
||||
|
||||
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
||||
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
||||
alpha = sqrt_alphas_cumprod
|
||||
sigma = sqrt_one_minus_alphas_cumprod
|
||||
all_snr = (alpha / sigma) ** 2
|
||||
snr = torch.stack([all_snr[t] for t in timesteps])
|
||||
gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr)
|
||||
snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper
|
||||
loss = loss * snr_weight
|
||||
return loss
|
||||
|
||||
def add_custom_train_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨")
|
||||
@@ -1046,10 +1046,14 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
|
||||
key_count = len(state_dict.keys())
|
||||
new_ckpt = {'state_dict': state_dict}
|
||||
|
||||
# epoch and global_step are sometimes not int
|
||||
try:
|
||||
if 'epoch' in checkpoint:
|
||||
epochs += checkpoint['epoch']
|
||||
if 'global_step' in checkpoint:
|
||||
steps += checkpoint['global_step']
|
||||
except:
|
||||
pass
|
||||
|
||||
new_ckpt['epoch'] = epochs
|
||||
new_ckpt['global_step'] = steps
|
||||
|
||||
@@ -276,6 +276,8 @@ class BaseSubset:
|
||||
caption_dropout_rate: float,
|
||||
caption_dropout_every_n_epochs: int,
|
||||
caption_tag_dropout_rate: float,
|
||||
token_warmup_min: int,
|
||||
token_warmup_step: Union[float, int],
|
||||
) -> None:
|
||||
self.image_dir = image_dir
|
||||
self.num_repeats = num_repeats
|
||||
@@ -289,6 +291,9 @@ class BaseSubset:
|
||||
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
|
||||
self.caption_tag_dropout_rate = caption_tag_dropout_rate
|
||||
|
||||
self.token_warmup_min = token_warmup_min # step=0におけるタグの数
|
||||
self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる
|
||||
|
||||
self.img_count = 0
|
||||
|
||||
|
||||
@@ -309,6 +314,8 @@ class DreamBoothSubset(BaseSubset):
|
||||
caption_dropout_rate,
|
||||
caption_dropout_every_n_epochs,
|
||||
caption_tag_dropout_rate,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
@@ -324,6 +331,8 @@ class DreamBoothSubset(BaseSubset):
|
||||
caption_dropout_rate,
|
||||
caption_dropout_every_n_epochs,
|
||||
caption_tag_dropout_rate,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
)
|
||||
|
||||
self.is_reg = is_reg
|
||||
@@ -351,6 +360,8 @@ class FineTuningSubset(BaseSubset):
|
||||
caption_dropout_rate,
|
||||
caption_dropout_every_n_epochs,
|
||||
caption_tag_dropout_rate,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
) -> None:
|
||||
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
||||
|
||||
@@ -366,6 +377,8 @@ class FineTuningSubset(BaseSubset):
|
||||
caption_dropout_rate,
|
||||
caption_dropout_every_n_epochs,
|
||||
caption_tag_dropout_rate,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
)
|
||||
|
||||
self.metadata_file = metadata_file
|
||||
@@ -406,6 +419,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
||||
|
||||
self.current_step: int = 0
|
||||
self.max_train_steps: int = 0
|
||||
self.seed: int = 0
|
||||
|
||||
# augmentation
|
||||
self.aug_helper = AugHelper()
|
||||
|
||||
@@ -421,9 +438,19 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.replacements = {}
|
||||
|
||||
def set_seed(self, seed):
|
||||
self.seed = seed
|
||||
|
||||
def set_current_epoch(self, epoch):
|
||||
self.current_epoch = epoch
|
||||
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
|
||||
self.shuffle_buckets()
|
||||
self.current_epoch = epoch
|
||||
|
||||
def set_current_step(self, step):
|
||||
self.current_step = step
|
||||
|
||||
def set_max_train_steps(self, max_train_steps):
|
||||
self.max_train_steps = max_train_steps
|
||||
|
||||
def set_tag_frequency(self, dir_name, captions):
|
||||
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
||||
@@ -458,7 +485,16 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if is_drop_out:
|
||||
caption = ""
|
||||
else:
|
||||
if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0:
|
||||
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
||||
tokens = [t.strip() for t in caption.strip().split(",")]
|
||||
if subset.token_warmup_step < 1: # 初回に上書きする
|
||||
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
||||
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
||||
tokens_len = (
|
||||
math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
|
||||
+ subset.token_warmup_min
|
||||
)
|
||||
tokens = tokens[:tokens_len]
|
||||
|
||||
def dropout_tags(tokens):
|
||||
if subset.caption_tag_dropout_rate <= 0:
|
||||
@@ -470,10 +506,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
return l
|
||||
|
||||
fixed_tokens = []
|
||||
flex_tokens = [t.strip() for t in caption.strip().split(",")]
|
||||
flex_tokens = tokens[:]
|
||||
if subset.keep_tokens > 0:
|
||||
fixed_tokens = flex_tokens[: subset.keep_tokens]
|
||||
flex_tokens = flex_tokens[subset.keep_tokens :]
|
||||
flex_tokens = tokens[subset.keep_tokens :]
|
||||
|
||||
if subset.shuffle_caption:
|
||||
random.shuffle(flex_tokens)
|
||||
@@ -643,6 +679,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self._length = len(self.buckets_indices)
|
||||
|
||||
def shuffle_buckets(self):
|
||||
# set random seed for this epoch
|
||||
random.seed(self.seed + self.current_epoch)
|
||||
|
||||
random.shuffle(self.buckets_indices)
|
||||
self.bucket_manager.shuffle()
|
||||
|
||||
@@ -1062,7 +1101,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
self.register_image(info, subset)
|
||||
n += info.num_repeats
|
||||
else:
|
||||
info.num_repeats += 1
|
||||
info.num_repeats += 1 # rewrite registered info
|
||||
n += 1
|
||||
if n >= num_train_images:
|
||||
break
|
||||
@@ -1123,6 +1162,8 @@ class FineTuningDataset(BaseDataset):
|
||||
# path情報を作る
|
||||
if os.path.exists(image_key):
|
||||
abs_path = image_key
|
||||
elif os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
|
||||
abs_path = os.path.splitext(image_key)[0] + ".npz"
|
||||
else:
|
||||
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
if os.path.exists(npz_path):
|
||||
@@ -1308,6 +1349,14 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
for dataset in self.datasets:
|
||||
dataset.set_current_epoch(epoch)
|
||||
|
||||
def set_current_step(self, step):
|
||||
for dataset in self.datasets:
|
||||
dataset.set_current_step(step)
|
||||
|
||||
def set_max_train_steps(self, max_train_steps):
|
||||
for dataset in self.datasets:
|
||||
dataset.set_max_train_steps(max_train_steps)
|
||||
|
||||
def disable_token_padding(self):
|
||||
for dataset in self.datasets:
|
||||
dataset.disable_token_padding()
|
||||
@@ -1315,13 +1364,22 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
|
||||
def debug_dataset(train_dataset, show_input_ids=False):
|
||||
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
||||
print("Escape for exit. / Escキーで中断、終了します")
|
||||
print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します")
|
||||
|
||||
train_dataset.set_current_epoch(1)
|
||||
k = 0
|
||||
epoch = 1
|
||||
while True:
|
||||
print(f"epoch: {epoch}")
|
||||
|
||||
steps = (epoch - 1) * len(train_dataset) + 1
|
||||
indices = list(range(len(train_dataset)))
|
||||
random.shuffle(indices)
|
||||
|
||||
k = 0
|
||||
for i, idx in enumerate(indices):
|
||||
train_dataset.set_current_epoch(epoch)
|
||||
train_dataset.set_current_step(steps)
|
||||
print(f"steps: {steps} ({i + 1}/{len(train_dataset)})")
|
||||
|
||||
example = train_dataset[idx]
|
||||
if example["latents"] is not None:
|
||||
print(f"sample has latents from npz file: {example['latents'].size()}")
|
||||
@@ -1341,10 +1399,19 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
cv2.imshow("img", im)
|
||||
k = cv2.waitKey()
|
||||
cv2.destroyAllWindows()
|
||||
if k == 27:
|
||||
if k == 27 or k == ord("s") or k == ord("e"):
|
||||
break
|
||||
steps += 1
|
||||
|
||||
if k == ord("e"):
|
||||
break
|
||||
if k == 27 or (example["images"] is None and i >= 8):
|
||||
k = 27
|
||||
break
|
||||
if k == 27:
|
||||
break
|
||||
|
||||
epoch += 1
|
||||
|
||||
|
||||
def glob_images(directory, base="*"):
|
||||
@@ -1354,8 +1421,8 @@ def glob_images(directory, base="*"):
|
||||
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
|
||||
else:
|
||||
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
||||
# img_paths = list(set(img_paths)) # 重複を排除
|
||||
# img_paths.sort()
|
||||
img_paths = list(set(img_paths)) # 重複を排除
|
||||
img_paths.sort()
|
||||
return img_paths
|
||||
|
||||
|
||||
@@ -1367,8 +1434,8 @@ def glob_images_pathlib(dir_path, recursive):
|
||||
else:
|
||||
for ext in IMAGE_EXTENSIONS:
|
||||
image_paths += list(dir_path.glob("*" + ext))
|
||||
# image_paths = list(set(image_paths)) # 重複を排除
|
||||
# image_paths.sort()
|
||||
image_paths = list(set(image_paths)) # 重複を排除
|
||||
image_paths.sort()
|
||||
return image_paths
|
||||
|
||||
|
||||
@@ -2061,6 +2128,20 @@ def add_dataset_arguments(
|
||||
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--token_warmup_min",
|
||||
type=int,
|
||||
default=1,
|
||||
help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--token_warmup_step",
|
||||
type=float,
|
||||
default=0,
|
||||
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)",
|
||||
)
|
||||
|
||||
if support_caption_dropout:
|
||||
# Textual Inversion はcaptionのdropoutをsupportしない
|
||||
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
||||
@@ -2995,3 +3076,24 @@ class ImageLoadingDataset(torch.utils.data.Dataset):
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# collate_fn用 epoch,stepはmultiprocessing.Value
|
||||
class collater_class:
|
||||
def __init__(self, epoch, step, dataset):
|
||||
self.current_epoch = epoch
|
||||
self.current_step = step
|
||||
self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
|
||||
|
||||
def __call__(self, examples):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
# worker_info is None in the main process
|
||||
if worker_info is not None:
|
||||
dataset = worker_info.dataset
|
||||
else:
|
||||
dataset = self.dataset
|
||||
|
||||
# set epoch and step
|
||||
dataset.set_current_epoch(self.current_epoch.value)
|
||||
dataset.set_current_step(self.current_step.value)
|
||||
return examples[0]
|
||||
|
||||
@@ -11,6 +11,8 @@ import numpy as np
|
||||
|
||||
MIN_SV = 1e-6
|
||||
|
||||
# Model save and load functions
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if model_util.is_safetensors(file_name):
|
||||
sd = load_file(file_name)
|
||||
@@ -39,12 +41,13 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
# Indexing functions
|
||||
|
||||
def index_sv_cumulative(S, target):
|
||||
original_sum = float(torch.sum(S))
|
||||
cumulative_sums = torch.cumsum(S, dim=0)/original_sum
|
||||
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
||||
if index >= len(S):
|
||||
index = len(S) - 1
|
||||
index = max(1, min(index, len(S)-1))
|
||||
|
||||
return index
|
||||
|
||||
@@ -54,8 +57,16 @@ def index_sv_fro(S, target):
|
||||
s_fro_sq = float(torch.sum(S_squared))
|
||||
sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
|
||||
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
||||
if index >= len(S):
|
||||
index = len(S) - 1
|
||||
index = max(1, min(index, len(S)-1))
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def index_sv_ratio(S, target):
|
||||
max_sv = S[0]
|
||||
min_sv = max_sv/target
|
||||
index = int(torch.sum(S > min_sv).item())
|
||||
index = max(1, min(index, len(S)-1))
|
||||
|
||||
return index
|
||||
|
||||
@@ -125,26 +136,24 @@ def merge_linear(lora_down, lora_up, device):
|
||||
return weight
|
||||
|
||||
|
||||
# Calculate new rank
|
||||
|
||||
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||
param_dict = {}
|
||||
|
||||
if dynamic_method=="sv_ratio":
|
||||
# Calculate new dim and alpha based off ratio
|
||||
max_sv = S[0]
|
||||
min_sv = max_sv/dynamic_param
|
||||
new_rank = max(torch.sum(S > min_sv).item(),1)
|
||||
new_rank = index_sv_ratio(S, dynamic_param) + 1
|
||||
new_alpha = float(scale*new_rank)
|
||||
|
||||
elif dynamic_method=="sv_cumulative":
|
||||
# Calculate new dim and alpha based off cumulative sum
|
||||
new_rank = index_sv_cumulative(S, dynamic_param)
|
||||
new_rank = max(new_rank, 1)
|
||||
new_rank = index_sv_cumulative(S, dynamic_param) + 1
|
||||
new_alpha = float(scale*new_rank)
|
||||
|
||||
elif dynamic_method=="sv_fro":
|
||||
# Calculate new dim and alpha based off sqrt sum of squares
|
||||
new_rank = index_sv_fro(S, dynamic_param)
|
||||
new_rank = min(max(new_rank, 1), len(S)-1)
|
||||
new_rank = index_sv_fro(S, dynamic_param) + 1
|
||||
new_alpha = float(scale*new_rank)
|
||||
else:
|
||||
new_rank = rank
|
||||
@@ -172,7 +181,7 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||
param_dict["new_alpha"] = new_alpha
|
||||
param_dict["sum_retained"] = (s_rank)/s_sum
|
||||
param_dict["fro_retained"] = fro_percent
|
||||
param_dict["max_ratio"] = S[0]/S[new_rank]
|
||||
param_dict["max_ratio"] = S[0]/S[new_rank - 1]
|
||||
|
||||
return param_dict
|
||||
|
||||
|
||||
28
train_db.py
28
train_db.py
@@ -8,6 +8,7 @@ import itertools
|
||||
import math
|
||||
import os
|
||||
import toml
|
||||
from multiprocessing import Value
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
@@ -21,10 +22,8 @@ from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
|
||||
|
||||
def collate_fn(examples):
|
||||
return examples[0]
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight
|
||||
|
||||
|
||||
def train(args):
|
||||
@@ -59,6 +58,11 @@ def train(args):
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||
|
||||
if args.no_token_padding:
|
||||
train_dataset_group.disable_token_padding()
|
||||
|
||||
@@ -152,16 +156,21 @@ def train(args):
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
collate_fn=collater,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
if args.stop_text_encoder_training is None:
|
||||
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
||||
|
||||
@@ -229,7 +238,7 @@ def train(args):
|
||||
loss_total = 0.0
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
train_dataset_group.set_current_epoch(epoch + 1)
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
||||
unet.train()
|
||||
@@ -238,6 +247,7 @@ def train(args):
|
||||
text_encoder.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
# 指定したステップ数でText Encoderの学習を止める
|
||||
if global_step == args.stop_text_encoder_training:
|
||||
print(f"stop text encoder training at step {global_step}")
|
||||
@@ -291,6 +301,9 @@ def train(args):
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
@@ -390,6 +403,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--no_token_padding",
|
||||
|
||||
@@ -8,6 +8,7 @@ import random
|
||||
import time
|
||||
import json
|
||||
import toml
|
||||
from multiprocessing import Value
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
@@ -23,10 +24,8 @@ from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
|
||||
|
||||
def collate_fn(examples):
|
||||
return examples[0]
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight
|
||||
|
||||
|
||||
# TODO 他のスクリプトと共通化する
|
||||
@@ -100,6 +99,11 @@ def train(args):
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
current_epoch = Value('i',0)
|
||||
current_step = Value('i',0)
|
||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collater = train_util.collater_class(current_epoch,current_step, ds_for_collater)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
return
|
||||
@@ -185,11 +189,12 @@ def train(args):
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
collate_fn=collater,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
@@ -200,6 +205,9 @@ def train(args):
|
||||
if is_main_process:
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
@@ -488,22 +496,23 @@ def train(args):
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("network_train")
|
||||
|
||||
loss_list = []
|
||||
loss_total = 0.0
|
||||
del train_dataset_group
|
||||
for epoch in range(num_train_epochs):
|
||||
if is_main_process:
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
train_dataset_group.set_current_epoch(epoch + 1)
|
||||
current_epoch.value = epoch+1
|
||||
|
||||
metadata["ss_epoch"] = str(epoch + 1)
|
||||
|
||||
network.on_epoch_start(text_encoder, unet)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(network):
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
@@ -528,7 +537,6 @@ def train(args):
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
@@ -549,6 +557,9 @@ def train(args):
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
@@ -652,6 +663,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
|
||||
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
||||
parser.add_argument(
|
||||
|
||||
@@ -4,6 +4,7 @@ import gc
|
||||
import math
|
||||
import os
|
||||
import toml
|
||||
from multiprocessing import Value
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
@@ -17,6 +18,8 @@ from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight
|
||||
|
||||
imagenet_templates_small = [
|
||||
"a photo of a {}",
|
||||
@@ -71,10 +74,6 @@ imagenet_style_templates_small = [
|
||||
]
|
||||
|
||||
|
||||
def collate_fn(examples):
|
||||
return examples[0]
|
||||
|
||||
|
||||
def train(args):
|
||||
if args.output_name is None:
|
||||
args.output_name = args.token_string
|
||||
@@ -185,6 +184,11 @@ def train(args):
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
current_epoch = Value('i',0)
|
||||
current_step = Value('i',0)
|
||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collater = train_util.collater_class(current_epoch,current_step, ds_for_collater)
|
||||
|
||||
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
||||
if use_template:
|
||||
print("use template for training captions. is object: {args.use_object_template}")
|
||||
@@ -250,7 +254,7 @@ def train(args):
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
collate_fn=collater,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
@@ -260,6 +264,9 @@ def train(args):
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
@@ -331,12 +338,14 @@ def train(args):
|
||||
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
train_dataset_group.set_current_epoch(epoch + 1)
|
||||
current_epoch.value = epoch+1
|
||||
|
||||
text_encoder.train()
|
||||
|
||||
loss_total = 0
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(text_encoder):
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
@@ -378,6 +387,9 @@ def train(args):
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
@@ -534,6 +546,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--save_model_as",
|
||||
|
||||
Reference in New Issue
Block a user