diff --git a/train_network.py b/train_network.py index 6b8ed9bd..e0d2f8ad 100644 --- a/train_network.py +++ b/train_network.py @@ -1323,10 +1323,9 @@ class NetworkTrainer: # training loop if initial_step > 0: # only if skip_until_initial_step is specified - for skip_epoch in range(epoch_to_start): # skip epochs - logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") - initial_step -= len(train_dataloader) global_step = initial_step + logger.info(f"skipping epoch {epoch_to_start} because initial_step (multiplied) is {initial_step}") + initial_step -= epoch_to_start * len(train_dataloader) # log device and dtype for each model logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}") @@ -1394,14 +1393,11 @@ class NetworkTrainer: # TRAINING skipped_dataloader = None if initial_step > 0: - skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1) - initial_step = 1 + skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step) + initial_step = 0 for step, batch in enumerate(skipped_dataloader or train_dataloader): current_step.value = global_step - if initial_step > 0: - initial_step -= 1 - continue with accelerator.accumulate(training_model): on_step_start_for_network(text_encoder, unet)