mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
enable full bf16 trainint in train_network
This commit is contained in:
@@ -2687,7 +2687,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
|
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
|
||||||
)
|
)
|
||||||
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
|
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
|
||||||
parser.add_argument("--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する")
|
parser.add_argument(
|
||||||
|
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
|
||||||
|
) # TODO move to SDXL training, because it is not supported by SD1/2
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--clip_skip",
|
"--clip_skip",
|
||||||
type=int,
|
type=int,
|
||||||
|
|||||||
@@ -273,7 +273,7 @@ def train(args):
|
|||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16にする
|
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
assert (
|
assert (
|
||||||
args.mixed_precision == "fp16"
|
args.mixed_precision == "fp16"
|
||||||
|
|||||||
@@ -350,13 +350,19 @@ class NetworkTrainer:
|
|||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
assert (
|
assert (
|
||||||
args.mixed_precision == "fp16"
|
args.mixed_precision == "fp16"
|
||||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||||
accelerator.print("enable full fp16 training.")
|
accelerator.print("enable full fp16 training.")
|
||||||
network.to(weight_dtype)
|
network.to(weight_dtype)
|
||||||
|
elif args.full_bf16:
|
||||||
|
assert (
|
||||||
|
args.mixed_precision == "bf16"
|
||||||
|
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
||||||
|
accelerator.print("enable full bf16 training.")
|
||||||
|
network.to(weight_dtype)
|
||||||
|
|
||||||
unet.requires_grad_(False)
|
unet.requires_grad_(False)
|
||||||
unet.to(dtype=weight_dtype)
|
unet.to(dtype=weight_dtype)
|
||||||
|
|||||||
Reference in New Issue
Block a user