mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
update for corner cases
This commit is contained in:
@@ -663,6 +663,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
for _ in range(num_epochs):
|
for _ in range(num_epochs):
|
||||||
self.current_epoch += 1
|
self.current_epoch += 1
|
||||||
self.shuffle_buckets()
|
self.shuffle_buckets()
|
||||||
|
# self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader?
|
||||||
else:
|
else:
|
||||||
logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
|
logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
|
||||||
self.current_epoch = epoch
|
self.current_epoch = epoch
|
||||||
@@ -5560,6 +5561,8 @@ class LossRecorder:
|
|||||||
if epoch == 0:
|
if epoch == 0:
|
||||||
self.loss_list.append(loss)
|
self.loss_list.append(loss)
|
||||||
else:
|
else:
|
||||||
|
while len(self.loss_list) <= step:
|
||||||
|
self.loss_list.append(0.0)
|
||||||
self.loss_total -= self.loss_list[step]
|
self.loss_total -= self.loss_list[step]
|
||||||
self.loss_list[step] = loss
|
self.loss_list[step] = loss
|
||||||
self.loss_total += loss
|
self.loss_total += loss
|
||||||
|
|||||||
@@ -493,12 +493,14 @@ class NetworkTrainer:
|
|||||||
# before resuming make hook for saving/loading to save/load the network weights only
|
# before resuming make hook for saving/loading to save/load the network weights only
|
||||||
def save_model_hook(models, weights, output_dir):
|
def save_model_hook(models, weights, output_dir):
|
||||||
# pop weights of other models than network to save only network weights
|
# pop weights of other models than network to save only network weights
|
||||||
if accelerator.is_main_process:
|
# only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
|
||||||
|
if accelerator.is_main_process or args.deepspeed:
|
||||||
remove_indices = []
|
remove_indices = []
|
||||||
for i, model in enumerate(models):
|
for i, model in enumerate(models):
|
||||||
if not isinstance(model, type(accelerator.unwrap_model(network))):
|
if not isinstance(model, type(accelerator.unwrap_model(network))):
|
||||||
remove_indices.append(i)
|
remove_indices.append(i)
|
||||||
for i in reversed(remove_indices):
|
for i in reversed(remove_indices):
|
||||||
|
if len(weights) > i:
|
||||||
weights.pop(i)
|
weights.pop(i)
|
||||||
# print(f"save model hook: {len(weights)} weights will be saved")
|
# print(f"save model hook: {len(weights)} weights will be saved")
|
||||||
|
|
||||||
@@ -813,11 +815,12 @@ class NetworkTrainer:
|
|||||||
)
|
)
|
||||||
logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
|
logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
|
||||||
initial_step *= args.gradient_accumulation_steps
|
initial_step *= args.gradient_accumulation_steps
|
||||||
|
|
||||||
|
# set epoch to start to make initial_step less than len(train_dataloader)
|
||||||
|
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
else:
|
else:
|
||||||
# if not, only epoch no is skipped for informative purpose
|
# if not, only epoch no is skipped for informative purpose
|
||||||
epoch_to_start = initial_step // math.ceil(
|
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
len(train_dataloader) / args.gradient_accumulation_steps
|
|
||||||
)
|
|
||||||
initial_step = 0 # do not skip
|
initial_step = 0 # do not skip
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
@@ -878,9 +881,11 @@ class NetworkTrainer:
|
|||||||
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
|
if initial_step > 0: # only if skip_until_initial_step is specified
|
||||||
for skip_epoch in range(epoch_to_start): # skip epochs
|
for skip_epoch in range(epoch_to_start): # skip epochs
|
||||||
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
|
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
|
||||||
initial_step -= len(train_dataloader)
|
initial_step -= len(train_dataloader)
|
||||||
|
global_step = initial_step
|
||||||
|
|
||||||
for epoch in range(epoch_to_start, num_train_epochs):
|
for epoch in range(epoch_to_start, num_train_epochs):
|
||||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||||
@@ -892,7 +897,7 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
skipped_dataloader = None
|
skipped_dataloader = None
|
||||||
if initial_step > 0:
|
if initial_step > 0:
|
||||||
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step-1)
|
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
|
||||||
initial_step = 1
|
initial_step = 1
|
||||||
|
|
||||||
for step, batch in enumerate(skipped_dataloader or train_dataloader):
|
for step, batch in enumerate(skipped_dataloader or train_dataloader):
|
||||||
|
|||||||
Reference in New Issue
Block a user