mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add callback for step start
This commit is contained in:
@@ -525,10 +525,11 @@ def train(args):
|
|||||||
loss_total = 0.0
|
loss_total = 0.0
|
||||||
del train_dataset_group
|
del train_dataset_group
|
||||||
|
|
||||||
# if hasattr(network, "on_step_start"):
|
# callback for step start
|
||||||
# on_step_start = network.on_step_start
|
if hasattr(network, "on_step_start"):
|
||||||
# else:
|
on_step_start = network.on_step_start
|
||||||
# on_step_start = lambda *args, **kwargs: None
|
else:
|
||||||
|
on_step_start = lambda *args, **kwargs: None
|
||||||
|
|
||||||
# function for saving/removing
|
# function for saving/removing
|
||||||
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
|
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
|
||||||
@@ -563,7 +564,7 @@ def train(args):
|
|||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(network):
|
with accelerator.accumulate(network):
|
||||||
# on_step_start(text_encoder, unet)
|
on_step_start(text_encoder, unet)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user