support win with diffusers, fix extra args eval

This commit is contained in:
Kohya S
2023-03-19 22:09:36 +09:00
parent c86bf213d1
commit de95431895

View File

@@ -1351,8 +1351,10 @@ def model_hash(filename):
return m.hexdigest()[0:8] return m.hexdigest()[0:8]
except FileNotFoundError: except FileNotFoundError:
return "NOFILE" return "NOFILE"
except IsADirectoryError: except IsADirectoryError: # Linux?
return 'IsADirectory' return "IsADirectory"
except PermissionError: # Windows
return "IsADirectory"
def calculate_sha256(filename): def calculate_sha256(filename):
@@ -1365,11 +1367,13 @@ def calculate_sha256(filename):
for chunk in iter(lambda: f.read(blksize), b""): for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk) hash_sha256.update(chunk)
return hash_sha256.hexdigest() return hash_sha256.hexdigest()
except FileNotFoundError: except FileNotFoundError:
return 'NOFILE' return "NOFILE"
except IsADirectoryError: except IsADirectoryError: # Linux?
return 'IsADirectory' return "IsADirectory"
except PermissionError: # Windows
return "IsADirectory"
def precalculate_safetensors_hashes(tensors, metadata): def precalculate_safetensors_hashes(tensors, metadata):
@@ -1728,7 +1732,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ..."', help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ..."',
) )
parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module") parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ")
parser.add_argument( parser.add_argument(
"--lr_scheduler_args", "--lr_scheduler_args",
type=str, type=str,
@@ -2139,22 +2143,23 @@ def get_optimizer(args, trainable_params):
optimizer_type = "AdamW" optimizer_type = "AdamW"
optimizer_type = optimizer_type.lower() optimizer_type = optimizer_type.lower()
# 引数を分解するboolとfloat、tupleのみ対応 # 引数を分解する
optimizer_kwargs = {} optimizer_kwargs = {}
if args.optimizer_args is not None and len(args.optimizer_args) > 0: if args.optimizer_args is not None and len(args.optimizer_args) > 0:
for arg in args.optimizer_args: for arg in args.optimizer_args:
key, value = arg.split("=") key, value = arg.split("=")
value = ast.literal_eval(value)
value = value.split(",") # value = value.split(",")
for i in range(len(value)): # for i in range(len(value)):
if value[i].lower() == "true" or value[i].lower() == "false": # if value[i].lower() == "true" or value[i].lower() == "false":
value[i] = value[i].lower() == "true" # value[i] = value[i].lower() == "true"
else: # else:
value[i] = float(value[i]) # value[i] = ast.float(value[i])
if len(value) == 1: # if len(value) == 1:
value = value[0] # value = value[0]
else: # else:
value = tuple(value) # value = tuple(value)
optimizer_kwargs[key] = value optimizer_kwargs[key] = value
# print("optkwargs:", optimizer_kwargs) # print("optkwargs:", optimizer_kwargs)
@@ -2324,16 +2329,17 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
for arg in args.lr_scheduler_args: for arg in args.lr_scheduler_args:
key, value = arg.split("=") key, value = arg.split("=")
value = value.split(",") value = ast.literal_eval(value)
for i in range(len(value)): # value = value.split(",")
if value[i].lower() == "true" or value[i].lower() == "false": # for i in range(len(value)):
value[i] = value[i].lower() == "true" # if value[i].lower() == "true" or value[i].lower() == "false":
else: # value[i] = value[i].lower() == "true"
value[i] = ast.literal_eval(value[i]) # else:
if len(value) == 1: # value[i] = ast.literal_eval(value[i])
value = value[0] # if len(value) == 1:
else: # value = value[0]
value = list(value) # some may use list? # else:
# value = list(value) # some may use list?
lr_scheduler_kwargs[key] = value lr_scheduler_kwargs[key] = value