mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
Adjustments to resuming training
Currently skips the resumed epoch if partway through These changes make it resume mid epoch on the appropriate step
This commit is contained in:
@@ -882,10 +882,9 @@ class NetworkTrainer:
|
||||
|
||||
# training loop
|
||||
if initial_step > 0: # only if skip_until_initial_step is specified
|
||||
for skip_epoch in range(epoch_to_start): # skip epochs
|
||||
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
|
||||
initial_step -= len(train_dataloader)
|
||||
global_step = initial_step
|
||||
logger.info(f"skipping epoch {epoch_to_start} because initial_step (multiplied) is {initial_step}")
|
||||
initial_step -= epoch_to_start * len(train_dataloader)
|
||||
|
||||
for epoch in range(epoch_to_start, num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
@@ -897,14 +896,11 @@ class NetworkTrainer:
|
||||
|
||||
skipped_dataloader = None
|
||||
if initial_step > 0:
|
||||
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
|
||||
initial_step = 1
|
||||
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step)
|
||||
initial_step = 0
|
||||
|
||||
for step, batch in enumerate(skipped_dataloader or train_dataloader):
|
||||
current_step.value = global_step
|
||||
if initial_step > 0:
|
||||
initial_step -= 1
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(training_model):
|
||||
on_step_start(text_encoder, unet)
|
||||
|
||||
Reference in New Issue
Block a user