Adjustments to resuming training

Currently skips the resumed epoch if partway through
These changes make it resume mid epoch on the appropriate step
This commit is contained in:
Cauldrath
2024-06-30 23:40:53 -04:00
parent 0b3e4f7ab6
commit e66f94a76c

View File

@@ -882,10 +882,9 @@ class NetworkTrainer:
# training loop
if initial_step > 0: # only if skip_until_initial_step is specified
for skip_epoch in range(epoch_to_start): # skip epochs
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
initial_step -= len(train_dataloader)
global_step = initial_step
logger.info(f"skipping epoch {epoch_to_start} because initial_step (multiplied) is {initial_step}")
initial_step -= epoch_to_start * len(train_dataloader)
for epoch in range(epoch_to_start, num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
@@ -897,14 +896,11 @@ class NetworkTrainer:
skipped_dataloader = None
if initial_step > 0:
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
initial_step = 1
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step)
initial_step = 0
for step, batch in enumerate(skipped_dataloader or train_dataloader):
current_step.value = global_step
if initial_step > 0:
initial_step -= 1
continue
with accelerator.accumulate(training_model):
on_step_start(text_encoder, unet)