This commit is contained in:
Benjamin Hanes
2026-04-01 11:41:51 +00:00
committed by GitHub

View File

@@ -1323,10 +1323,9 @@ class NetworkTrainer:
# training loop # training loop
if initial_step > 0: # only if skip_until_initial_step is specified if initial_step > 0: # only if skip_until_initial_step is specified
for skip_epoch in range(epoch_to_start): # skip epochs
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
initial_step -= len(train_dataloader)
global_step = initial_step global_step = initial_step
logger.info(f"skipping epoch {epoch_to_start} because initial_step (multiplied) is {initial_step}")
initial_step -= epoch_to_start * len(train_dataloader)
# log device and dtype for each model # log device and dtype for each model
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}") logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
@@ -1394,14 +1393,11 @@ class NetworkTrainer:
# TRAINING # TRAINING
skipped_dataloader = None skipped_dataloader = None
if initial_step > 0: if initial_step > 0:
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1) skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step)
initial_step = 1 initial_step = 0
for step, batch in enumerate(skipped_dataloader or train_dataloader): for step, batch in enumerate(skipped_dataloader or train_dataloader):
current_step.value = global_step current_step.value = global_step
if initial_step > 0:
initial_step -= 1
continue
with accelerator.accumulate(training_model): with accelerator.accumulate(training_model):
on_step_start_for_network(text_encoder, unet) on_step_start_for_network(text_encoder, unet)