From de9543189593598d8038e3abec0e64acd5da23c6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 19 Mar 2023 22:09:36 +0900 Subject: [PATCH] support win with diffusers, fix extra args eval --- library/train_util.py | 62 ++++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index b23eeec1..7d311827 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1351,8 +1351,10 @@ def model_hash(filename): return m.hexdigest()[0:8] except FileNotFoundError: return "NOFILE" - except IsADirectoryError: - return 'IsADirectory' + except IsADirectoryError: # Linux? + return "IsADirectory" + except PermissionError: # Windows + return "IsADirectory" def calculate_sha256(filename): @@ -1365,11 +1367,13 @@ def calculate_sha256(filename): for chunk in iter(lambda: f.read(blksize), b""): hash_sha256.update(chunk) - return hash_sha256.hexdigest() + return hash_sha256.hexdigest() except FileNotFoundError: - return 'NOFILE' - except IsADirectoryError: - return 'IsADirectory' + return "NOFILE" + except IsADirectoryError: # Linux? + return "IsADirectory" + except PermissionError: # Windows + return "IsADirectory" def precalculate_safetensors_hashes(tensors, metadata): @@ -1728,7 +1732,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")', ) - parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module") + parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ") parser.add_argument( "--lr_scheduler_args", type=str, @@ -2139,22 +2143,23 @@ def get_optimizer(args, trainable_params): optimizer_type = "AdamW" optimizer_type = optimizer_type.lower() - # 引数を分解する:boolとfloat、tupleのみ対応 + # 引数を分解する optimizer_kwargs = {} if args.optimizer_args is not None and len(args.optimizer_args) > 0: for arg in args.optimizer_args: key, value = arg.split("=") + value = ast.literal_eval(value) - value = value.split(",") - for i in range(len(value)): - if value[i].lower() == "true" or value[i].lower() == "false": - value[i] = value[i].lower() == "true" - else: - value[i] = float(value[i]) - if len(value) == 1: - value = value[0] - else: - value = tuple(value) + # value = value.split(",") + # for i in range(len(value)): + # if value[i].lower() == "true" or value[i].lower() == "false": + # value[i] = value[i].lower() == "true" + # else: + # value[i] = ast.float(value[i]) + # if len(value) == 1: + # value = value[0] + # else: + # value = tuple(value) optimizer_kwargs[key] = value # print("optkwargs:", optimizer_kwargs) @@ -2324,16 +2329,17 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): for arg in args.lr_scheduler_args: key, value = arg.split("=") - value = value.split(",") - for i in range(len(value)): - if value[i].lower() == "true" or value[i].lower() == "false": - value[i] = value[i].lower() == "true" - else: - value[i] = ast.literal_eval(value[i]) - if len(value) == 1: - value = value[0] - else: - value = list(value) # some may use list? + value = ast.literal_eval(value) + # value = value.split(",") + # for i in range(len(value)): + # if value[i].lower() == "true" or value[i].lower() == "false": + # value[i] = value[i].lower() == "true" + # else: + # value[i] = ast.literal_eval(value[i]) + # if len(value) == 1: + # value = value[0] + # else: + # value = list(value) # some may use list? lr_scheduler_kwargs[key] = value