From 225e8718194e4a7c3f3ca51df4a8ae5a2cea49dc Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 19 Jul 2023 08:41:42 +0900 Subject: [PATCH] enable full bf16 trainint in train_network --- library/train_util.py | 4 +++- sdxl_train.py | 2 +- train_network.py | 8 +++++++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index e6e5c3c4..f5d5288b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2687,7 +2687,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度" ) parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") - parser.add_argument("--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する") + parser.add_argument( + "--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する" + ) # TODO move to SDXL training, because it is not supported by SD1/2 parser.add_argument( "--clip_skip", type=int, diff --git a/sdxl_train.py b/sdxl_train.py index 630d8832..4dbed79a 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -273,7 +273,7 @@ def train(args): # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16にする + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする if args.full_fp16: assert ( args.mixed_precision == "fp16" diff --git a/train_network.py b/train_network.py index 3f78f159..a55339c4 100644 --- a/train_network.py +++ b/train_network.py @@ -350,13 +350,19 @@ class NetworkTrainer: # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする if args.full_fp16: assert ( args.mixed_precision == "fp16" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" accelerator.print("enable full fp16 training.") network.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + network.to(weight_dtype) unet.requires_grad_(False) unet.to(dtype=weight_dtype)