support win with diffusers, fix extra args eval

This commit is contained in:
Kohya S
2023-03-19 22:09:36 +09:00
parent c86bf213d1
commit de95431895

View File

@@ -1351,8 +1351,10 @@ def model_hash(filename):
return m.hexdigest()[0:8]
except FileNotFoundError:
return "NOFILE"
except IsADirectoryError:
return 'IsADirectory'
except IsADirectoryError: # Linux?
return "IsADirectory"
except PermissionError: # Windows
return "IsADirectory"
def calculate_sha256(filename):
@@ -1365,11 +1367,13 @@ def calculate_sha256(filename):
for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
return hash_sha256.hexdigest()
except FileNotFoundError:
return 'NOFILE'
except IsADirectoryError:
return 'IsADirectory'
return "NOFILE"
except IsADirectoryError: # Linux?
return "IsADirectory"
except PermissionError: # Windows
return "IsADirectory"
def precalculate_safetensors_hashes(tensors, metadata):
@@ -1728,7 +1732,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ..."',
)
parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module")
parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ")
parser.add_argument(
"--lr_scheduler_args",
type=str,
@@ -2139,22 +2143,23 @@ def get_optimizer(args, trainable_params):
optimizer_type = "AdamW"
optimizer_type = optimizer_type.lower()
# 引数を分解するboolとfloat、tupleのみ対応
# 引数を分解する
optimizer_kwargs = {}
if args.optimizer_args is not None and len(args.optimizer_args) > 0:
for arg in args.optimizer_args:
key, value = arg.split("=")
value = ast.literal_eval(value)
value = value.split(",")
for i in range(len(value)):
if value[i].lower() == "true" or value[i].lower() == "false":
value[i] = value[i].lower() == "true"
else:
value[i] = float(value[i])
if len(value) == 1:
value = value[0]
else:
value = tuple(value)
# value = value.split(",")
# for i in range(len(value)):
# if value[i].lower() == "true" or value[i].lower() == "false":
# value[i] = value[i].lower() == "true"
# else:
# value[i] = ast.float(value[i])
# if len(value) == 1:
# value = value[0]
# else:
# value = tuple(value)
optimizer_kwargs[key] = value
# print("optkwargs:", optimizer_kwargs)
@@ -2324,16 +2329,17 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
for arg in args.lr_scheduler_args:
key, value = arg.split("=")
value = value.split(",")
for i in range(len(value)):
if value[i].lower() == "true" or value[i].lower() == "false":
value[i] = value[i].lower() == "true"
else:
value[i] = ast.literal_eval(value[i])
if len(value) == 1:
value = value[0]
else:
value = list(value) # some may use list?
value = ast.literal_eval(value)
# value = value.split(",")
# for i in range(len(value)):
# if value[i].lower() == "true" or value[i].lower() == "false":
# value[i] = value[i].lower() == "true"
# else:
# value[i] = ast.literal_eval(value[i])
# if len(value) == 1:
# value = value[0]
# else:
# value = list(value) # some may use list?
lr_scheduler_kwargs[key] = value