mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
implement token warmup
This commit is contained in:
@@ -197,6 +197,9 @@ def train(args):
|
|||||||
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
|
# データセット側にも学習ステップを送信
|
||||||
|
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||||
|
|
||||||
@@ -263,6 +266,7 @@ def train(args):
|
|||||||
loss_total = 0
|
loss_total = 0
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
||||||
|
train_dataset_group.set_current_step(step + 1)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
latents = batch["latents"].to(accelerator.device)
|
latents = batch["latents"].to(accelerator.device)
|
||||||
|
|||||||
@@ -56,6 +56,8 @@ class BaseSubsetParams:
|
|||||||
caption_dropout_rate: float = 0.0
|
caption_dropout_rate: float = 0.0
|
||||||
caption_dropout_every_n_epochs: int = 0
|
caption_dropout_every_n_epochs: int = 0
|
||||||
caption_tag_dropout_rate: float = 0.0
|
caption_tag_dropout_rate: float = 0.0
|
||||||
|
token_warmup_min: int = 1
|
||||||
|
token_warmup_step: Union[float,int] = 0
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DreamBoothSubsetParams(BaseSubsetParams):
|
class DreamBoothSubsetParams(BaseSubsetParams):
|
||||||
@@ -137,6 +139,8 @@ class ConfigSanitizer:
|
|||||||
"random_crop": bool,
|
"random_crop": bool,
|
||||||
"shuffle_caption": bool,
|
"shuffle_caption": bool,
|
||||||
"keep_tokens": int,
|
"keep_tokens": int,
|
||||||
|
"token_warmup_min": int,
|
||||||
|
"token_warmup_step": Union[float,int],
|
||||||
}
|
}
|
||||||
# DO means DropOut
|
# DO means DropOut
|
||||||
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
||||||
@@ -406,6 +410,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
|||||||
flip_aug: {subset.flip_aug}
|
flip_aug: {subset.flip_aug}
|
||||||
face_crop_aug_range: {subset.face_crop_aug_range}
|
face_crop_aug_range: {subset.face_crop_aug_range}
|
||||||
random_crop: {subset.random_crop}
|
random_crop: {subset.random_crop}
|
||||||
|
token_warmup_min: {subset.token_warmup_min},
|
||||||
|
token_warmup_step: {subset.token_warmup_step},
|
||||||
"""), " ")
|
"""), " ")
|
||||||
|
|
||||||
if is_dreambooth:
|
if is_dreambooth:
|
||||||
|
|||||||
@@ -277,6 +277,8 @@ class BaseSubset:
|
|||||||
caption_dropout_rate: float,
|
caption_dropout_rate: float,
|
||||||
caption_dropout_every_n_epochs: int,
|
caption_dropout_every_n_epochs: int,
|
||||||
caption_tag_dropout_rate: float,
|
caption_tag_dropout_rate: float,
|
||||||
|
token_warmup_min: int,
|
||||||
|
token_warmup_step: Union[float,int],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.image_dir = image_dir
|
self.image_dir = image_dir
|
||||||
self.num_repeats = num_repeats
|
self.num_repeats = num_repeats
|
||||||
@@ -290,6 +292,9 @@ class BaseSubset:
|
|||||||
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
|
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
|
||||||
self.caption_tag_dropout_rate = caption_tag_dropout_rate
|
self.caption_tag_dropout_rate = caption_tag_dropout_rate
|
||||||
|
|
||||||
|
self.token_warmup_min = token_warmup_min # step=0におけるタグの数
|
||||||
|
self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる
|
||||||
|
|
||||||
self.img_count = 0
|
self.img_count = 0
|
||||||
|
|
||||||
|
|
||||||
@@ -310,6 +315,8 @@ class DreamBoothSubset(BaseSubset):
|
|||||||
caption_dropout_rate,
|
caption_dropout_rate,
|
||||||
caption_dropout_every_n_epochs,
|
caption_dropout_every_n_epochs,
|
||||||
caption_tag_dropout_rate,
|
caption_tag_dropout_rate,
|
||||||
|
token_warmup_min,
|
||||||
|
token_warmup_step,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||||
|
|
||||||
@@ -325,6 +332,8 @@ class DreamBoothSubset(BaseSubset):
|
|||||||
caption_dropout_rate,
|
caption_dropout_rate,
|
||||||
caption_dropout_every_n_epochs,
|
caption_dropout_every_n_epochs,
|
||||||
caption_tag_dropout_rate,
|
caption_tag_dropout_rate,
|
||||||
|
token_warmup_min,
|
||||||
|
token_warmup_step,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.is_reg = is_reg
|
self.is_reg = is_reg
|
||||||
@@ -352,6 +361,8 @@ class FineTuningSubset(BaseSubset):
|
|||||||
caption_dropout_rate,
|
caption_dropout_rate,
|
||||||
caption_dropout_every_n_epochs,
|
caption_dropout_every_n_epochs,
|
||||||
caption_tag_dropout_rate,
|
caption_tag_dropout_rate,
|
||||||
|
token_warmup_min,
|
||||||
|
token_warmup_step,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
||||||
|
|
||||||
@@ -367,6 +378,8 @@ class FineTuningSubset(BaseSubset):
|
|||||||
caption_dropout_rate,
|
caption_dropout_rate,
|
||||||
caption_dropout_every_n_epochs,
|
caption_dropout_every_n_epochs,
|
||||||
caption_tag_dropout_rate,
|
caption_tag_dropout_rate,
|
||||||
|
token_warmup_min,
|
||||||
|
token_warmup_step,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.metadata_file = metadata_file
|
self.metadata_file = metadata_file
|
||||||
@@ -405,6 +418,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
||||||
|
|
||||||
|
self.current_step: int = 0
|
||||||
|
self.max_train_steps: int = 0
|
||||||
|
|
||||||
# augmentation
|
# augmentation
|
||||||
self.aug_helper = AugHelper()
|
self.aug_helper = AugHelper()
|
||||||
|
|
||||||
@@ -424,6 +440,12 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
self.current_epoch = epoch
|
self.current_epoch = epoch
|
||||||
self.shuffle_buckets()
|
self.shuffle_buckets()
|
||||||
|
|
||||||
|
def set_current_step(self, step):
|
||||||
|
self.current_step = step
|
||||||
|
|
||||||
|
def set_max_train_steps(self, max_train_steps):
|
||||||
|
self.max_train_steps = max_train_steps
|
||||||
|
|
||||||
def set_tag_frequency(self, dir_name, captions):
|
def set_tag_frequency(self, dir_name, captions):
|
||||||
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
||||||
self.tag_frequency[dir_name] = frequency_for_dir
|
self.tag_frequency[dir_name] = frequency_for_dir
|
||||||
@@ -453,7 +475,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
if is_drop_out:
|
if is_drop_out:
|
||||||
caption = ""
|
caption = ""
|
||||||
else:
|
else:
|
||||||
if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0:
|
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
||||||
|
|
||||||
def dropout_tags(tokens):
|
def dropout_tags(tokens):
|
||||||
if subset.caption_tag_dropout_rate <= 0:
|
if subset.caption_tag_dropout_rate <= 0:
|
||||||
@@ -474,8 +496,15 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
random.shuffle(flex_tokens)
|
random.shuffle(flex_tokens)
|
||||||
|
|
||||||
flex_tokens = dropout_tags(flex_tokens)
|
flex_tokens = dropout_tags(flex_tokens)
|
||||||
|
tokens = fixed_tokens + flex_tokens
|
||||||
|
|
||||||
caption = ", ".join(fixed_tokens + flex_tokens)
|
if subset.token_warmup_step < 1:
|
||||||
|
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
||||||
|
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
||||||
|
tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min
|
||||||
|
tokens = tokens[:tokens_len]
|
||||||
|
|
||||||
|
caption = ", ".join(tokens)
|
||||||
|
|
||||||
# textual inversion対応
|
# textual inversion対応
|
||||||
for str_from, str_to in self.replacements.items():
|
for str_from, str_to in self.replacements.items():
|
||||||
@@ -1249,6 +1278,14 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
|||||||
for dataset in self.datasets:
|
for dataset in self.datasets:
|
||||||
dataset.set_current_epoch(epoch)
|
dataset.set_current_epoch(epoch)
|
||||||
|
|
||||||
|
def set_current_step(self, step):
|
||||||
|
for dataset in self.datasets:
|
||||||
|
dataset.set_current_step(step)
|
||||||
|
|
||||||
|
def set_max_train_steps(self, max_train_steps):
|
||||||
|
for dataset in self.datasets:
|
||||||
|
dataset.set_max_train_steps(max_train_steps)
|
||||||
|
|
||||||
def disable_token_padding(self):
|
def disable_token_padding(self):
|
||||||
for dataset in self.datasets:
|
for dataset in self.datasets:
|
||||||
dataset.disable_token_padding()
|
dataset.disable_token_padding()
|
||||||
@@ -2001,6 +2038,20 @@ def add_dataset_arguments(
|
|||||||
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
|
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--token_warmup_min",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--token_warmup_steps",
|
||||||
|
type=float,
|
||||||
|
default=0,
|
||||||
|
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)",
|
||||||
|
)
|
||||||
|
|
||||||
if support_caption_dropout:
|
if support_caption_dropout:
|
||||||
# Textual Inversion はcaptionのdropoutをsupportしない
|
# Textual Inversion はcaptionのdropoutをsupportしない
|
||||||
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
||||||
|
|||||||
@@ -162,6 +162,9 @@ def train(args):
|
|||||||
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
|
# データセット側にも学習ステップを送信
|
||||||
|
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||||
|
|
||||||
if args.stop_text_encoder_training is None:
|
if args.stop_text_encoder_training is None:
|
||||||
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
||||||
|
|
||||||
@@ -246,6 +249,7 @@ def train(args):
|
|||||||
text_encoder.requires_grad_(False)
|
text_encoder.requires_grad_(False)
|
||||||
|
|
||||||
with accelerator.accumulate(unet):
|
with accelerator.accumulate(unet):
|
||||||
|
train_dataset_group.set_current_step(step + 1)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# latentに変換
|
# latentに変換
|
||||||
if cache_latents:
|
if cache_latents:
|
||||||
|
|||||||
@@ -200,6 +200,9 @@ def train(args):
|
|||||||
if is_main_process:
|
if is_main_process:
|
||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
|
# データセット側にも学習ステップを送信
|
||||||
|
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||||
|
|
||||||
@@ -505,6 +508,7 @@ def train(args):
|
|||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
with accelerator.accumulate(network):
|
with accelerator.accumulate(network):
|
||||||
|
train_dataset_group.set_current_step(step + 1)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
latents = batch["latents"].to(accelerator.device)
|
latents = batch["latents"].to(accelerator.device)
|
||||||
|
|||||||
@@ -260,6 +260,9 @@ def train(args):
|
|||||||
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
|
# データセット側にも学習ステップを送信
|
||||||
|
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||||
|
|
||||||
@@ -338,6 +341,7 @@ def train(args):
|
|||||||
loss_total = 0
|
loss_total = 0
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
with accelerator.accumulate(text_encoder):
|
with accelerator.accumulate(text_encoder):
|
||||||
|
train_dataset_group.set_current_step(step + 1)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
latents = batch["latents"].to(accelerator.device)
|
latents = batch["latents"].to(accelerator.device)
|
||||||
|
|||||||
Reference in New Issue
Block a user