From b886d0a359526f5715f3ced05697d406a169055b Mon Sep 17 00:00:00 2001 From: Maatra Date: Sat, 20 Apr 2024 14:36:47 +0100 Subject: [PATCH] Cleaned typing to be in line with accelerate hyperparameters type resctrictions --- library/train_util.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 40be2b05..75b3420d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3390,7 +3390,20 @@ def filter_sensitive_args(args: argparse.Namespace): "output_dir", "logging_dir", ] - filtered_args = {k: v for k, v in vars(args).items() if k not in sensitive_args + sensitive_path_args} + filtered_args = {} + for k, v in vars(args).items(): + # filter out sensitive values + if k not in sensitive_args + sensitive_path_args: + #Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`. + if v is None or isinstance(v, bool) or isinstance(v, str) or isinstance(v, float) or isinstance(v, int): + filtered_args[k] = v + # accelerate does not support lists + elif isinstance(v, list): + filtered_args[k] = f"{v}" + # accelerate does not support objects + elif isinstance(v, object): + filtered_args[k] = f"{v}" + return filtered_args # verify command line args for training