From 7889a52f959ca8d7350b0f6951690a984a68f38d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 11 May 2023 22:00:41 +0900 Subject: [PATCH] add callback for step start --- train_network.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/train_network.py b/train_network.py index b5cdfea1..b331e92c 100644 --- a/train_network.py +++ b/train_network.py @@ -525,10 +525,11 @@ def train(args): loss_total = 0.0 del train_dataset_group - # if hasattr(network, "on_step_start"): - # on_step_start = network.on_step_start - # else: - # on_step_start = lambda *args, **kwargs: None + # callback for step start + if hasattr(network, "on_step_start"): + on_step_start = network.on_step_start + else: + on_step_start = lambda *args, **kwargs: None # function for saving/removing def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): @@ -563,7 +564,7 @@ def train(args): for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(network): - # on_step_start(text_encoder, unet) + on_step_start(text_encoder, unet) with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: