Merge branch 'dev' into min-SNR

This commit is contained in:
Kohya S
2023-03-26 17:10:53 +09:00
committed by GitHub
7 changed files with 2713 additions and 2347 deletions

View File

@@ -56,6 +56,8 @@ class BaseSubsetParams:
caption_dropout_rate: float = 0.0
caption_dropout_every_n_epochs: int = 0
caption_tag_dropout_rate: float = 0.0
token_warmup_min: int = 1
token_warmup_step: float = 0
@dataclass
class DreamBoothSubsetParams(BaseSubsetParams):
@@ -137,6 +139,8 @@ class ConfigSanitizer:
"random_crop": bool,
"shuffle_caption": bool,
"keep_tokens": int,
"token_warmup_min": int,
"token_warmup_step": Any(float,int),
}
# DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = {
@@ -406,6 +410,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
"""), " ")
if is_dreambooth:
@@ -491,7 +497,6 @@ def load_user_config(file: str) -> dict:
return config
# for config test
if __name__ == "__main__":
parser = argparse.ArgumentParser()

View File

@@ -276,6 +276,8 @@ class BaseSubset:
caption_dropout_rate: float,
caption_dropout_every_n_epochs: int,
caption_tag_dropout_rate: float,
token_warmup_min: int,
token_warmup_step: Union[float,int],
) -> None:
self.image_dir = image_dir
self.num_repeats = num_repeats
@@ -289,6 +291,9 @@ class BaseSubset:
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
self.caption_tag_dropout_rate = caption_tag_dropout_rate
self.token_warmup_min = token_warmup_min # step=0におけるタグの数
self.token_warmup_step = token_warmup_step # NN<1ならN*max_train_stepsステップ目でタグの数が最大になる
self.img_count = 0
@@ -309,6 +314,8 @@ class DreamBoothSubset(BaseSubset):
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -324,6 +331,8 @@ class DreamBoothSubset(BaseSubset):
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
)
self.is_reg = is_reg
@@ -351,6 +360,8 @@ class FineTuningSubset(BaseSubset):
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
@@ -366,6 +377,8 @@ class FineTuningSubset(BaseSubset):
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
)
self.metadata_file = metadata_file
@@ -404,6 +417,9 @@ class BaseDataset(torch.utils.data.Dataset):
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
self.current_step: int = 0
self.max_train_steps: int = 0
# augmentation
self.aug_helper = AugHelper()
@@ -420,8 +436,15 @@ class BaseDataset(torch.utils.data.Dataset):
self.replacements = {}
def set_current_epoch(self, epoch):
if not self.current_epoch == epoch:
self.shuffle_buckets()
self.current_epoch = epoch
self.shuffle_buckets()
def set_current_step(self, step):
self.current_step = step
def set_max_train_steps(self, max_train_steps):
self.max_train_steps = max_train_steps
def set_tag_frequency(self, dir_name, captions):
frequency_for_dir = self.tag_frequency.get(dir_name, {})
@@ -452,7 +475,14 @@ class BaseDataset(torch.utils.data.Dataset):
if is_drop_out:
caption = ""
else:
if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0:
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
tokens = [t.strip() for t in caption.strip().split(",")]
if subset.token_warmup_step < 1:
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min
tokens = tokens[:tokens_len]
def dropout_tags(tokens):
if subset.caption_tag_dropout_rate <= 0:
@@ -464,10 +494,10 @@ class BaseDataset(torch.utils.data.Dataset):
return l
fixed_tokens = []
flex_tokens = [t.strip() for t in caption.strip().split(",")]
flex_tokens = tokens[:]
if subset.keep_tokens > 0:
fixed_tokens = flex_tokens[: subset.keep_tokens]
flex_tokens = flex_tokens[subset.keep_tokens :]
flex_tokens = tokens[subset.keep_tokens :]
if subset.shuffle_caption:
random.shuffle(flex_tokens)
@@ -1285,6 +1315,14 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for dataset in self.datasets:
dataset.set_current_epoch(epoch)
def set_current_step(self, step):
for dataset in self.datasets:
dataset.set_current_step(step)
def set_max_train_steps(self, max_train_steps):
for dataset in self.datasets:
dataset.set_max_train_steps(max_train_steps)
def disable_token_padding(self):
for dataset in self.datasets:
dataset.disable_token_padding()
@@ -2038,6 +2076,20 @@ def add_dataset_arguments(
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
)
parser.add_argument(
"--token_warmup_min",
type=int,
default=1,
help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する",
)
parser.add_argument(
"--token_warmup_step",
type=float,
default=0,
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / NN<1ならN*max_train_stepsステップでタグ長が最大になる。デフォルトは0最初から最大",
)
if support_caption_dropout:
# Textual Inversion はcaptionのdropoutをsupportしない
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
@@ -2972,3 +3024,14 @@ class ImageLoadingDataset(torch.utils.data.Dataset):
# endregion
# colalte_fn用 epoch,stepはmultiprocessing.Value
class collater_class:
def __init__(self,epoch,step):
self.current_epoch=epoch
self.current_step=step
def __call__(self, examples):
dataset = torch.utils.data.get_worker_info().dataset
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]