diff --git a/library/config_util.py b/library/config_util.py index e8e0fda7..af4eedaa 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -51,6 +51,7 @@ class BaseSubsetParams: image_dir: Optional[str] = None num_repeats: int = 1 shuffle_caption: bool = False + caption_seperator: str = ',', keep_tokens: int = 0 color_aug: bool = False flip_aug: bool = False diff --git a/library/train_util.py b/library/train_util.py index 51610e70..c04ad9a9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -341,6 +341,7 @@ class BaseSubset: image_dir: Optional[str], num_repeats: int, shuffle_caption: bool, + caption_seperator: str, keep_tokens: int, color_aug: bool, flip_aug: bool, @@ -357,6 +358,7 @@ class BaseSubset: self.image_dir = image_dir self.num_repeats = num_repeats self.shuffle_caption = shuffle_caption + self.caption_seperator = caption_seperator self.keep_tokens = keep_tokens self.color_aug = color_aug self.flip_aug = flip_aug @@ -383,6 +385,7 @@ class DreamBoothSubset(BaseSubset): caption_extension: str, num_repeats, shuffle_caption, + caption_seperator: str, keep_tokens, color_aug, flip_aug, @@ -402,6 +405,7 @@ class DreamBoothSubset(BaseSubset): image_dir, num_repeats, shuffle_caption, + caption_seperator, keep_tokens, color_aug, flip_aug, @@ -435,6 +439,7 @@ class FineTuningSubset(BaseSubset): metadata_file: str, num_repeats, shuffle_caption, + caption_seperator, keep_tokens, color_aug, flip_aug, @@ -454,6 +459,7 @@ class FineTuningSubset(BaseSubset): image_dir, num_repeats, shuffle_caption, + caption_seperator, keep_tokens, color_aug, flip_aug, @@ -484,6 +490,7 @@ class ControlNetSubset(BaseSubset): caption_extension: str, num_repeats, shuffle_caption, + caption_seperator, keep_tokens, color_aug, flip_aug, @@ -503,6 +510,7 @@ class ControlNetSubset(BaseSubset): image_dir, num_repeats, shuffle_caption, + caption_seperator, keep_tokens, color_aug, flip_aug, @@ -638,7 +646,7 @@ class BaseDataset(torch.utils.data.Dataset): caption = "" else: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: - tokens = [t.strip() for t in caption.strip().split(",")] + tokens = [t.strip() for t in caption.strip().split(subset.caption_seperator)] if subset.token_warmup_step < 1: # 初回に上書きする subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) if subset.token_warmup_step and self.current_step < subset.token_warmup_step: @@ -3091,7 +3099,10 @@ def add_dataset_arguments( # dataset common parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument( - "--shuffle_caption", action="store_true", help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする" + "--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする" + ) + parser.add_argument( + "--caption_seperator", type=str, default=",", help="seperator for caption / captionの区切り文字" ) parser.add_argument( "--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子"