This commit is contained in:
Benjamin Hanes
2026-04-01 11:41:51 +00:00
committed by GitHub

View File

@@ -1323,10 +1323,9 @@ class NetworkTrainer:
# training loop
if initial_step > 0: # only if skip_until_initial_step is specified
for skip_epoch in range(epoch_to_start): # skip epochs
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
initial_step -= len(train_dataloader)
global_step = initial_step
logger.info(f"skipping epoch {epoch_to_start} because initial_step (multiplied) is {initial_step}")
initial_step -= epoch_to_start * len(train_dataloader)
# log device and dtype for each model
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
@@ -1394,14 +1393,11 @@ class NetworkTrainer:
# TRAINING
skipped_dataloader = None
if initial_step > 0:
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
initial_step = 1
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step)
initial_step = 0
for step, batch in enumerate(skipped_dataloader or train_dataloader):
current_step.value = global_step
if initial_step > 0:
initial_step -= 1
continue
with accelerator.accumulate(training_model):
on_step_start_for_network(text_encoder, unet)