minor fix in token shuffling

This commit is contained in:
Kohya S
2023-03-02 20:31:07 +09:00
parent c3024be8bf
commit 859f8361bb

View File

@@ -362,7 +362,7 @@ class BaseDataset(torch.utils.data.Dataset):
fixed_tokens = [] fixed_tokens = []
flex_tokens = [t.strip() for t in caption.strip().split(",")] flex_tokens = [t.strip() for t in caption.strip().split(",")]
if subset.keep_tokens >= 0: if subset.keep_tokens > 0:
fixed_tokens = flex_tokens[:subset.keep_tokens] fixed_tokens = flex_tokens[:subset.keep_tokens]
flex_tokens = flex_tokens[subset.keep_tokens:] flex_tokens = flex_tokens[subset.keep_tokens:]