From 96c677b4594ed6f28f3ef896f6deca7c3aced25d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 16 Sep 2024 10:42:09 +0900 Subject: [PATCH] fix to work lienar/cosine lr scheduler closes #1602 ref #1393 --- library/train_util.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 742d057e..60afd421 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4707,6 +4707,15 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): **lr_scheduler_kwargs, ) + # these schedulers do not require `num_decay_steps` + if name == SchedulerType.LINEAR or name == SchedulerType.COSINE: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + **lr_scheduler_kwargs, + ) + # All other schedulers require `num_decay_steps` if num_decay_steps is None: raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.") @@ -5837,14 +5846,9 @@ def sample_image_inference( wandb_tracker = accelerator.get_tracker("wandb") import wandb + # not to commit images to avoid inconsistency between training and logging steps - wandb_tracker.log( - {f"sample_{i}": wandb.Image( - image, - caption=prompt # positive prompt as a caption - )}, - commit=False - ) + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption # endregion