diff --git a/fine_tune.py b/fine_tune.py index 1acf478f..a4486841 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -6,6 +6,7 @@ import gc import math import os import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -21,10 +22,6 @@ from library.config_util import ( ) -def collate_fn(examples): - return examples[0] - - def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) @@ -64,6 +61,10 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value('i',0) + current_step = Value('i',0) + collater = train_util.collater_class(current_epoch,current_step) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return @@ -187,7 +188,7 @@ def train(args): train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -197,6 +198,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -255,13 +259,14 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) + current_epoch.value = epoch+1 for m in training_models: m.train() loss_total = 0 for step, batch in enumerate(train_dataloader): + current_step.value = global_step with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: diff --git a/library/config_util.py b/library/config_util.py index e62bfb89..b1543f63 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -56,6 +56,8 @@ class BaseSubsetParams: caption_dropout_rate: float = 0.0 caption_dropout_every_n_epochs: int = 0 caption_tag_dropout_rate: float = 0.0 + token_warmup_min: int = 1 + token_warmup_step: float = 0 @dataclass class DreamBoothSubsetParams(BaseSubsetParams): @@ -137,6 +139,8 @@ class ConfigSanitizer: "random_crop": bool, "shuffle_caption": bool, "keep_tokens": int, + "token_warmup_min": int, + "token_warmup_step": Any(float,int), } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { @@ -406,6 +410,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu flip_aug: {subset.flip_aug} face_crop_aug_range: {subset.face_crop_aug_range} random_crop: {subset.random_crop} + token_warmup_min: {subset.token_warmup_min}, + token_warmup_step: {subset.token_warmup_step}, """), " ") if is_dreambooth: @@ -491,7 +497,6 @@ def load_user_config(file: str) -> dict: return config - # for config test if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/library/train_util.py b/library/train_util.py index 97f5a702..131a263b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -276,6 +276,8 @@ class BaseSubset: caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float, + token_warmup_min: int, + token_warmup_step: Union[float,int], ) -> None: self.image_dir = image_dir self.num_repeats = num_repeats @@ -289,6 +291,9 @@ class BaseSubset: self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs self.caption_tag_dropout_rate = caption_tag_dropout_rate + self.token_warmup_min = token_warmup_min # step=0におけるタグの数 + self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる + self.img_count = 0 @@ -309,6 +314,8 @@ class DreamBoothSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -324,6 +331,8 @@ class DreamBoothSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, ) self.is_reg = is_reg @@ -351,6 +360,8 @@ class FineTuningSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -366,6 +377,8 @@ class FineTuningSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, ) self.metadata_file = metadata_file @@ -404,6 +417,9 @@ class BaseDataset(torch.utils.data.Dataset): self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ + self.current_step: int = 0 + self.max_train_steps: int = 0 + # augmentation self.aug_helper = AugHelper() @@ -420,8 +436,15 @@ class BaseDataset(torch.utils.data.Dataset): self.replacements = {} def set_current_epoch(self, epoch): + if not self.current_epoch == epoch: + self.shuffle_buckets() self.current_epoch = epoch - self.shuffle_buckets() + + def set_current_step(self, step): + self.current_step = step + + def set_max_train_steps(self, max_train_steps): + self.max_train_steps = max_train_steps def set_tag_frequency(self, dir_name, captions): frequency_for_dir = self.tag_frequency.get(dir_name, {}) @@ -452,7 +475,14 @@ class BaseDataset(torch.utils.data.Dataset): if is_drop_out: caption = "" else: - if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0: + if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: + + tokens = [t.strip() for t in caption.strip().split(",")] + if subset.token_warmup_step < 1: + subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) + if subset.token_warmup_step and self.current_step < subset.token_warmup_step: + tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min + tokens = tokens[:tokens_len] def dropout_tags(tokens): if subset.caption_tag_dropout_rate <= 0: @@ -464,10 +494,10 @@ class BaseDataset(torch.utils.data.Dataset): return l fixed_tokens = [] - flex_tokens = [t.strip() for t in caption.strip().split(",")] + flex_tokens = tokens[:] if subset.keep_tokens > 0: fixed_tokens = flex_tokens[: subset.keep_tokens] - flex_tokens = flex_tokens[subset.keep_tokens :] + flex_tokens = tokens[subset.keep_tokens :] if subset.shuffle_caption: random.shuffle(flex_tokens) @@ -1285,6 +1315,14 @@ class DatasetGroup(torch.utils.data.ConcatDataset): for dataset in self.datasets: dataset.set_current_epoch(epoch) + def set_current_step(self, step): + for dataset in self.datasets: + dataset.set_current_step(step) + + def set_max_train_steps(self, max_train_steps): + for dataset in self.datasets: + dataset.set_max_train_steps(max_train_steps) + def disable_token_padding(self): for dataset in self.datasets: dataset.disable_token_padding() @@ -2038,6 +2076,20 @@ def add_dataset_arguments( "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" ) + parser.add_argument( + "--token_warmup_min", + type=int, + default=1, + help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する", + ) + + parser.add_argument( + "--token_warmup_step", + type=float, + default=0, + help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)", + ) + if support_caption_dropout: # Textual Inversion はcaptionのdropoutをsupportしない # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに @@ -2972,3 +3024,14 @@ class ImageLoadingDataset(torch.utils.data.Dataset): # endregion + +# colalte_fn用 epoch,stepはmultiprocessing.Value +class collater_class: + def __init__(self,epoch,step): + self.current_epoch=epoch + self.current_step=step + def __call__(self, examples): + dataset = torch.utils.data.get_worker_info().dataset + dataset.set_current_epoch(self.current_epoch.value) + dataset.set_current_step(self.current_step.value) + return examples[0] \ No newline at end of file diff --git a/train_db.py b/train_db.py index 527f8e9b..904fcc60 100644 --- a/train_db.py +++ b/train_db.py @@ -8,6 +8,7 @@ import itertools import math import os import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -23,10 +24,6 @@ from library.config_util import ( ) -def collate_fn(examples): - return examples[0] - - def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, False) @@ -59,6 +56,10 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value('i',0) + current_step = Value('i',0) + collater = train_util.collater_class(current_epoch,current_step) + if args.no_token_padding: train_dataset_group.disable_token_padding() @@ -152,7 +153,7 @@ def train(args): train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -162,6 +163,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + if args.stop_text_encoder_training is None: args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end @@ -229,7 +233,7 @@ def train(args): loss_total = 0.0 for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) + current_epoch.value = epoch+1 # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() @@ -238,6 +242,7 @@ def train(args): text_encoder.train() for step, batch in enumerate(train_dataloader): + current_step.value = global_step # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: print(f"stop text encoder training at step {global_step}") diff --git a/train_network.py b/train_network.py index 083aad67..8d1502ca 100644 --- a/train_network.py +++ b/train_network.py @@ -8,6 +8,7 @@ import random import time import json import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -24,11 +25,6 @@ from library.config_util import ( BlueprintGenerator, ) - -def collate_fn(examples): - return examples[0] - - # TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): logs = {"loss/current": current_loss, "loss/average": avr_loss} @@ -100,6 +96,10 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value('i',0) + current_step = Value('i',0) + collater = train_util.collater_class(current_epoch,current_step) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return @@ -185,11 +185,12 @@ def train(args): # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -200,6 +201,9 @@ def train(args): if is_main_process: print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -494,16 +498,18 @@ def train(args): loss_list = [] loss_total = 0.0 + del train_dataset_group for epoch in range(num_train_epochs): if is_main_process: print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) + current_epoch.value = epoch+1 metadata["ss_epoch"] = str(epoch + 1) network.on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): + current_step.value = global_step with accelerator.accumulate(network): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 85f0d57c..a4ef45e3 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -4,6 +4,7 @@ import gc import math import os import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -71,10 +72,6 @@ imagenet_style_templates_small = [ ] -def collate_fn(examples): - return examples[0] - - def train(args): if args.output_name is None: args.output_name = args.token_string @@ -185,6 +182,10 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value('i',0) + current_step = Value('i',0) + collater = train_util.collater_class(current_epoch,current_step) + # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: print("use template for training captions. is object: {args.use_object_template}") @@ -250,7 +251,7 @@ def train(args): train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -260,6 +261,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -331,12 +335,14 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) + current_epoch.value = epoch+1 text_encoder.train() loss_total = 0 + for step, batch in enumerate(train_dataloader): + current_step.value = global_step with accelerator.accumulate(text_encoder): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: