diff --git a/sdxl_train.py b/sdxl_train.py index 9cf20252..06cbc571 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -267,6 +267,14 @@ def train(args): unet.to(weight_dtype) text_encoder1.to(weight_dtype) text_encoder2.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + unet.to(weight_dtype) + text_encoder1.to(weight_dtype) + text_encoder2.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい if args.train_text_encoder: