update for corner cases

This commit is contained in:
Kohya S
2024-06-04 21:26:55 +09:00
parent 321e24d83b
commit 4dbcef429b
2 changed files with 17 additions and 9 deletions

View File

@@ -663,6 +663,7 @@ class BaseDataset(torch.utils.data.Dataset):
for _ in range(num_epochs): for _ in range(num_epochs):
self.current_epoch += 1 self.current_epoch += 1
self.shuffle_buckets() self.shuffle_buckets()
# self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader?
else: else:
logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
self.current_epoch = epoch self.current_epoch = epoch
@@ -5560,6 +5561,8 @@ class LossRecorder:
if epoch == 0: if epoch == 0:
self.loss_list.append(loss) self.loss_list.append(loss)
else: else:
while len(self.loss_list) <= step:
self.loss_list.append(0.0)
self.loss_total -= self.loss_list[step] self.loss_total -= self.loss_list[step]
self.loss_list[step] = loss self.loss_list[step] = loss
self.loss_total += loss self.loss_total += loss

View File

@@ -493,13 +493,15 @@ class NetworkTrainer:
# before resuming make hook for saving/loading to save/load the network weights only # before resuming make hook for saving/loading to save/load the network weights only
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
# pop weights of other models than network to save only network weights # pop weights of other models than network to save only network weights
if accelerator.is_main_process: # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
if accelerator.is_main_process or args.deepspeed:
remove_indices = [] remove_indices = []
for i, model in enumerate(models): for i, model in enumerate(models):
if not isinstance(model, type(accelerator.unwrap_model(network))): if not isinstance(model, type(accelerator.unwrap_model(network))):
remove_indices.append(i) remove_indices.append(i)
for i in reversed(remove_indices): for i in reversed(remove_indices):
weights.pop(i) if len(weights) > i:
weights.pop(i)
# print(f"save model hook: {len(weights)} weights will be saved") # print(f"save model hook: {len(weights)} weights will be saved")
# save current ecpoch and step # save current ecpoch and step
@@ -813,11 +815,12 @@ class NetworkTrainer:
) )
logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします") logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
initial_step *= args.gradient_accumulation_steps initial_step *= args.gradient_accumulation_steps
# set epoch to start to make initial_step less than len(train_dataloader)
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
else: else:
# if not, only epoch no is skipped for informative purpose # if not, only epoch no is skipped for informative purpose
epoch_to_start = initial_step // math.ceil( epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
len(train_dataloader) / args.gradient_accumulation_steps
)
initial_step = 0 # do not skip initial_step = 0 # do not skip
global_step = 0 global_step = 0
@@ -878,9 +881,11 @@ class NetworkTrainer:
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# training loop # training loop
for skip_epoch in range(epoch_to_start): # skip epochs if initial_step > 0: # only if skip_until_initial_step is specified
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") for skip_epoch in range(epoch_to_start): # skip epochs
initial_step -= len(train_dataloader) logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
initial_step -= len(train_dataloader)
global_step = initial_step
for epoch in range(epoch_to_start, num_train_epochs): for epoch in range(epoch_to_start, num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
@@ -892,7 +897,7 @@ class NetworkTrainer:
skipped_dataloader = None skipped_dataloader = None
if initial_step > 0: if initial_step > 0:
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step-1) skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
initial_step = 1 initial_step = 1
for step, batch in enumerate(skipped_dataloader or train_dataloader): for step, batch in enumerate(skipped_dataloader or train_dataloader):