From e66f94a76c32e1bacfcd9ab782f7f725e5aba207 Mon Sep 17 00:00:00 2001 From: Cauldrath Date: Sun, 30 Jun 2024 23:40:53 -0400 Subject: [PATCH] Adjustments to resuming training Currently skips the resumed epoch if partway through These changes make it resume mid epoch on the appropriate step --- train_network.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/train_network.py b/train_network.py index 7ba07385..090691d7 100644 --- a/train_network.py +++ b/train_network.py @@ -882,10 +882,9 @@ class NetworkTrainer: # training loop if initial_step > 0: # only if skip_until_initial_step is specified - for skip_epoch in range(epoch_to_start): # skip epochs - logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") - initial_step -= len(train_dataloader) global_step = initial_step + logger.info(f"skipping epoch {epoch_to_start} because initial_step (multiplied) is {initial_step}") + initial_step -= epoch_to_start * len(train_dataloader) for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") @@ -897,14 +896,11 @@ class NetworkTrainer: skipped_dataloader = None if initial_step > 0: - skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1) - initial_step = 1 + skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step) + initial_step = 0 for step, batch in enumerate(skipped_dataloader or train_dataloader): current_step.value = global_step - if initial_step > 0: - initial_step -= 1 - continue with accelerator.accumulate(training_model): on_step_start(text_encoder, unet)