diff --git a/tools/convert_diffusers20_original_sd.py b/tools/convert_diffusers20_original_sd.py index 7c7cc1c5..c33ca202 100644 --- a/tools/convert_diffusers20_original_sd.py +++ b/tools/convert_diffusers20_original_sd.py @@ -12,12 +12,12 @@ def convert(args): # 引数を確認する load_dtype = torch.float16 if args.fp16 else None - save_dtype = None - if args.fp16: + save_dtype = None + if args.fp16 or args.save_precision_as == "fp16": save_dtype = torch.float16 - elif args.bf16: + elif args.bf16 or args.save_precision_as == "bf16": save_dtype = torch.bfloat16 - elif args.float: + elif args.float or args.save_precision_as == "float": save_dtype = torch.float is_load_ckpt = os.path.isfile(args.model_to_load) @@ -72,6 +72,8 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)') parser.add_argument("--float", action='store_true', help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)') + parser.add_argument("--save_precision_as", type=str, default="no", choices=["fp16", "bf16", "float"], + help="save precision") parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値') parser.add_argument("--global_step", type=int, default=0, help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')