mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 06:28:48 +00:00
Merge pull request #628 from KohakuBlueleaf/full_bf16
Full bf16 support
This commit is contained in:
@@ -2422,6 +2422,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
|
||||
)
|
||||
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
|
||||
parser.add_argument("--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する")
|
||||
parser.add_argument(
|
||||
"--clip_skip",
|
||||
type=int,
|
||||
|
||||
@@ -267,6 +267,14 @@ def train(args):
|
||||
unet.to(weight_dtype)
|
||||
text_encoder1.to(weight_dtype)
|
||||
text_encoder2.to(weight_dtype)
|
||||
elif args.full_bf16:
|
||||
assert (
|
||||
args.mixed_precision == "bf16"
|
||||
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
||||
accelerator.print("enable full bf16 training.")
|
||||
unet.to(weight_dtype)
|
||||
text_encoder1.to(weight_dtype)
|
||||
text_encoder2.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if args.train_text_encoder:
|
||||
|
||||
Reference in New Issue
Block a user