From 828a581e2968935c00d22e7e03ca32c1281aa5dd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 26 Jun 2024 20:43:31 +0900 Subject: [PATCH] fix assertion for experimental impl ref #1389 --- sd3_train.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/sd3_train.py b/sd3_train.py index 8216a62b..ea9a1104 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -60,9 +60,19 @@ def train(args): assert ( not args.weighted_captions ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + # assert ( + # not args.train_text_encoder or not args.cache_text_encoder_outputs + # ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + + # training text encoder is not supported assert ( - not args.train_text_encoder or not args.cache_text_encoder_outputs - ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + not args.train_text_encoder + ), "training text encoder is not supported currently / text encoderの学習は現在サポートされていません" + + # training without text encoder cache is not supported + assert ( + args.cache_text_encoder_outputs + ), "training without text encoder cache is not supported currently / text encoderのキャッシュなしの学習は現在サポートされていません" # if args.block_lr: # block_lrs = [float(lr) for lr in args.block_lr.split(",")]