mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Merge e66f94a76c into 1dae34b0af
This commit is contained in:
@@ -1323,10 +1323,9 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
if initial_step > 0: # only if skip_until_initial_step is specified
|
if initial_step > 0: # only if skip_until_initial_step is specified
|
||||||
for skip_epoch in range(epoch_to_start): # skip epochs
|
|
||||||
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
|
|
||||||
initial_step -= len(train_dataloader)
|
|
||||||
global_step = initial_step
|
global_step = initial_step
|
||||||
|
logger.info(f"skipping epoch {epoch_to_start} because initial_step (multiplied) is {initial_step}")
|
||||||
|
initial_step -= epoch_to_start * len(train_dataloader)
|
||||||
|
|
||||||
# log device and dtype for each model
|
# log device and dtype for each model
|
||||||
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
|
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
|
||||||
@@ -1394,14 +1393,11 @@ class NetworkTrainer:
|
|||||||
# TRAINING
|
# TRAINING
|
||||||
skipped_dataloader = None
|
skipped_dataloader = None
|
||||||
if initial_step > 0:
|
if initial_step > 0:
|
||||||
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
|
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step)
|
||||||
initial_step = 1
|
initial_step = 0
|
||||||
|
|
||||||
for step, batch in enumerate(skipped_dataloader or train_dataloader):
|
for step, batch in enumerate(skipped_dataloader or train_dataloader):
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
if initial_step > 0:
|
|
||||||
initial_step -= 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
with accelerator.accumulate(training_model):
|
with accelerator.accumulate(training_model):
|
||||||
on_step_start_for_network(text_encoder, unet)
|
on_step_start_for_network(text_encoder, unet)
|
||||||
|
|||||||
Reference in New Issue
Block a user