Merge remote-tracking branch 'upstream/main'

This commit is contained in:
Jakaline-dev
2023-03-30 01:35:38 +09:00
11 changed files with 3050 additions and 2526 deletions

102
README.md
View File

@@ -127,6 +127,43 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
## Change History ## Change History
- 28 Mar. 2023, 2023/3/28:
- Fix an issue that the training script crashes when `max_data_loader_n_workers` is 0.
- `max_data_loader_n_workers` が0の時に学習スクリプトがエラーとなる不具合を修正しました。
- 27 Mar. 2023, 2023/3/27:
- Fix issues when `--persistent_data_loader_workers` is specified.
- The batch members of the bucket are not shuffled.
- `--caption_dropout_every_n_epochs` does not work.
- These issues occurred because the epoch transition was not recognized correctly. Thanks to u-haru for reporting the issue.
- Fix an issue that images are loaded twice in Windows environment.
- Add Min-SNR Weighting strategy. Details are in [#308](https://github.com/kohya-ss/sd-scripts/pull/308). Thank you to AI-Casanova for this great work!
- Add `--min_snr_gamma` option to training scripts, 5 is recommended by paper.
- Add tag warmup. Details are in [#322](https://github.com/kohya-ss/sd-scripts/pull/322). Thanks to u-haru!
- Add `token_warmup_min` and `token_warmup_step` to dataset settings.
- Gradually increase the number of tokens from `token_warmup_min` to `token_warmup_step`.
- For example, if `token_warmup_min` is `3` and `token_warmup_step` is `10`, the first step will use the first 3 tokens, and the 10th step will use all tokens.
- Fix a bug in `resize_lora.py`. Thanks to mgz-dev! [#328](https://github.com/kohya-ss/sd-scripts/pull/328)
- Add `--debug_dataset` option to step to the next step with `S` key and to the next epoch with `E` key.
- Fix other bugs.
- `--persistent_data_loader_workers` を指定した時の各種不具合を修正しました。
- `--caption_dropout_every_n_epochs` が効かない。
- バケットのバッチメンバーがシャッフルされない。
- エポックの遷移が正しく認識されないために発生していました。ご指摘いただいたu-haru氏に感謝します。
- Windows環境で画像が二重に読み込まれる不具合を修正しました。
- Min-SNR Weighting strategyを追加しました。 詳細は [#308](https://github.com/kohya-ss/sd-scripts/pull/308) をご参照ください。AI-Casanova氏の素晴らしい貢献に感謝します。
- `--min_snr_gamma` オプションを学習スクリプトに追加しました。論文では5が推奨されています。
- タグのウォームアップを追加しました。詳細は [#322](https://github.com/kohya-ss/sd-scripts/pull/322) をご参照ください。u-haru氏に感謝します。
- データセット設定に `token_warmup_min` と `token_warmup_step` を追加しました。
- `token_warmup_min` で指定した数のトークン(カンマ区切りの文字列)から、`token_warmup_step` で指定したステップまで、段階的にトークンを増やしていきます。
- たとえば `token_warmup_min`に `3` を、`token_warmup_step` に `10` を指定すると、最初のステップでは最初から3個のトークンが使われ、10ステップ目では全てのトークンが使われます。
- `resize_lora.py` の不具合を修正しました。mgz-dev氏に感謝します。[#328](https://github.com/kohya-ss/sd-scripts/pull/328)
- `--debug_dataset` オプションで、`S`キーで次のステップへ、`E`キーで次のエポックへ進めるようにしました。
- その他の不具合を修正しました。
- 21 Mar. 2023, 2023/3/21: - 21 Mar. 2023, 2023/3/21:
- Add `--vae_batch_size` for faster latents caching to each training script. This batches VAE calls. - Add `--vae_batch_size` for faster latents caching to each training script. This batches VAE calls.
- Please start with`2` or `4` depending on the size of VRAM. - Please start with`2` or `4` depending on the size of VRAM.
@@ -143,50 +180,49 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
- Windows以外の環境向けに、画像ファイルの大文字の拡張子をサポートしました。 - Windows以外の環境向けに、画像ファイルの大文字の拡張子をサポートしました。
- `resize_lora.py` を dynamic rank rankが各LoRAモジュールで異なる場合、`conv_dim` が `network_dim` と異なる場合も含むの時に正しく動作しない不具合を修正しました。toshiaki氏に感謝します。 - `resize_lora.py` を dynamic rank rankが各LoRAモジュールで異なる場合、`conv_dim` が `network_dim` と異なる場合も含むの時に正しく動作しない不具合を修正しました。toshiaki氏に感謝します。
## Sample image generation during training
A prompt file might look like this, for example
- Sample image generation: ```
A prompt file might look like this, for example # prompt 1
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
``` # prompt 2
# prompt 1 masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28 ```
# prompt 2 Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used.
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
```
Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used. * `--n` Negative prompt up to the next option.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
* `--n` Negative prompt up to the next option. The prompt weighting such as `( )` and `[ ]` are working.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
The prompt weighting such as `( )` and `[ ]` are working. ## サンプル画像生成
プロンプトファイルは例えば以下のようになります。
- サンプル画像生成: ```
プロンプトファイルは例えば以下のようになります。 # prompt 1
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
``` # prompt 2
# prompt 1 masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28 ```
# prompt 2 `#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
```
`#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。 * `--n` Negative prompt up to the next option.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
* `--n` Negative prompt up to the next option. `( )` や `[ ]` などの重みづけも動作します。
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
`( )` や `[ ]` などの重みづけは動作しません。
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates. Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。 最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。

View File

@@ -6,6 +6,7 @@ import gc
import math import math
import os import os
import toml import toml
from multiprocessing import Value
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -19,10 +20,8 @@ from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
def collate_fn(examples):
return examples[0]
def train(args): def train(args):
@@ -64,6 +63,11 @@ def train(args):
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
if args.debug_dataset: if args.debug_dataset:
train_util.debug_dataset(train_dataset_group) train_util.debug_dataset(train_dataset_group)
return return
@@ -187,16 +191,21 @@ def train(args):
train_dataset_group, train_dataset_group,
batch_size=1, batch_size=1,
shuffle=True, shuffle=True,
collate_fn=collate_fn, collate_fn=collater,
num_workers=n_workers, num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers, persistent_workers=args.persistent_data_loader_workers,
) )
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
@@ -255,13 +264,14 @@ def train(args):
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1) current_epoch.value = epoch + 1
for m in training_models: for m in training_models:
m.train() m.train()
loss_total = 0 loss_total = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
with torch.no_grad(): with torch.no_grad():
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None:
@@ -302,7 +312,14 @@ def train(args):
else: else:
target = noise target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") if args.min_snr_gamma:
# do not mean over batch dimension for snr weight
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = loss.mean() # mean over batch dimension
else:
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0: if accelerator.sync_gradients and args.max_grad_norm != 0.0:
@@ -396,6 +413,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_saving_arguments(parser) train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser) train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,7 @@ from dataclasses import (
dataclass, dataclass,
) )
import functools import functools
import random
from textwrap import dedent, indent from textwrap import dedent, indent
import json import json
from pathlib import Path from pathlib import Path
@@ -56,6 +57,8 @@ class BaseSubsetParams:
caption_dropout_rate: float = 0.0 caption_dropout_rate: float = 0.0
caption_dropout_every_n_epochs: int = 0 caption_dropout_every_n_epochs: int = 0
caption_tag_dropout_rate: float = 0.0 caption_tag_dropout_rate: float = 0.0
token_warmup_min: int = 1
token_warmup_step: float = 0
@dataclass @dataclass
class DreamBoothSubsetParams(BaseSubsetParams): class DreamBoothSubsetParams(BaseSubsetParams):
@@ -137,6 +140,8 @@ class ConfigSanitizer:
"random_crop": bool, "random_crop": bool,
"shuffle_caption": bool, "shuffle_caption": bool,
"keep_tokens": int, "keep_tokens": int,
"token_warmup_min": int,
"token_warmup_step": Any(float,int),
} }
# DO means DropOut # DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = { DO_SUBSET_ASCENDABLE_SCHEMA = {
@@ -406,6 +411,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
flip_aug: {subset.flip_aug} flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range} face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop} random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
"""), " ") """), " ")
if is_dreambooth: if is_dreambooth:
@@ -422,9 +429,12 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
print(info) print(info)
# make buckets first because it determines the length of dataset # make buckets first because it determines the length of dataset
# and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
for i, dataset in enumerate(datasets): for i, dataset in enumerate(datasets):
print(f"[Dataset {i}]") print(f"[Dataset {i}]")
dataset.make_buckets() dataset.make_buckets()
dataset.set_seed(seed)
return DatasetGroup(datasets) return DatasetGroup(datasets)
@@ -491,7 +501,6 @@ def load_user_config(file: str) -> dict:
return config return config
# for config test # for config test
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@@ -0,0 +1,18 @@
import torch
import argparse
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
alpha = sqrt_alphas_cumprod
sigma = sqrt_one_minus_alphas_cumprod
all_snr = (alpha / sigma) ** 2
snr = torch.stack([all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr)
snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper
loss = loss * snr_weight
return loss
def add_custom_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨")

View File

@@ -1046,10 +1046,14 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
key_count = len(state_dict.keys()) key_count = len(state_dict.keys())
new_ckpt = {'state_dict': state_dict} new_ckpt = {'state_dict': state_dict}
if 'epoch' in checkpoint: # epoch and global_step are sometimes not int
epochs += checkpoint['epoch'] try:
if 'global_step' in checkpoint: if 'epoch' in checkpoint:
steps += checkpoint['global_step'] epochs += checkpoint['epoch']
if 'global_step' in checkpoint:
steps += checkpoint['global_step']
except:
pass
new_ckpt['epoch'] = epochs new_ckpt['epoch'] = epochs
new_ckpt['global_step'] = steps new_ckpt['global_step'] = steps

View File

@@ -276,6 +276,8 @@ class BaseSubset:
caption_dropout_rate: float, caption_dropout_rate: float,
caption_dropout_every_n_epochs: int, caption_dropout_every_n_epochs: int,
caption_tag_dropout_rate: float, caption_tag_dropout_rate: float,
token_warmup_min: int,
token_warmup_step: Union[float, int],
) -> None: ) -> None:
self.image_dir = image_dir self.image_dir = image_dir
self.num_repeats = num_repeats self.num_repeats = num_repeats
@@ -289,6 +291,9 @@ class BaseSubset:
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
self.caption_tag_dropout_rate = caption_tag_dropout_rate self.caption_tag_dropout_rate = caption_tag_dropout_rate
self.token_warmup_min = token_warmup_min # step=0におけるタグの数
self.token_warmup_step = token_warmup_step # NN<1ならN*max_train_stepsステップ目でタグの数が最大になる
self.img_count = 0 self.img_count = 0
@@ -309,6 +314,8 @@ class DreamBoothSubset(BaseSubset):
caption_dropout_rate, caption_dropout_rate,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_tag_dropout_rate, caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) -> None: ) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -324,6 +331,8 @@ class DreamBoothSubset(BaseSubset):
caption_dropout_rate, caption_dropout_rate,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_tag_dropout_rate, caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) )
self.is_reg = is_reg self.is_reg = is_reg
@@ -351,6 +360,8 @@ class FineTuningSubset(BaseSubset):
caption_dropout_rate, caption_dropout_rate,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_tag_dropout_rate, caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) -> None: ) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
@@ -366,6 +377,8 @@ class FineTuningSubset(BaseSubset):
caption_dropout_rate, caption_dropout_rate,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_tag_dropout_rate, caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) )
self.metadata_file = metadata_file self.metadata_file = metadata_file
@@ -406,6 +419,10 @@ class BaseDataset(torch.utils.data.Dataset):
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
self.current_step: int = 0
self.max_train_steps: int = 0
self.seed: int = 0
# augmentation # augmentation
self.aug_helper = AugHelper() self.aug_helper = AugHelper()
@@ -421,9 +438,19 @@ class BaseDataset(torch.utils.data.Dataset):
self.replacements = {} self.replacements = {}
def set_seed(self, seed):
self.seed = seed
def set_current_epoch(self, epoch): def set_current_epoch(self, epoch):
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
self.shuffle_buckets()
self.current_epoch = epoch self.current_epoch = epoch
self.shuffle_buckets()
def set_current_step(self, step):
self.current_step = step
def set_max_train_steps(self, max_train_steps):
self.max_train_steps = max_train_steps
def set_tag_frequency(self, dir_name, captions): def set_tag_frequency(self, dir_name, captions):
frequency_for_dir = self.tag_frequency.get(dir_name, {}) frequency_for_dir = self.tag_frequency.get(dir_name, {})
@@ -458,7 +485,16 @@ class BaseDataset(torch.utils.data.Dataset):
if is_drop_out: if is_drop_out:
caption = "" caption = ""
else: else:
if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
tokens = [t.strip() for t in caption.strip().split(",")]
if subset.token_warmup_step < 1: # 初回に上書きする
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
tokens_len = (
math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
+ subset.token_warmup_min
)
tokens = tokens[:tokens_len]
def dropout_tags(tokens): def dropout_tags(tokens):
if subset.caption_tag_dropout_rate <= 0: if subset.caption_tag_dropout_rate <= 0:
@@ -470,10 +506,10 @@ class BaseDataset(torch.utils.data.Dataset):
return l return l
fixed_tokens = [] fixed_tokens = []
flex_tokens = [t.strip() for t in caption.strip().split(",")] flex_tokens = tokens[:]
if subset.keep_tokens > 0: if subset.keep_tokens > 0:
fixed_tokens = flex_tokens[: subset.keep_tokens] fixed_tokens = flex_tokens[: subset.keep_tokens]
flex_tokens = flex_tokens[subset.keep_tokens :] flex_tokens = tokens[subset.keep_tokens :]
if subset.shuffle_caption: if subset.shuffle_caption:
random.shuffle(flex_tokens) random.shuffle(flex_tokens)
@@ -643,6 +679,9 @@ class BaseDataset(torch.utils.data.Dataset):
self._length = len(self.buckets_indices) self._length = len(self.buckets_indices)
def shuffle_buckets(self): def shuffle_buckets(self):
# set random seed for this epoch
random.seed(self.seed + self.current_epoch)
random.shuffle(self.buckets_indices) random.shuffle(self.buckets_indices)
self.bucket_manager.shuffle() self.bucket_manager.shuffle()
@@ -1062,7 +1101,7 @@ class DreamBoothDataset(BaseDataset):
self.register_image(info, subset) self.register_image(info, subset)
n += info.num_repeats n += info.num_repeats
else: else:
info.num_repeats += 1 info.num_repeats += 1 # rewrite registered info
n += 1 n += 1
if n >= num_train_images: if n >= num_train_images:
break break
@@ -1123,6 +1162,8 @@ class FineTuningDataset(BaseDataset):
# path情報を作る # path情報を作る
if os.path.exists(image_key): if os.path.exists(image_key):
abs_path = image_key abs_path = image_key
elif os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
abs_path = os.path.splitext(image_key)[0] + ".npz"
else: else:
npz_path = os.path.join(subset.image_dir, image_key + ".npz") npz_path = os.path.join(subset.image_dir, image_key + ".npz")
if os.path.exists(npz_path): if os.path.exists(npz_path):
@@ -1308,6 +1349,14 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for dataset in self.datasets: for dataset in self.datasets:
dataset.set_current_epoch(epoch) dataset.set_current_epoch(epoch)
def set_current_step(self, step):
for dataset in self.datasets:
dataset.set_current_step(step)
def set_max_train_steps(self, max_train_steps):
for dataset in self.datasets:
dataset.set_max_train_steps(max_train_steps)
def disable_token_padding(self): def disable_token_padding(self):
for dataset in self.datasets: for dataset in self.datasets:
dataset.disable_token_padding() dataset.disable_token_padding()
@@ -1315,37 +1364,55 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
def debug_dataset(train_dataset, show_input_ids=False): def debug_dataset(train_dataset, show_input_ids=False):
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
print("Escape for exit. / Escキーで中断、終了します") print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します")
train_dataset.set_current_epoch(1) epoch = 1
k = 0 while True:
indices = list(range(len(train_dataset))) print(f"epoch: {epoch}")
random.shuffle(indices)
for i, idx in enumerate(indices): steps = (epoch - 1) * len(train_dataset) + 1
example = train_dataset[idx] indices = list(range(len(train_dataset)))
if example["latents"] is not None: random.shuffle(indices)
print(f"sample has latents from npz file: {example['latents'].size()}")
for j, (ik, cap, lw, iid) in enumerate( k = 0
zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"]) for i, idx in enumerate(indices):
): train_dataset.set_current_epoch(epoch)
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"') train_dataset.set_current_step(steps)
if show_input_ids: print(f"steps: {steps} ({i + 1}/{len(train_dataset)})")
print(f"input ids: {iid}")
if example["images"] is not None: example = train_dataset[idx]
im = example["images"][j] if example["latents"] is not None:
print(f"image size: {im.size()}") print(f"sample has latents from npz file: {example['latents'].size()}")
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) for j, (ik, cap, lw, iid) in enumerate(
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"])
im = im[:, :, ::-1] # RGB -> BGR (OpenCV) ):
if os.name == "nt": # only windows print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
cv2.imshow("img", im) if show_input_ids:
k = cv2.waitKey() print(f"input ids: {iid}")
cv2.destroyAllWindows() if example["images"] is not None:
if k == 27: im = example["images"][j]
break print(f"image size: {im.size()}")
if k == 27 or (example["images"] is None and i >= 8): im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
if os.name == "nt": # only windows
cv2.imshow("img", im)
k = cv2.waitKey()
cv2.destroyAllWindows()
if k == 27 or k == ord("s") or k == ord("e"):
break
steps += 1
if k == ord("e"):
break
if k == 27 or (example["images"] is None and i >= 8):
k = 27
break
if k == 27:
break break
epoch += 1
def glob_images(directory, base="*"): def glob_images(directory, base="*"):
img_paths = [] img_paths = []
@@ -1354,8 +1421,8 @@ def glob_images(directory, base="*"):
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
else: else:
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
# img_paths = list(set(img_paths)) # 重複を排除 img_paths = list(set(img_paths)) # 重複を排除
# img_paths.sort() img_paths.sort()
return img_paths return img_paths
@@ -1367,8 +1434,8 @@ def glob_images_pathlib(dir_path, recursive):
else: else:
for ext in IMAGE_EXTENSIONS: for ext in IMAGE_EXTENSIONS:
image_paths += list(dir_path.glob("*" + ext)) image_paths += list(dir_path.glob("*" + ext))
# image_paths = list(set(image_paths)) # 重複を排除 image_paths = list(set(image_paths)) # 重複を排除
# image_paths.sort() image_paths.sort()
return image_paths return image_paths
@@ -2061,6 +2128,20 @@ def add_dataset_arguments(
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
) )
parser.add_argument(
"--token_warmup_min",
type=int,
default=1,
help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する",
)
parser.add_argument(
"--token_warmup_step",
type=float,
default=0,
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / NN<1ならN*max_train_stepsステップでタグ長が最大になる。デフォルトは0最初から最大",
)
if support_caption_dropout: if support_caption_dropout:
# Textual Inversion はcaptionのdropoutをsupportしない # Textual Inversion はcaptionのdropoutをsupportしない
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
@@ -2995,3 +3076,24 @@ class ImageLoadingDataset(torch.utils.data.Dataset):
# endregion # endregion
# collate_fn用 epoch,stepはmultiprocessing.Value
class collater_class:
def __init__(self, epoch, step, dataset):
self.current_epoch = epoch
self.current_step = step
self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
def __call__(self, examples):
worker_info = torch.utils.data.get_worker_info()
# worker_info is None in the main process
if worker_info is not None:
dataset = worker_info.dataset
else:
dataset = self.dataset
# set epoch and step
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]

View File

@@ -11,6 +11,8 @@ import numpy as np
MIN_SV = 1e-6 MIN_SV = 1e-6
# Model save and load functions
def load_state_dict(file_name, dtype): def load_state_dict(file_name, dtype):
if model_util.is_safetensors(file_name): if model_util.is_safetensors(file_name):
sd = load_file(file_name) sd = load_file(file_name)
@@ -39,12 +41,13 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
torch.save(model, file_name) torch.save(model, file_name)
# Indexing functions
def index_sv_cumulative(S, target): def index_sv_cumulative(S, target):
original_sum = float(torch.sum(S)) original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0)/original_sum cumulative_sums = torch.cumsum(S, dim=0)/original_sum
index = int(torch.searchsorted(cumulative_sums, target)) + 1 index = int(torch.searchsorted(cumulative_sums, target)) + 1
if index >= len(S): index = max(1, min(index, len(S)-1))
index = len(S) - 1
return index return index
@@ -54,8 +57,16 @@ def index_sv_fro(S, target):
s_fro_sq = float(torch.sum(S_squared)) s_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
if index >= len(S): index = max(1, min(index, len(S)-1))
index = len(S) - 1
return index
def index_sv_ratio(S, target):
max_sv = S[0]
min_sv = max_sv/target
index = int(torch.sum(S > min_sv).item())
index = max(1, min(index, len(S)-1))
return index return index
@@ -125,26 +136,24 @@ def merge_linear(lora_down, lora_up, device):
return weight return weight
# Calculate new rank
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
param_dict = {} param_dict = {}
if dynamic_method=="sv_ratio": if dynamic_method=="sv_ratio":
# Calculate new dim and alpha based off ratio # Calculate new dim and alpha based off ratio
max_sv = S[0] new_rank = index_sv_ratio(S, dynamic_param) + 1
min_sv = max_sv/dynamic_param
new_rank = max(torch.sum(S > min_sv).item(),1)
new_alpha = float(scale*new_rank) new_alpha = float(scale*new_rank)
elif dynamic_method=="sv_cumulative": elif dynamic_method=="sv_cumulative":
# Calculate new dim and alpha based off cumulative sum # Calculate new dim and alpha based off cumulative sum
new_rank = index_sv_cumulative(S, dynamic_param) new_rank = index_sv_cumulative(S, dynamic_param) + 1
new_rank = max(new_rank, 1)
new_alpha = float(scale*new_rank) new_alpha = float(scale*new_rank)
elif dynamic_method=="sv_fro": elif dynamic_method=="sv_fro":
# Calculate new dim and alpha based off sqrt sum of squares # Calculate new dim and alpha based off sqrt sum of squares
new_rank = index_sv_fro(S, dynamic_param) new_rank = index_sv_fro(S, dynamic_param) + 1
new_rank = min(max(new_rank, 1), len(S)-1)
new_alpha = float(scale*new_rank) new_alpha = float(scale*new_rank)
else: else:
new_rank = rank new_rank = rank
@@ -172,7 +181,7 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
param_dict["new_alpha"] = new_alpha param_dict["new_alpha"] = new_alpha
param_dict["sum_retained"] = (s_rank)/s_sum param_dict["sum_retained"] = (s_rank)/s_sum
param_dict["fro_retained"] = fro_percent param_dict["fro_retained"] = fro_percent
param_dict["max_ratio"] = S[0]/S[new_rank] param_dict["max_ratio"] = S[0]/S[new_rank - 1]
return param_dict return param_dict

View File

@@ -8,6 +8,7 @@ import itertools
import math import math
import os import os
import toml import toml
from multiprocessing import Value
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -21,10 +22,8 @@ from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
def collate_fn(examples):
return examples[0]
def train(args): def train(args):
@@ -59,6 +58,11 @@ def train(args):
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
if args.no_token_padding: if args.no_token_padding:
train_dataset_group.disable_token_padding() train_dataset_group.disable_token_padding()
@@ -152,16 +156,21 @@ def train(args):
train_dataset_group, train_dataset_group,
batch_size=1, batch_size=1,
shuffle=True, shuffle=True,
collate_fn=collate_fn, collate_fn=collater,
num_workers=n_workers, num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers, persistent_workers=args.persistent_data_loader_workers,
) )
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
if args.stop_text_encoder_training is None: if args.stop_text_encoder_training is None:
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
@@ -229,7 +238,7 @@ def train(args):
loss_total = 0.0 loss_total = 0.0
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1) current_epoch.value = epoch + 1
# 指定したステップ数までText Encoderを学習するepoch最初の状態 # 指定したステップ数までText Encoderを学習するepoch最初の状態
unet.train() unet.train()
@@ -238,6 +247,7 @@ def train(args):
text_encoder.train() text_encoder.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
current_step.value = global_step
# 指定したステップ数でText Encoderの学習を止める # 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training: if global_step == args.stop_text_encoder_training:
print(f"stop text encoder training at step {global_step}") print(f"stop text encoder training at step {global_step}")
@@ -291,6 +301,9 @@ def train(args):
loss_weights = batch["loss_weights"] # 各sampleごとのweight loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss) accelerator.backward(loss)
@@ -390,6 +403,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_saving_arguments(parser) train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser) train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument( parser.add_argument(
"--no_token_padding", "--no_token_padding",

View File

@@ -8,6 +8,7 @@ import random
import time import time
import json import json
import toml import toml
from multiprocessing import Value
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -23,10 +24,8 @@ from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
def collate_fn(examples):
return examples[0]
# TODO 他のスクリプトと共通化する # TODO 他のスクリプトと共通化する
@@ -100,6 +99,11 @@ def train(args):
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value('i',0)
current_step = Value('i',0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch,current_step, ds_for_collater)
if args.debug_dataset: if args.debug_dataset:
train_util.debug_dataset(train_dataset_group) train_util.debug_dataset(train_dataset_group)
return return
@@ -185,11 +189,12 @@ def train(args):
# dataloaderを準備する # dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる # DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset_group, train_dataset_group,
batch_size=1, batch_size=1,
shuffle=True, shuffle=True,
collate_fn=collate_fn, collate_fn=collater,
num_workers=n_workers, num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers, persistent_workers=args.persistent_data_loader_workers,
) )
@@ -200,6 +205,9 @@ def train(args):
if is_main_process: if is_main_process:
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
@@ -488,22 +496,23 @@ def train(args):
noise_scheduler = DDPMScheduler( noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
) )
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("network_train") accelerator.init_trackers("network_train")
loss_list = [] loss_list = []
loss_total = 0.0 loss_total = 0.0
del train_dataset_group
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
if is_main_process: if is_main_process:
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1) current_epoch.value = epoch+1
metadata["ss_epoch"] = str(epoch + 1) metadata["ss_epoch"] = str(epoch + 1)
network.on_epoch_start(text_encoder, unet) network.on_epoch_start(text_encoder, unet)
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(network): with accelerator.accumulate(network):
with torch.no_grad(): with torch.no_grad():
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None:
@@ -528,7 +537,6 @@ def train(args):
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long() timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep # Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process) # (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
@@ -549,6 +557,9 @@ def train(args):
loss_weights = batch["loss_weights"] # 各sampleごとのweight loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss) accelerator.backward(loss)
@@ -652,6 +663,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_training_arguments(parser, True) train_util.add_training_arguments(parser, True)
train_util.add_optimizer_arguments(parser) train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない") parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
parser.add_argument( parser.add_argument(

View File

@@ -4,6 +4,7 @@ import gc
import math import math
import os import os
import toml import toml
from multiprocessing import Value
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -17,6 +18,8 @@ from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
imagenet_templates_small = [ imagenet_templates_small = [
"a photo of a {}", "a photo of a {}",
@@ -71,10 +74,6 @@ imagenet_style_templates_small = [
] ]
def collate_fn(examples):
return examples[0]
def train(args): def train(args):
if args.output_name is None: if args.output_name is None:
args.output_name = args.token_string args.output_name = args.token_string
@@ -185,6 +184,11 @@ def train(args):
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value('i',0)
current_step = Value('i',0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch,current_step, ds_for_collater)
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
if use_template: if use_template:
print("use template for training captions. is object: {args.use_object_template}") print("use template for training captions. is object: {args.use_object_template}")
@@ -250,7 +254,7 @@ def train(args):
train_dataset_group, train_dataset_group,
batch_size=1, batch_size=1,
shuffle=True, shuffle=True,
collate_fn=collate_fn, collate_fn=collater,
num_workers=n_workers, num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers, persistent_workers=args.persistent_data_loader_workers,
) )
@@ -260,6 +264,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
@@ -331,12 +338,14 @@ def train(args):
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1) current_epoch.value = epoch+1
text_encoder.train() text_encoder.train()
loss_total = 0 loss_total = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(text_encoder): with accelerator.accumulate(text_encoder):
with torch.no_grad(): with torch.no_grad():
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None:
@@ -378,6 +387,9 @@ def train(args):
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3]) loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss_weights = batch["loss_weights"] # 各sampleごとのweight loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights loss = loss * loss_weights
@@ -534,6 +546,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_training_arguments(parser, True) train_util.add_training_arguments(parser, True)
train_util.add_optimizer_arguments(parser) train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument( parser.add_argument(
"--save_model_as", "--save_model_as",