mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix an error when keep_tokens_separator is not set ref #975
This commit is contained in:
@@ -664,7 +664,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
||||||
fixed_tokens = []
|
fixed_tokens = []
|
||||||
flex_tokens = []
|
flex_tokens = []
|
||||||
if hasattr(subset, 'keep_tokens_separator') and subset.keep_tokens_separator in caption:
|
if (
|
||||||
|
hasattr(subset, "keep_tokens_separator")
|
||||||
|
and subset.keep_tokens_separator
|
||||||
|
and subset.keep_tokens_separator in caption
|
||||||
|
):
|
||||||
fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1)
|
fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1)
|
||||||
fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()]
|
fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()]
|
||||||
flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()]
|
flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()]
|
||||||
@@ -675,17 +679,17 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
fixed_tokens = flex_tokens[: subset.keep_tokens]
|
fixed_tokens = flex_tokens[: subset.keep_tokens]
|
||||||
flex_tokens = tokens[subset.keep_tokens :]
|
flex_tokens = tokens[subset.keep_tokens :]
|
||||||
|
|
||||||
|
|
||||||
if subset.token_warmup_step < 1: # 初回に上書きする
|
if subset.token_warmup_step < 1: # 初回に上書きする
|
||||||
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
||||||
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
||||||
tokens_len = (
|
tokens_len = (
|
||||||
math.floor((self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
|
math.floor(
|
||||||
|
(self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step))
|
||||||
|
)
|
||||||
+ subset.token_warmup_min
|
+ subset.token_warmup_min
|
||||||
)
|
)
|
||||||
flex_tokens = flex_tokens[:tokens_len]
|
flex_tokens = flex_tokens[:tokens_len]
|
||||||
|
|
||||||
|
|
||||||
def dropout_tags(tokens):
|
def dropout_tags(tokens):
|
||||||
if subset.caption_tag_dropout_rate <= 0:
|
if subset.caption_tag_dropout_rate <= 0:
|
||||||
return tokens
|
return tokens
|
||||||
@@ -3152,7 +3156,8 @@ def add_dataset_arguments(
|
|||||||
"--keep_tokens_separator",
|
"--keep_tokens_separator",
|
||||||
type=str,
|
type=str,
|
||||||
default="",
|
default="",
|
||||||
help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens.",
|
help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens."
|
||||||
|
+ " / captionを固定部分と可変部分に分けるためのカスタム区切り文字。この区切り文字より前のトークンはシャッフルされない。指定しない場合、'--keep_tokens'が固定部分のトークン数として使用される。",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--caption_prefix",
|
"--caption_prefix",
|
||||||
|
|||||||
Reference in New Issue
Block a user