mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Update sdxl_train.py
This commit is contained in:
@@ -267,6 +267,14 @@ def train(args):
|
|||||||
unet.to(weight_dtype)
|
unet.to(weight_dtype)
|
||||||
text_encoder1.to(weight_dtype)
|
text_encoder1.to(weight_dtype)
|
||||||
text_encoder2.to(weight_dtype)
|
text_encoder2.to(weight_dtype)
|
||||||
|
elif args.full_bf16:
|
||||||
|
assert (
|
||||||
|
args.mixed_precision == "bf16"
|
||||||
|
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
||||||
|
accelerator.print("enable full bf16 training.")
|
||||||
|
unet.to(weight_dtype)
|
||||||
|
text_encoder1.to(weight_dtype)
|
||||||
|
text_encoder2.to(weight_dtype)
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
if args.train_text_encoder:
|
if args.train_text_encoder:
|
||||||
|
|||||||
Reference in New Issue
Block a user