Fix --debug_dataset to work.

This commit is contained in:
Kohya S
2024-08-22 19:55:31 +09:00
parent 98c91a7625
commit a4d27a232b

View File

@@ -142,6 +142,12 @@ def train(args):
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False
)
)
name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev"
t5xxl_max_token_length = (
args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if name == "schnell" else 512)
)
strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length))
train_dataset_group.set_current_strategies()
train_util.debug_dataset(train_dataset_group, True)
return