mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add dropout options
This commit is contained in:
@@ -171,6 +171,10 @@ def train(args):
|
|||||||
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
|
# 学習データのdropout率を設定する
|
||||||
|
train_dataset.dropout_rate = args.dropout_rate
|
||||||
|
train_dataset.dropout_every_n_epochs = args.dropout_every_n_epochs
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
||||||
@@ -226,6 +230,9 @@ def train(args):
|
|||||||
|
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
|
|
||||||
|
train_dataset.epoch_current = epoch + 1
|
||||||
|
|
||||||
for m in training_models:
|
for m in training_models:
|
||||||
m.train()
|
m.train()
|
||||||
|
|
||||||
|
|||||||
@@ -223,6 +223,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||||
|
|
||||||
|
self.epoch_current:int = int(0)
|
||||||
|
self.dropout_rate:float = 0
|
||||||
|
self.dropout_every_n_epochs:int = 0
|
||||||
|
|
||||||
# augmentation
|
# augmentation
|
||||||
flip_p = 0.5 if flip_aug else 0.0
|
flip_p = 0.5 if flip_aug else 0.0
|
||||||
if color_aug:
|
if color_aug:
|
||||||
@@ -598,6 +602,16 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
images.append(image)
|
images.append(image)
|
||||||
latents_list.append(latents)
|
latents_list.append(latents)
|
||||||
|
|
||||||
|
# dropoutの決定
|
||||||
|
is_drop_out = False
|
||||||
|
if self.dropout_rate > 0 and self.dropout_rate < random.random() :
|
||||||
|
is_drop_out = True
|
||||||
|
if self.dropout_every_n_epochs > 0 and self.epoch_current % self.dropout_every_n_epochs == 0 :
|
||||||
|
is_drop_out = True
|
||||||
|
|
||||||
|
if is_drop_out:
|
||||||
|
caption = ""
|
||||||
|
else:
|
||||||
caption = self.process_caption(image_info.caption)
|
caption = self.process_caption(image_info.caption)
|
||||||
captions.append(caption)
|
captions.append(caption)
|
||||||
if not self.token_padding_disabled: # this option might be omitted in future
|
if not self.token_padding_disabled: # this option might be omitted in future
|
||||||
@@ -1407,6 +1421,10 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
|||||||
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
|
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
|
||||||
parser.add_argument("--bucket_no_upscale", action="store_true",
|
parser.add_argument("--bucket_no_upscale", action="store_true",
|
||||||
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
|
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
|
||||||
|
parser.add_argument("--dropout_rate", type=float, default=0,
|
||||||
|
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
||||||
|
parser.add_argument("--dropout_every_n_epochs", type=int, default=0,
|
||||||
|
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
||||||
|
|
||||||
if support_dreambooth:
|
if support_dreambooth:
|
||||||
# DreamBooth dataset
|
# DreamBooth dataset
|
||||||
|
|||||||
@@ -136,6 +136,10 @@ def train(args):
|
|||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
||||||
|
|
||||||
|
# 学習データのdropout率を設定する
|
||||||
|
train_dataset.dropout_rate = args.dropout_rate
|
||||||
|
train_dataset.dropout_every_n_epochs = args.dropout_every_n_epochs
|
||||||
|
|
||||||
# 学習ステップ数を計算する
|
# 学習ステップ数を計算する
|
||||||
if args.max_train_epochs is not None:
|
if args.max_train_epochs is not None:
|
||||||
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
||||||
@@ -204,6 +208,8 @@ def train(args):
|
|||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
|
|
||||||
|
train_dataset.epoch_current = epoch + 1
|
||||||
|
|
||||||
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
||||||
unet.train()
|
unet.train()
|
||||||
# train==True is required to enable gradient_checkpointing
|
# train==True is required to enable gradient_checkpointing
|
||||||
|
|||||||
@@ -219,6 +219,10 @@ def train(args):
|
|||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
||||||
|
|
||||||
|
# 学習データのdropout率を設定する
|
||||||
|
train_dataset.dropout_rate = args.dropout_rate
|
||||||
|
train_dataset.dropout_every_n_epochs = args.dropout_every_n_epochs
|
||||||
|
|
||||||
# 学習ステップ数を計算する
|
# 学習ステップ数を計算する
|
||||||
if args.max_train_epochs is not None:
|
if args.max_train_epochs is not None:
|
||||||
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
||||||
@@ -376,6 +380,9 @@ def train(args):
|
|||||||
|
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
|
|
||||||
|
train_dataset.epoch_current = epoch + 1
|
||||||
|
|
||||||
metadata["ss_epoch"] = str(epoch+1)
|
metadata["ss_epoch"] = str(epoch+1)
|
||||||
|
|
||||||
network.on_epoch_start(text_encoder, unet)
|
network.on_epoch_start(text_encoder, unet)
|
||||||
|
|||||||
Reference in New Issue
Block a user