diff --git a/sd3_train.py b/sd3_train.py index 8216a62b..ea9a1104 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -60,9 +60,19 @@ def train(args): assert ( not args.weighted_captions ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + # assert ( + # not args.train_text_encoder or not args.cache_text_encoder_outputs + # ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + + # training text encoder is not supported assert ( - not args.train_text_encoder or not args.cache_text_encoder_outputs - ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + not args.train_text_encoder + ), "training text encoder is not supported currently / text encoderの学習は現在サポートされていません" + + # training without text encoder cache is not supported + assert ( + args.cache_text_encoder_outputs + ), "training without text encoder cache is not supported currently / text encoderのキャッシュなしの学習は現在サポートされていません" # if args.block_lr: # block_lrs = [float(lr) for lr in args.block_lr.split(",")]