Merge branch 'main' into original-u-net

This commit is contained in:
Kohya S
2023-06-16 20:59:34 +09:00
10 changed files with 532 additions and 247 deletions

View File

@@ -162,6 +162,42 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
## Change History ## Change History
### 15 Jun. 2023, 2023/06/15
- Prodigy optimizer is supported in each training script. It is a member of D-Adaptation and is effective for DyLoRA training. [PR #585](https://github.com/kohya-ss/sd-scripts/pull/585) Please see the PR for details. Thanks to sdbds!
- Install the package with `pip install prodigyopt`. Then specify the option like `--optimizer_type="prodigy"`.
- Arbitrary Dataset is supported in each training script (except XTI). You can use it by defining a Dataset class that returns images and captions.
- Prepare a Python script and define a class that inherits `train_util.MinimalDataset`. Then specify the option like `--dataset_class package.module.DatasetClass` in each training script.
- Please refer to `MinimalDataset` for implementation. I will prepare a sample later.
- The following features have been added to the generation script.
- Added an option `--highres_fix_disable_control_net` to disable ControlNet in the 2nd stage of Highres. Fix. Please try it if the image is disturbed by some ControlNet such as Canny.
- Added Variants similar to sd-dynamic-propmpts in the prompt.
- If you specify `{spring|summer|autumn|winter}`, one of them will be randomly selected.
- If you specify `{2$$chocolate|vanilla|strawberry}`, two of them will be randomly selected.
- If you specify `{1-2$$ and $$chocolate|vanilla|strawberry}`, one or two of them will be randomly selected and connected by ` and `.
- You can specify the number of candidates in the range `0-2`. You cannot omit one side like `-2` or `1-`.
- It can also be specified for the prompt option.
- If you specify `e` or `E`, all candidates will be selected and the prompt will be repeated multiple times (`--images_per_prompt` is ignored). It may be useful for creating X/Y plots.
- You can also specify `--am {e$$0.2|0.4|0.6|0.8|1.0},{e$$0.4|0.7|1.0} --d 1234`. In this case, 15 prompts will be generated with 5*3.
- There is no weighting function.
- 各学習スクリプトでProdigyオプティマイザがサポートされました。D-Adaptationの仲間でDyLoRAの学習に有効とのことです。 [PR #585](https://github.com/kohya-ss/sd-scripts/pull/585) 詳細はPRをご覧ください。sdbds氏に感謝します。
- `pip install prodigyopt` としてパッケージをインストールしてください。また `--optimizer_type="prodigy"` のようにオプションを指定します。
- 各学習スクリプトで任意のDatasetをサポートしましたXTIを除く。画像とキャプションを返すDatasetクラスを定義することで、学習スクリプトから利用できます。
- Pythonスクリプトを用意し、`train_util.MinimalDataset`を継承するクラスを定義してください。そして各学習スクリプトのオプションで `--dataset_class package.module.DatasetClass` のように指定してください。
- 実装方法は `MinimalDataset` を参考にしてください。のちほどサンプルを用意します。
- 生成スクリプトに以下の機能追加を行いました。
- Highres. Fixの2nd stageでControlNetを無効化するオプション `--highres_fix_disable_control_net` を追加しました。Canny等一部のControlNetで画像が乱れる場合にお試しください。
- プロンプトでsd-dynamic-propmptsに似たVariantをサポートしました。
- `{spring|summer|autumn|winter}` のように指定すると、いずれかがランダムに選択されます。
- `{2$$chocolate|vanilla|strawberry}` のように指定すると、いずれか2個がランダムに選択されます。
- `{1-2$$ and $$chocolate|vanilla|strawberry}` のように指定すると、1個か2個がランダムに選択され ` and ` で接続されます。
- 個数のレンジ指定では`0-2`のように0個も指定可能です。`-2`や`1-`のような片側の省略はできません。
- プロンプトオプションに対しても指定可能です。
- `{e$$chocolate|vanilla|strawberry}` のように`e`または`E`を指定すると、すべての候補が選択されプロンプトが複数回繰り返されます(`--images_per_prompt`は無視されます。X/Y plotの作成に便利かもしれません。
- `--am {e$$0.2|0.4|0.6|0.8|1.0},{e$$0.4|0.7|1.0} --d 1234`のような指定も可能です。この場合、5*3で15回のプロンプトが生成されます。
- Weightingの機能はありません。
### 8 Jun. 2023, 2023/06/08 ### 8 Jun. 2023, 2023/06/08
- Fixed a bug where clip skip did not work when training with weighted captions (`--weighted_captions` specified) and when generating sample images during training. - Fixed a bug where clip skip did not work when training with weighted captions (`--weighted_captions` specified) and when generating sample images during training.

View File

@@ -622,6 +622,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
- DAdaptAdanIP : 引数は同上 - DAdaptAdanIP : 引数は同上
- DAdaptLion : 引数は同上 - DAdaptLion : 引数は同上
- DAdaptSGD : 引数は同上 - DAdaptSGD : 引数は同上
- Prodigy : https://github.com/konstmish/prodigy
- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules) - AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
- 任意のオプティマイザ - 任意のオプティマイザ

View File

@@ -555,9 +555,10 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
- DAdaptAdam : 参数同上 - DAdaptAdam : 参数同上
- DAdaptAdaGrad : 参数同上 - DAdaptAdaGrad : 参数同上
- DAdaptAdan : 参数同上 - DAdaptAdan : 参数同上
- DAdaptAdanIP : 引数は同上 - DAdaptAdanIP : 参数同上
- DAdaptLion : 参数同上 - DAdaptLion : 参数同上
- DAdaptSGD : 参数同上 - DAdaptSGD : 参数同上
- Prodigy : https://github.com/konstmish/prodigy
- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules) - AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
- 任何优化器 - 任何优化器

View File

@@ -42,6 +42,8 @@ def train(args):
tokenizer = train_util.load_tokenizer(args) tokenizer = train_util.load_tokenizer(args)
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
if args.dataset_config is not None: if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}") print(f"Load dataset config from {args.dataset_config}")
@@ -69,6 +71,8 @@ def train(args):
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
current_epoch = Value("i", 0) current_epoch = Value("i", 0)
current_step = Value("i", 0) current_step = Value("i", 0)
@@ -393,7 +397,7 @@ def train(args):
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy": # tracking d*lr value
logs["lr/d*lr"] = ( logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
) )

View File

@@ -46,6 +46,7 @@ VGG(
) )
""" """
import itertools
import json import json
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
import glob import glob
@@ -441,11 +442,15 @@ class PipelineLike:
# ControlNet # ControlNet
self.control_nets: List[ControlNetInfo] = [] self.control_nets: List[ControlNetInfo] = []
self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない
# Textual Inversion # Textual Inversion
def add_token_replacement(self, target_token_id, rep_token_ids): def add_token_replacement(self, target_token_id, rep_token_ids):
self.token_replacements[target_token_id] = rep_token_ids self.token_replacements[target_token_id] = rep_token_ids
def set_enable_control_net(self, en: bool):
self.control_net_enabled = en
def replace_token(self, tokens, layer=None): def replace_token(self, tokens, layer=None):
new_tokens = [] new_tokens = []
for token in tokens: for token in tokens:
@@ -938,7 +943,7 @@ class PipelineLike:
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # predict the noise residual
if self.control_nets: if self.control_nets and self.control_net_enabled:
if reginonal_network: if reginonal_network:
num_sub_and_neg_prompts = len(text_embeddings) // batch_size num_sub_and_neg_prompts = len(text_embeddings) // batch_size
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
@@ -1986,6 +1991,110 @@ def preprocess_mask(mask):
return mask return mask
# regular expression for dynamic prompt:
# starts and ends with "{" and "}"
# contains at least one variant divided by "|"
# optional framgments divided by "$$" at start
# if the first fragment is "E" or "e", enumerate all variants
# if the second fragment is a number or two numbers, repeat the variants in the range
# if the third fragment is a string, use it as a separator
RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}")
def handle_dynamic_prompt_variants(prompt, repeat_count):
founds = list(RE_DYNAMIC_PROMPT.finditer(prompt))
if not founds:
return [prompt]
# make each replacement for each variant
enumerating = False
replacers = []
for found in founds:
# if "e$$" is found, enumerate all variants
found_enumerating = found.group(2) is not None
enumerating = enumerating or found_enumerating
separator = ", " if found.group(6) is None else found.group(6)
variants = found.group(7).split("|")
# parse count range
count_range = found.group(4)
if count_range is None:
count_range = [1, 1]
else:
count_range = count_range.split("-")
if len(count_range) == 1:
count_range = [int(count_range[0]), int(count_range[0])]
elif len(count_range) == 2:
count_range = [int(count_range[0]), int(count_range[1])]
else:
print(f"invalid count range: {count_range}")
count_range = [1, 1]
if count_range[0] > count_range[1]:
count_range = [count_range[1], count_range[0]]
if count_range[0] < 0:
count_range[0] = 0
if count_range[1] > len(variants):
count_range[1] = len(variants)
if found_enumerating:
# make function to enumerate all combinations
def make_replacer_enum(vari, cr, sep):
def replacer():
values = []
for count in range(cr[0], cr[1] + 1):
for comb in itertools.combinations(vari, count):
values.append(sep.join(comb))
return values
return replacer
replacers.append(make_replacer_enum(variants, count_range, separator))
else:
# make function to choose random combinations
def make_replacer_single(vari, cr, sep):
def replacer():
count = random.randint(cr[0], cr[1])
comb = random.sample(vari, count)
return [sep.join(comb)]
return replacer
replacers.append(make_replacer_single(variants, count_range, separator))
# make each prompt
if not enumerating:
# if not enumerating, repeat the prompt, replace each variant randomly
prompts = []
for _ in range(repeat_count):
current = prompt
for found, replacer in zip(founds, replacers):
current = current.replace(found.group(0), replacer()[0], 1)
prompts.append(current)
else:
# if enumerating, iterate all combinations for previous prompts
prompts = [prompt]
for found, replacer in zip(founds, replacers):
if found.group(2) is not None:
# make all combinations for existing prompts
new_prompts = []
for current in prompts:
replecements = replacer()
for replecement in replecements:
new_prompts.append(current.replace(found.group(0), replecement, 1))
prompts = new_prompts
for found, replacer in zip(founds, replacers):
# make random selection for existing prompts
if found.group(2) is None:
for i in range(len(prompts)):
prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1)
return prompts
# endregion # endregion
@@ -2615,6 +2724,7 @@ def main(args):
# seed指定時はseedを決めておく # seed指定時はseedを決めておく
if args.seed is not None: if args.seed is not None:
# dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう
random.seed(args.seed) random.seed(args.seed)
predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)]
if len(predefined_seeds) == 1: if len(predefined_seeds) == 1:
@@ -2666,6 +2776,8 @@ def main(args):
ext.num_sub_prompts, ext.num_sub_prompts,
) )
batch_1st.append(BatchData(is_1st_latent, base, ext_1st)) batch_1st.append(BatchData(is_1st_latent, base, ext_1st))
pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする
images_1st = process_batch(batch_1st, True, True) images_1st = process_batch(batch_1st, True, True)
# 2nd stageのバッチを作成して以下処理する # 2nd stageのバッチを作成して以下処理する
@@ -2709,6 +2821,9 @@ def main(args):
batch_2nd.append(bd_2nd) batch_2nd.append(bd_2nd)
batch = batch_2nd batch = batch_2nd
if args.highres_fix_disable_control_net:
pipe.set_enable_control_net(False) # オプション指定時、2nd stageではControlNetを無効にする
# このバッチの情報を取り出す # このバッチの情報を取り出す
( (
return_latents, return_latents,
@@ -2897,29 +3012,39 @@ def main(args):
while not valid: while not valid:
print("\nType prompt:") print("\nType prompt:")
try: try:
prompt = input() raw_prompt = input()
except EOFError: except EOFError:
break break
valid = len(prompt.strip().split(" --")[0].strip()) > 0 valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0
if not valid: # EOF, end app if not valid: # EOF, end app
break break
else: else:
prompt = prompt_list[prompt_index] raw_prompt = prompt_list[prompt_index]
# parse prompt # sd-dynamic-prompts like variants:
# count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration)
raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt)
# repeat prompt
for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)):
raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0]
if pi == 0 or len(raw_prompts) > 1:
# parse prompt: if prompt is not changed, skip parsing
width = args.W width = args.W
height = args.H height = args.H
scale = args.scale scale = args.scale
negative_scale = args.negative_scale negative_scale = args.negative_scale
steps = args.steps steps = args.steps
seed = None
seeds = None seeds = None
strength = 0.8 if args.strength is None else args.strength strength = 0.8 if args.strength is None else args.strength
negative_prompt = "" negative_prompt = ""
clip_prompt = None clip_prompt = None
network_muls = None network_muls = None
prompt_args = prompt.strip().split(" --") prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0] prompt = prompt_args[0]
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
@@ -2994,24 +3119,28 @@ def main(args):
print(f"Exception in parsing / 解析エラー: {parg}") print(f"Exception in parsing / 解析エラー: {parg}")
print(ex) print(ex)
if seeds is not None: # prepare seed
# 数が足りないなら繰り返す if seeds is not None: # given in prompt
if len(seeds) < args.images_per_prompt: # 数が足りないなら前のをそのまま使う
seeds = seeds * int(math.ceil(args.images_per_prompt / len(seeds))) if len(seeds) > 0:
seeds = seeds[: args.images_per_prompt] seed = seeds.pop(0)
else: else:
if predefined_seeds is not None: if predefined_seeds is not None:
seeds = predefined_seeds[-args.images_per_prompt :] if len(predefined_seeds) > 0:
predefined_seeds = predefined_seeds[: -args.images_per_prompt] seed = predefined_seeds.pop(0)
elif args.iter_same_seed:
seeds = [iter_seed] * args.images_per_prompt
else: else:
seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.images_per_prompt)] print("predefined seeds are exhausted")
seed = None
elif args.iter_same_seed:
seeds = iter_seed
if seed is None:
seed = random.randint(0, 0x7FFFFFFF)
if args.interactive: if args.interactive:
print(f"seed: {seeds}") print(f"seed: {seed}")
# prepare init image, guide image and mask
init_image = mask_image = guide_image = None init_image = mask_image = guide_image = None
for seed in seeds: # images_per_promptの数だけ
# 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する
if init_images is not None: if init_images is not None:
init_image = init_images[global_step % len(init_images)] init_image = init_images[global_step % len(init_images)]
@@ -3294,6 +3423,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None, default=None,
help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数", help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数",
) )
parser.add_argument(
"--highres_fix_disable_control_net",
action="store_true",
help="disable ControlNet for highres fix / highres fixでControlNetを使わない",
)
parser.add_argument( parser.add_argument(
"--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する" "--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する"

View File

@@ -1519,6 +1519,76 @@ def glob_images_pathlib(dir_path, recursive):
return image_paths return image_paths
class MinimalDataset(BaseDataset):
def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False):
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
self.num_train_images = 0 # update in subclass
self.num_reg_images = 0 # update in subclass
self.datasets = [self]
self.batch_size = 1 # update in subclass
self.subsets = [self]
self.num_repeats = 1 # update in subclass if needed
self.img_count = 1 # update in subclass if needed
self.bucket_info = {}
self.is_reg = False
self.image_dir = "dummy" # for metadata
def is_latent_cacheable(self) -> bool:
return False
def __len__(self):
raise NotImplementedError
# override to avoid shuffling buckets
def set_current_epoch(self, epoch):
self.current_epoch = epoch
def __getitem__(self, idx):
r"""
The subclass may have image_data for debug_dataset, which is a dict of ImageInfo objects.
Returns: example like this:
for i in range(batch_size):
image_key = ... # whatever hashable
image_keys.append(image_key)
image = ... # PIL Image
img_tensor = self.image_transforms(img)
images.append(img_tensor)
caption = ... # str
input_ids = self.get_input_ids(caption)
input_ids_list.append(input_ids)
captions.append(caption)
images = torch.stack(images, dim=0)
input_ids_list = torch.stack(input_ids_list, dim=0)
example = {
"images": images,
"input_ids": input_ids_list,
"captions": captions, # for debug_dataset
"latents": None,
"image_keys": image_keys, # for debug_dataset
"loss_weights": torch.ones(batch_size, dtype=torch.float32),
}
return example
"""
raise NotImplementedError
def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
module = ".".join(args.dataset_class.split(".")[:-1])
dataset_class = args.dataset_class.split(".")[-1]
module = importlib.import_module(module)
dataset_class = getattr(module, dataset_class)
train_dataset_group: MinimalDataset = dataset_class(tokenizer, args.max_token_length, args.resolution, args.debug_dataset)
return train_dataset_group
# endregion # endregion
# region モジュール入れ替え部 # region モジュール入れ替え部
@@ -2321,7 +2391,6 @@ def add_dataset_arguments(
default=1, default=1,
help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する", help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する",
) )
parser.add_argument( parser.add_argument(
"--token_warmup_step", "--token_warmup_step",
type=float, type=float,
@@ -2329,6 +2398,13 @@ def add_dataset_arguments(
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / NN<1ならN*max_train_stepsステップでタグ長が最大になる。デフォルトは0最初から最大", help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / NN<1ならN*max_train_stepsステップでタグ長が最大になる。デフォルトは0最初から最大",
) )
parser.add_argument(
"--dataset_class",
type=str,
default=None,
help="dataset class for arbitrary dataset (package.module.Class) / 任意のデータセットを用いるときのクラス名 (package.module.Class)",
)
if support_caption_dropout: if support_caption_dropout:
# Textual Inversion はcaptionのdropoutをsupportしない # Textual Inversion はcaptionのdropoutをsupportしない
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
@@ -2603,15 +2679,7 @@ def get_optimizer(args, trainable_params):
optimizer_class = torch.optim.SGD optimizer_class = torch.optim.SGD
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
elif optimizer_type.startswith("DAdapt".lower()): elif optimizer_type.startswith("DAdapt".lower()) or optimizer_type == "Prodigy".lower():
# DAdaptation family
# check dadaptation is installed
try:
import dadaptation
import dadaptation.experimental as experimental
except ImportError:
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
# check lr and lr_count, and print warning # check lr and lr_count, and print warning
actual_lr = lr actual_lr = lr
lr_count = 1 lr_count = 1
@@ -2624,14 +2692,23 @@ def get_optimizer(args, trainable_params):
if actual_lr <= 0.1: if actual_lr <= 0.1:
print( print(
f"learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}" f"learning rate is too low. If using D-Adaptation or Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。D-AdaptationまたはProdigyの使用時は1.0前後の値を指定してください: lr={actual_lr}"
) )
print("recommend option: lr=1.0 / 推奨は1.0です") print("recommend option: lr=1.0 / 推奨は1.0です")
if lr_count > 1: if lr_count > 1:
print( print(
f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合Text EncoderとU-Netなど、最初の学習率のみが有効になります: lr={actual_lr}" f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-AdaptationまたはProdigyで複数の学習率を指定した場合Text EncoderとU-Netなど、最初の学習率のみが有効になります: lr={actual_lr}"
) )
if optimizer_type.startswith("DAdapt".lower()):
# DAdaptation family
# check dadaptation is installed
try:
import dadaptation
import dadaptation.experimental as experimental
except ImportError:
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
# set optimizer # set optimizer
if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower(): if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower():
optimizer_class = experimental.DAdaptAdamPreprint optimizer_class = experimental.DAdaptAdamPreprint
@@ -2658,6 +2735,17 @@ def get_optimizer(args, trainable_params):
raise ValueError(f"Unknown optimizer type: {optimizer_type}") raise ValueError(f"Unknown optimizer type: {optimizer_type}")
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
else:
# Prodigy
# check Prodigy is installed
try:
import prodigyopt
except ImportError:
raise ImportError("No Prodigy / Prodigy がインストールされていないようです")
print(f"use Prodigy optimizer | {optimizer_kwargs}")
optimizer_class = prodigyopt.Prodigy
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "Adafactor".lower(): elif optimizer_type == "Adafactor".lower():
# 引数を確認して適宜補正する # 引数を確認して適宜補正する

View File

@@ -46,6 +46,8 @@ def train(args):
tokenizer = train_util.load_tokenizer(args) tokenizer = train_util.load_tokenizer(args)
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
if args.dataset_config is not None: if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}") print(f"Load dataset config from {args.dataset_config}")
@@ -66,6 +68,8 @@ def train(args):
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
current_epoch = Value("i", 0) current_epoch = Value("i", 0)
current_step = Value("i", 0) current_step = Value("i", 0)
@@ -380,7 +384,7 @@ def train(args):
current_loss = loss.detach().item() current_loss = loss.detach().item()
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value
logs["lr/d*lr"] = ( logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
) )

View File

@@ -57,7 +57,7 @@ def generate_step_logs(
logs["lr/textencoder"] = float(lrs[0]) logs["lr/textencoder"] = float(lrs[0])
logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder
if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value of unet. if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value of unet.
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
else: else:
idx = 0 idx = 0
@@ -67,7 +67,7 @@ def generate_step_logs(
for i in range(idx, len(lrs)): for i in range(idx, len(lrs)):
logs[f"lr/group{i}"] = float(lrs[i]) logs[f"lr/group{i}"] = float(lrs[i])
if args.optimizer_type.lower().startswith("DAdapt".lower()): if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
logs[f"lr/d*lr/group{i}"] = ( logs[f"lr/d*lr/group{i}"] = (
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
) )
@@ -92,6 +92,7 @@ def train(args):
tokenizer = train_util.load_tokenizer(args) tokenizer = train_util.load_tokenizer(args)
# データセットを準備する # データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
if use_user_config: if use_user_config:
print(f"Loading dataset config from {args.dataset_config}") print(f"Loading dataset config from {args.dataset_config}")
@@ -108,7 +109,11 @@ def train(args):
print("Using DreamBooth method.") print("Using DreamBooth method.")
user_config = { user_config = {
"datasets": [ "datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} {
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
] ]
} }
else: else:
@@ -128,6 +133,9 @@ def train(args):
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
# use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
current_epoch = Value("i", 0) current_epoch = Value("i", 0)
current_step = Value("i", 0) current_step = Value("i", 0)
@@ -185,6 +193,7 @@ def train(args):
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
print(f"all weights merged: {', '.join(args.base_weights)}") print(f"all weights merged: {', '.join(args.base_weights)}")
# 学習を準備する # 学習を準備する
if cache_latents: if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)

View File

@@ -153,6 +153,7 @@ def train(args):
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
# データセットを準備する # データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
if args.dataset_config is not None: if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}") print(f"Load dataset config from {args.dataset_config}")
@@ -190,6 +191,8 @@ def train(args):
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
current_epoch = Value("i", 0) current_epoch = Value("i", 0)
current_step = Value("i", 0) current_step = Value("i", 0)
@@ -475,7 +478,7 @@ def train(args):
current_loss = loss.detach().item() current_loss = loss.detach().item()
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value
logs["lr/d*lr"] = ( logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
) )

View File

@@ -96,6 +96,9 @@ def train(args):
print( print(
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません" "sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
) )
assert (
args.dataset_class is None
), "dataset_class is not supported in this script currently / dataset_classは現在このスクリプトではサポートされていません"
cache_latents = args.cache_latents cache_latents = args.cache_latents
@@ -520,7 +523,9 @@ def train(args):
current_loss = loss.detach().item() current_loss = loss.detach().item()
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
): # tracking d*lr value
logs["lr/d*lr"] = ( logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
) )