Update sdxl_train.py

This commit is contained in:
Kohaku-Blueleaf
2023-07-09 12:46:35 +08:00
committed by GitHub
parent 8371a7a3aa
commit 5f348579d1

View File

@@ -267,6 +267,14 @@ def train(args):
unet.to(weight_dtype) unet.to(weight_dtype)
text_encoder1.to(weight_dtype) text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype) text_encoder2.to(weight_dtype)
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
unet.to(weight_dtype)
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい # acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder: if args.train_text_encoder: