add callback for step start

This commit is contained in:
Kohya S
2023-05-11 22:00:41 +09:00
parent 8d562ecf48
commit 7889a52f95

View File

@@ -525,10 +525,11 @@ def train(args):
loss_total = 0.0 loss_total = 0.0
del train_dataset_group del train_dataset_group
# if hasattr(network, "on_step_start"): # callback for step start
# on_step_start = network.on_step_start if hasattr(network, "on_step_start"):
# else: on_step_start = network.on_step_start
# on_step_start = lambda *args, **kwargs: None else:
on_step_start = lambda *args, **kwargs: None
# function for saving/removing # function for saving/removing
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
@@ -563,7 +564,7 @@ def train(args):
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
current_step.value = global_step current_step.value = global_step
with accelerator.accumulate(network): with accelerator.accumulate(network):
# on_step_start(text_encoder, unet) on_step_start(text_encoder, unet)
with torch.no_grad(): with torch.no_grad():
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None: