mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
feat: add --save_precision args
This commit is contained in:
@@ -12,12 +12,12 @@ def convert(args):
|
||||
# 引数を確認する
|
||||
load_dtype = torch.float16 if args.fp16 else None
|
||||
|
||||
save_dtype = None
|
||||
if args.fp16:
|
||||
save_dtype = None
|
||||
if args.fp16 or args.save_precision_as == "fp16":
|
||||
save_dtype = torch.float16
|
||||
elif args.bf16:
|
||||
elif args.bf16 or args.save_precision_as == "bf16":
|
||||
save_dtype = torch.bfloat16
|
||||
elif args.float:
|
||||
elif args.float or args.save_precision_as == "float":
|
||||
save_dtype = torch.float
|
||||
|
||||
is_load_ckpt = os.path.isfile(args.model_to_load)
|
||||
@@ -72,6 +72,8 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)')
|
||||
parser.add_argument("--float", action='store_true',
|
||||
help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)')
|
||||
parser.add_argument("--save_precision_as", type=str, default="no", choices=["fp16", "bf16", "float"],
|
||||
help="save precision")
|
||||
parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値')
|
||||
parser.add_argument("--global_step", type=int, default=0,
|
||||
help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')
|
||||
|
||||
Reference in New Issue
Block a user