mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'dev' into min-SNR
This commit is contained in:
17
fine_tune.py
17
fine_tune.py
@@ -6,6 +6,7 @@ import gc
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import toml
|
import toml
|
||||||
|
from multiprocessing import Value
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
@@ -22,10 +23,6 @@ from library.config_util import (
|
|||||||
import library.custom_train_functions as custom_train_functions
|
import library.custom_train_functions as custom_train_functions
|
||||||
from library.custom_train_functions import apply_snr_weight
|
from library.custom_train_functions import apply_snr_weight
|
||||||
|
|
||||||
def collate_fn(examples):
|
|
||||||
return examples[0]
|
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
train_util.verify_training_args(args)
|
train_util.verify_training_args(args)
|
||||||
train_util.prepare_dataset_args(args, True)
|
train_util.prepare_dataset_args(args, True)
|
||||||
@@ -65,6 +62,10 @@ def train(args):
|
|||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
|
||||||
|
current_epoch = Value('i',0)
|
||||||
|
current_step = Value('i',0)
|
||||||
|
collater = train_util.collater_class(current_epoch,current_step)
|
||||||
|
|
||||||
if args.debug_dataset:
|
if args.debug_dataset:
|
||||||
train_util.debug_dataset(train_dataset_group)
|
train_util.debug_dataset(train_dataset_group)
|
||||||
return
|
return
|
||||||
@@ -188,7 +189,7 @@ def train(args):
|
|||||||
train_dataset_group,
|
train_dataset_group,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collater,
|
||||||
num_workers=n_workers,
|
num_workers=n_workers,
|
||||||
persistent_workers=args.persistent_data_loader_workers,
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
)
|
)
|
||||||
@@ -198,6 +199,9 @@ def train(args):
|
|||||||
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
|
# データセット側にも学習ステップを送信
|
||||||
|
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||||
|
|
||||||
@@ -256,13 +260,14 @@ def train(args):
|
|||||||
|
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
train_dataset_group.set_current_epoch(epoch + 1)
|
current_epoch.value = epoch+1
|
||||||
|
|
||||||
for m in training_models:
|
for m in training_models:
|
||||||
m.train()
|
m.train()
|
||||||
|
|
||||||
loss_total = 0
|
loss_total = 0
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
|
|||||||
4912
gen_img_diffusers.py
4912
gen_img_diffusers.py
File diff suppressed because it is too large
Load Diff
@@ -56,6 +56,8 @@ class BaseSubsetParams:
|
|||||||
caption_dropout_rate: float = 0.0
|
caption_dropout_rate: float = 0.0
|
||||||
caption_dropout_every_n_epochs: int = 0
|
caption_dropout_every_n_epochs: int = 0
|
||||||
caption_tag_dropout_rate: float = 0.0
|
caption_tag_dropout_rate: float = 0.0
|
||||||
|
token_warmup_min: int = 1
|
||||||
|
token_warmup_step: float = 0
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DreamBoothSubsetParams(BaseSubsetParams):
|
class DreamBoothSubsetParams(BaseSubsetParams):
|
||||||
@@ -137,6 +139,8 @@ class ConfigSanitizer:
|
|||||||
"random_crop": bool,
|
"random_crop": bool,
|
||||||
"shuffle_caption": bool,
|
"shuffle_caption": bool,
|
||||||
"keep_tokens": int,
|
"keep_tokens": int,
|
||||||
|
"token_warmup_min": int,
|
||||||
|
"token_warmup_step": Any(float,int),
|
||||||
}
|
}
|
||||||
# DO means DropOut
|
# DO means DropOut
|
||||||
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
||||||
@@ -406,6 +410,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
|||||||
flip_aug: {subset.flip_aug}
|
flip_aug: {subset.flip_aug}
|
||||||
face_crop_aug_range: {subset.face_crop_aug_range}
|
face_crop_aug_range: {subset.face_crop_aug_range}
|
||||||
random_crop: {subset.random_crop}
|
random_crop: {subset.random_crop}
|
||||||
|
token_warmup_min: {subset.token_warmup_min},
|
||||||
|
token_warmup_step: {subset.token_warmup_step},
|
||||||
"""), " ")
|
"""), " ")
|
||||||
|
|
||||||
if is_dreambooth:
|
if is_dreambooth:
|
||||||
@@ -491,7 +497,6 @@ def load_user_config(file: str) -> dict:
|
|||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
# for config test
|
# for config test
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|||||||
@@ -276,6 +276,8 @@ class BaseSubset:
|
|||||||
caption_dropout_rate: float,
|
caption_dropout_rate: float,
|
||||||
caption_dropout_every_n_epochs: int,
|
caption_dropout_every_n_epochs: int,
|
||||||
caption_tag_dropout_rate: float,
|
caption_tag_dropout_rate: float,
|
||||||
|
token_warmup_min: int,
|
||||||
|
token_warmup_step: Union[float,int],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.image_dir = image_dir
|
self.image_dir = image_dir
|
||||||
self.num_repeats = num_repeats
|
self.num_repeats = num_repeats
|
||||||
@@ -289,6 +291,9 @@ class BaseSubset:
|
|||||||
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
|
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
|
||||||
self.caption_tag_dropout_rate = caption_tag_dropout_rate
|
self.caption_tag_dropout_rate = caption_tag_dropout_rate
|
||||||
|
|
||||||
|
self.token_warmup_min = token_warmup_min # step=0におけるタグの数
|
||||||
|
self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる
|
||||||
|
|
||||||
self.img_count = 0
|
self.img_count = 0
|
||||||
|
|
||||||
|
|
||||||
@@ -309,6 +314,8 @@ class DreamBoothSubset(BaseSubset):
|
|||||||
caption_dropout_rate,
|
caption_dropout_rate,
|
||||||
caption_dropout_every_n_epochs,
|
caption_dropout_every_n_epochs,
|
||||||
caption_tag_dropout_rate,
|
caption_tag_dropout_rate,
|
||||||
|
token_warmup_min,
|
||||||
|
token_warmup_step,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||||
|
|
||||||
@@ -324,6 +331,8 @@ class DreamBoothSubset(BaseSubset):
|
|||||||
caption_dropout_rate,
|
caption_dropout_rate,
|
||||||
caption_dropout_every_n_epochs,
|
caption_dropout_every_n_epochs,
|
||||||
caption_tag_dropout_rate,
|
caption_tag_dropout_rate,
|
||||||
|
token_warmup_min,
|
||||||
|
token_warmup_step,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.is_reg = is_reg
|
self.is_reg = is_reg
|
||||||
@@ -351,6 +360,8 @@ class FineTuningSubset(BaseSubset):
|
|||||||
caption_dropout_rate,
|
caption_dropout_rate,
|
||||||
caption_dropout_every_n_epochs,
|
caption_dropout_every_n_epochs,
|
||||||
caption_tag_dropout_rate,
|
caption_tag_dropout_rate,
|
||||||
|
token_warmup_min,
|
||||||
|
token_warmup_step,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
||||||
|
|
||||||
@@ -366,6 +377,8 @@ class FineTuningSubset(BaseSubset):
|
|||||||
caption_dropout_rate,
|
caption_dropout_rate,
|
||||||
caption_dropout_every_n_epochs,
|
caption_dropout_every_n_epochs,
|
||||||
caption_tag_dropout_rate,
|
caption_tag_dropout_rate,
|
||||||
|
token_warmup_min,
|
||||||
|
token_warmup_step,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.metadata_file = metadata_file
|
self.metadata_file = metadata_file
|
||||||
@@ -404,6 +417,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
||||||
|
|
||||||
|
self.current_step: int = 0
|
||||||
|
self.max_train_steps: int = 0
|
||||||
|
|
||||||
# augmentation
|
# augmentation
|
||||||
self.aug_helper = AugHelper()
|
self.aug_helper = AugHelper()
|
||||||
|
|
||||||
@@ -420,8 +436,15 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
self.replacements = {}
|
self.replacements = {}
|
||||||
|
|
||||||
def set_current_epoch(self, epoch):
|
def set_current_epoch(self, epoch):
|
||||||
|
if not self.current_epoch == epoch:
|
||||||
|
self.shuffle_buckets()
|
||||||
self.current_epoch = epoch
|
self.current_epoch = epoch
|
||||||
self.shuffle_buckets()
|
|
||||||
|
def set_current_step(self, step):
|
||||||
|
self.current_step = step
|
||||||
|
|
||||||
|
def set_max_train_steps(self, max_train_steps):
|
||||||
|
self.max_train_steps = max_train_steps
|
||||||
|
|
||||||
def set_tag_frequency(self, dir_name, captions):
|
def set_tag_frequency(self, dir_name, captions):
|
||||||
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
||||||
@@ -452,7 +475,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
if is_drop_out:
|
if is_drop_out:
|
||||||
caption = ""
|
caption = ""
|
||||||
else:
|
else:
|
||||||
if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0:
|
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
||||||
|
|
||||||
|
tokens = [t.strip() for t in caption.strip().split(",")]
|
||||||
|
if subset.token_warmup_step < 1:
|
||||||
|
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
||||||
|
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
||||||
|
tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min
|
||||||
|
tokens = tokens[:tokens_len]
|
||||||
|
|
||||||
def dropout_tags(tokens):
|
def dropout_tags(tokens):
|
||||||
if subset.caption_tag_dropout_rate <= 0:
|
if subset.caption_tag_dropout_rate <= 0:
|
||||||
@@ -464,10 +494,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
return l
|
return l
|
||||||
|
|
||||||
fixed_tokens = []
|
fixed_tokens = []
|
||||||
flex_tokens = [t.strip() for t in caption.strip().split(",")]
|
flex_tokens = tokens[:]
|
||||||
if subset.keep_tokens > 0:
|
if subset.keep_tokens > 0:
|
||||||
fixed_tokens = flex_tokens[: subset.keep_tokens]
|
fixed_tokens = flex_tokens[: subset.keep_tokens]
|
||||||
flex_tokens = flex_tokens[subset.keep_tokens :]
|
flex_tokens = tokens[subset.keep_tokens :]
|
||||||
|
|
||||||
if subset.shuffle_caption:
|
if subset.shuffle_caption:
|
||||||
random.shuffle(flex_tokens)
|
random.shuffle(flex_tokens)
|
||||||
@@ -1285,6 +1315,14 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
|||||||
for dataset in self.datasets:
|
for dataset in self.datasets:
|
||||||
dataset.set_current_epoch(epoch)
|
dataset.set_current_epoch(epoch)
|
||||||
|
|
||||||
|
def set_current_step(self, step):
|
||||||
|
for dataset in self.datasets:
|
||||||
|
dataset.set_current_step(step)
|
||||||
|
|
||||||
|
def set_max_train_steps(self, max_train_steps):
|
||||||
|
for dataset in self.datasets:
|
||||||
|
dataset.set_max_train_steps(max_train_steps)
|
||||||
|
|
||||||
def disable_token_padding(self):
|
def disable_token_padding(self):
|
||||||
for dataset in self.datasets:
|
for dataset in self.datasets:
|
||||||
dataset.disable_token_padding()
|
dataset.disable_token_padding()
|
||||||
@@ -2038,6 +2076,20 @@ def add_dataset_arguments(
|
|||||||
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
|
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--token_warmup_min",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--token_warmup_step",
|
||||||
|
type=float,
|
||||||
|
default=0,
|
||||||
|
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)",
|
||||||
|
)
|
||||||
|
|
||||||
if support_caption_dropout:
|
if support_caption_dropout:
|
||||||
# Textual Inversion はcaptionのdropoutをsupportしない
|
# Textual Inversion はcaptionのdropoutをsupportしない
|
||||||
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
||||||
@@ -2972,3 +3024,14 @@ class ImageLoadingDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
# colalte_fn用 epoch,stepはmultiprocessing.Value
|
||||||
|
class collater_class:
|
||||||
|
def __init__(self,epoch,step):
|
||||||
|
self.current_epoch=epoch
|
||||||
|
self.current_step=step
|
||||||
|
def __call__(self, examples):
|
||||||
|
dataset = torch.utils.data.get_worker_info().dataset
|
||||||
|
dataset.set_current_epoch(self.current_epoch.value)
|
||||||
|
dataset.set_current_step(self.current_step.value)
|
||||||
|
return examples[0]
|
||||||
17
train_db.py
17
train_db.py
@@ -8,6 +8,7 @@ import itertools
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import toml
|
import toml
|
||||||
|
from multiprocessing import Value
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
@@ -24,10 +25,6 @@ from library.config_util import (
|
|||||||
import library.custom_train_functions as custom_train_functions
|
import library.custom_train_functions as custom_train_functions
|
||||||
from library.custom_train_functions import apply_snr_weight
|
from library.custom_train_functions import apply_snr_weight
|
||||||
|
|
||||||
def collate_fn(examples):
|
|
||||||
return examples[0]
|
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
train_util.verify_training_args(args)
|
train_util.verify_training_args(args)
|
||||||
train_util.prepare_dataset_args(args, False)
|
train_util.prepare_dataset_args(args, False)
|
||||||
@@ -60,6 +57,10 @@ def train(args):
|
|||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
|
||||||
|
current_epoch = Value('i',0)
|
||||||
|
current_step = Value('i',0)
|
||||||
|
collater = train_util.collater_class(current_epoch,current_step)
|
||||||
|
|
||||||
if args.no_token_padding:
|
if args.no_token_padding:
|
||||||
train_dataset_group.disable_token_padding()
|
train_dataset_group.disable_token_padding()
|
||||||
|
|
||||||
@@ -153,7 +154,7 @@ def train(args):
|
|||||||
train_dataset_group,
|
train_dataset_group,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collater,
|
||||||
num_workers=n_workers,
|
num_workers=n_workers,
|
||||||
persistent_workers=args.persistent_data_loader_workers,
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
)
|
)
|
||||||
@@ -163,6 +164,9 @@ def train(args):
|
|||||||
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
|
# データセット側にも学習ステップを送信
|
||||||
|
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||||
|
|
||||||
if args.stop_text_encoder_training is None:
|
if args.stop_text_encoder_training is None:
|
||||||
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
||||||
|
|
||||||
@@ -230,7 +234,7 @@ def train(args):
|
|||||||
loss_total = 0.0
|
loss_total = 0.0
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
train_dataset_group.set_current_epoch(epoch + 1)
|
current_epoch.value = epoch+1
|
||||||
|
|
||||||
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
||||||
unet.train()
|
unet.train()
|
||||||
@@ -239,6 +243,7 @@ def train(args):
|
|||||||
text_encoder.train()
|
text_encoder.train()
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
current_step.value = global_step
|
||||||
# 指定したステップ数でText Encoderの学習を止める
|
# 指定したステップ数でText Encoderの学習を止める
|
||||||
if global_step == args.stop_text_encoder_training:
|
if global_step == args.stop_text_encoder_training:
|
||||||
print(f"stop text encoder training at step {global_step}")
|
print(f"stop text encoder training at step {global_step}")
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import random
|
|||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
import toml
|
import toml
|
||||||
|
from multiprocessing import Value
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
@@ -26,9 +27,6 @@ from library.config_util import (
|
|||||||
import library.custom_train_functions as custom_train_functions
|
import library.custom_train_functions as custom_train_functions
|
||||||
from library.custom_train_functions import apply_snr_weight
|
from library.custom_train_functions import apply_snr_weight
|
||||||
|
|
||||||
def collate_fn(examples):
|
|
||||||
return examples[0]
|
|
||||||
|
|
||||||
|
|
||||||
# TODO 他のスクリプトと共通化する
|
# TODO 他のスクリプトと共通化する
|
||||||
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
||||||
@@ -101,6 +99,10 @@ def train(args):
|
|||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
|
||||||
|
current_epoch = Value('i',0)
|
||||||
|
current_step = Value('i',0)
|
||||||
|
collater = train_util.collater_class(current_epoch,current_step)
|
||||||
|
|
||||||
if args.debug_dataset:
|
if args.debug_dataset:
|
||||||
train_util.debug_dataset(train_dataset_group)
|
train_util.debug_dataset(train_dataset_group)
|
||||||
return
|
return
|
||||||
@@ -186,11 +188,12 @@ def train(args):
|
|||||||
# dataloaderを準備する
|
# dataloaderを準備する
|
||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||||
|
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
train_dataset_group,
|
train_dataset_group,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collater,
|
||||||
num_workers=n_workers,
|
num_workers=n_workers,
|
||||||
persistent_workers=args.persistent_data_loader_workers,
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
)
|
)
|
||||||
@@ -201,6 +204,9 @@ def train(args):
|
|||||||
if is_main_process:
|
if is_main_process:
|
||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
|
# データセット側にも学習ステップを送信
|
||||||
|
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||||
|
|
||||||
@@ -494,16 +500,18 @@ def train(args):
|
|||||||
|
|
||||||
loss_list = []
|
loss_list = []
|
||||||
loss_total = 0.0
|
loss_total = 0.0
|
||||||
|
del train_dataset_group
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
train_dataset_group.set_current_epoch(epoch + 1)
|
current_epoch.value = epoch+1
|
||||||
|
|
||||||
metadata["ss_epoch"] = str(epoch + 1)
|
metadata["ss_epoch"] = str(epoch + 1)
|
||||||
|
|
||||||
network.on_epoch_start(text_encoder, unet)
|
network.on_epoch_start(text_encoder, unet)
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(network):
|
with accelerator.accumulate(network):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import gc
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import toml
|
import toml
|
||||||
|
from multiprocessing import Value
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
@@ -73,10 +74,6 @@ imagenet_style_templates_small = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(examples):
|
|
||||||
return examples[0]
|
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
if args.output_name is None:
|
if args.output_name is None:
|
||||||
args.output_name = args.token_string
|
args.output_name = args.token_string
|
||||||
@@ -187,6 +184,10 @@ def train(args):
|
|||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
|
||||||
|
current_epoch = Value('i',0)
|
||||||
|
current_step = Value('i',0)
|
||||||
|
collater = train_util.collater_class(current_epoch,current_step)
|
||||||
|
|
||||||
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
||||||
if use_template:
|
if use_template:
|
||||||
print("use template for training captions. is object: {args.use_object_template}")
|
print("use template for training captions. is object: {args.use_object_template}")
|
||||||
@@ -252,7 +253,7 @@ def train(args):
|
|||||||
train_dataset_group,
|
train_dataset_group,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collater,
|
||||||
num_workers=n_workers,
|
num_workers=n_workers,
|
||||||
persistent_workers=args.persistent_data_loader_workers,
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
)
|
)
|
||||||
@@ -262,6 +263,9 @@ def train(args):
|
|||||||
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
|
# データセット側にも学習ステップを送信
|
||||||
|
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||||
|
|
||||||
@@ -333,12 +337,14 @@ def train(args):
|
|||||||
|
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
train_dataset_group.set_current_epoch(epoch + 1)
|
current_epoch.value = epoch+1
|
||||||
|
|
||||||
text_encoder.train()
|
text_encoder.train()
|
||||||
|
|
||||||
loss_total = 0
|
loss_total = 0
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(text_encoder):
|
with accelerator.accumulate(text_encoder):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user