passing filtered hyperparameters to accelerate

This commit is contained in:
Maatra
2024-04-20 14:11:43 +01:00
parent 71e2c91330
commit 2c9db5d9f2
10 changed files with 23 additions and 9 deletions

View File

@@ -290,7 +290,7 @@ def train(args):
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs)
# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)