mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Merge e66f94a76c into 1dae34b0af
This commit is contained in:
@@ -1323,10 +1323,9 @@ class NetworkTrainer:
|
||||
|
||||
# training loop
|
||||
if initial_step > 0: # only if skip_until_initial_step is specified
|
||||
for skip_epoch in range(epoch_to_start): # skip epochs
|
||||
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
|
||||
initial_step -= len(train_dataloader)
|
||||
global_step = initial_step
|
||||
logger.info(f"skipping epoch {epoch_to_start} because initial_step (multiplied) is {initial_step}")
|
||||
initial_step -= epoch_to_start * len(train_dataloader)
|
||||
|
||||
# log device and dtype for each model
|
||||
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
|
||||
@@ -1394,14 +1393,11 @@ class NetworkTrainer:
|
||||
# TRAINING
|
||||
skipped_dataloader = None
|
||||
if initial_step > 0:
|
||||
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
|
||||
initial_step = 1
|
||||
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step)
|
||||
initial_step = 0
|
||||
|
||||
for step, batch in enumerate(skipped_dataloader or train_dataloader):
|
||||
current_step.value = global_step
|
||||
if initial_step > 0:
|
||||
initial_step -= 1
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(training_model):
|
||||
on_step_start_for_network(text_encoder, unet)
|
||||
|
||||
Reference in New Issue
Block a user