enable full bf16 trainint in train_network

This commit is contained in:
Kohya S
2023-07-19 08:41:42 +09:00
parent 7875ca8fb5
commit 225e871819
3 changed files with 11 additions and 3 deletions

View File

@@ -273,7 +273,7 @@ def train(args):
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16にする
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"