mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
@@ -4707,6 +4707,15 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
|||||||
**lr_scheduler_kwargs,
|
**lr_scheduler_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# these schedulers do not require `num_decay_steps`
|
||||||
|
if name == SchedulerType.LINEAR or name == SchedulerType.COSINE:
|
||||||
|
return schedule_func(
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=num_warmup_steps,
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
**lr_scheduler_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# All other schedulers require `num_decay_steps`
|
# All other schedulers require `num_decay_steps`
|
||||||
if num_decay_steps is None:
|
if num_decay_steps is None:
|
||||||
raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")
|
raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")
|
||||||
@@ -5837,14 +5846,9 @@ def sample_image_inference(
|
|||||||
wandb_tracker = accelerator.get_tracker("wandb")
|
wandb_tracker = accelerator.get_tracker("wandb")
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# not to commit images to avoid inconsistency between training and logging steps
|
# not to commit images to avoid inconsistency between training and logging steps
|
||||||
wandb_tracker.log(
|
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
||||||
{f"sample_{i}": wandb.Image(
|
|
||||||
image,
|
|
||||||
caption=prompt # positive prompt as a caption
|
|
||||||
)},
|
|
||||||
commit=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|||||||
Reference in New Issue
Block a user