feat: add --save_precision args

This commit is contained in:
Linaqruf
2023-04-12 05:43:19 +07:00
parent c316c63dff
commit 7f8e05ccad

View File

@@ -12,12 +12,12 @@ def convert(args):
# 引数を確認する
load_dtype = torch.float16 if args.fp16 else None
save_dtype = None
if args.fp16:
save_dtype = None
if args.fp16 or args.save_precision_as == "fp16":
save_dtype = torch.float16
elif args.bf16:
elif args.bf16 or args.save_precision_as == "bf16":
save_dtype = torch.bfloat16
elif args.float:
elif args.float or args.save_precision_as == "float":
save_dtype = torch.float
is_load_ckpt = os.path.isfile(args.model_to_load)
@@ -72,6 +72,8 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存するcheckpointのみ対応')
parser.add_argument("--float", action='store_true',
help='save as float (checkpoint only) / float(float32)形式で保存するcheckpointのみ対応')
parser.add_argument("--save_precision_as", type=str, default="no", choices=["fp16", "bf16", "float"],
help="save precision")
parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値')
parser.add_argument("--global_step", type=int, default=0,
help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')