diff --git a/library/train_util.py b/library/train_util.py index 40bcfc6f..7a285a94 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -362,7 +362,7 @@ class BaseDataset(torch.utils.data.Dataset): fixed_tokens = [] flex_tokens = [t.strip() for t in caption.strip().split(",")] - if subset.keep_tokens >= 0: + if subset.keep_tokens > 0: fixed_tokens = flex_tokens[:subset.keep_tokens] flex_tokens = flex_tokens[subset.keep_tokens:]