From 4dc1124f9339af9886f4a73db49e9e7c9cf17a23 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sun, 26 Mar 2023 02:19:55 +0900 Subject: [PATCH] =?UTF-8?q?lora=E4=BB=A5=E5=A4=96=E3=82=82=E5=AF=BE?= =?UTF-8?q?=E5=BF=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fine_tune.py | 15 ++++++++------- library/train_util.py | 11 +++++++++++ train_db.py | 15 ++++++++------- train_network.py | 13 +------------ train_textual_inversion.py | 16 +++++++++------- 5 files changed, 37 insertions(+), 33 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index ff580435..96ec3d96 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -6,6 +6,7 @@ import gc import math import os import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -21,10 +22,6 @@ from library.config_util import ( ) -def collate_fn(examples): - return examples[0] - - def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) @@ -65,6 +62,10 @@ def train(args): config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value('i',0) + current_step = Value('i',0) + collater = train_util.collater_class(current_epoch,current_step) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return @@ -188,7 +189,7 @@ def train(args): train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -259,14 +260,14 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) - train_dataset_group.set_current_step(global_step) + current_epoch.value = epoch+1 for m in training_models: m.train() loss_total = 0 for step, batch in enumerate(train_dataloader): + current_step.value = global_step with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: diff --git a/library/train_util.py b/library/train_util.py index 223e403b..994201fc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2987,3 +2987,14 @@ class ImageLoadingDataset(torch.utils.data.Dataset): # endregion + +# colalte_fn用 epoch,stepはmultiprocessing.Value +class collater_class: + def __init__(self,epoch,step): + self.current_epoch=epoch + self.current_step=step + def __call__(self, examples): + dataset = torch.utils.data.get_worker_info().dataset + dataset.set_current_epoch(self.current_epoch.value) + dataset.set_current_step(self.current_step.value) + return examples[0] \ No newline at end of file diff --git a/train_db.py b/train_db.py index 87fe771b..50d50345 100644 --- a/train_db.py +++ b/train_db.py @@ -8,6 +8,7 @@ import itertools import math import os import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -23,10 +24,6 @@ from library.config_util import ( ) -def collate_fn(examples): - return examples[0] - - def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, False) @@ -60,6 +57,10 @@ def train(args): config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value('i',0) + current_step = Value('i',0) + collater = train_util.collater_class(current_epoch,current_step) + if args.no_token_padding: train_dataset_group.disable_token_padding() @@ -153,7 +154,7 @@ def train(args): train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -233,8 +234,7 @@ def train(args): loss_total = 0.0 for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) - train_dataset_group.set_current_step(global_step) + current_epoch.value = epoch+1 # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() @@ -243,6 +243,7 @@ def train(args): text_encoder.train() for step, batch in enumerate(train_dataloader): + current_step.value = global_step # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: print(f"stop text encoder training at step {global_step}") diff --git a/train_network.py b/train_network.py index dd1fb748..79d118d0 100644 --- a/train_network.py +++ b/train_network.py @@ -25,17 +25,6 @@ from library.config_util import ( BlueprintGenerator, ) -class collater_class: - def __init__(self,epoch,step): - self.current_epoch=epoch - self.current_step=step - def __call__(self, examples): - dataset = torch.utils.data.get_worker_info().dataset - dataset.set_current_epoch(self.current_epoch.value) - dataset.set_current_step(self.current_step.value) - return examples[0] - - # TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): logs = {"loss/current": current_loss, "loss/average": avr_loss} @@ -110,7 +99,7 @@ def train(args): current_epoch = Value('i',0) current_step = Value('i',0) - collater = collater_class(current_epoch,current_step) + collater = train_util.collater_class(current_epoch,current_step) if args.debug_dataset: train_util.debug_dataset(train_dataset_group) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 63b63426..4f2e2724 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -4,6 +4,7 @@ import gc import math import os import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -71,10 +72,6 @@ imagenet_style_templates_small = [ ] -def collate_fn(examples): - return examples[0] - - def train(args): if args.output_name is None: args.output_name = args.token_string @@ -186,6 +183,10 @@ def train(args): config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value('i',0) + current_step = Value('i',0) + collater = train_util.collater_class(current_epoch,current_step) + # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: print("use template for training captions. is object: {args.use_object_template}") @@ -251,7 +252,7 @@ def train(args): train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -335,13 +336,14 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) - train_dataset_group.set_current_step(global_step) + current_epoch.value = epoch+1 text_encoder.train() loss_total = 0 + for step, batch in enumerate(train_dataloader): + current_step.value = global_step with accelerator.accumulate(text_encoder): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: