diff --git a/library/train_util.py b/library/train_util.py index 102f9f03..4736ff4f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -663,6 +663,7 @@ class BaseDataset(torch.utils.data.Dataset): for _ in range(num_epochs): self.current_epoch += 1 self.shuffle_buckets() + # self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader? else: logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) self.current_epoch = epoch @@ -5560,6 +5561,8 @@ class LossRecorder: if epoch == 0: self.loss_list.append(loss) else: + while len(self.loss_list) <= step: + self.loss_list.append(0.0) self.loss_total -= self.loss_list[step] self.loss_list[step] = loss self.loss_total += loss diff --git a/train_network.py b/train_network.py index d1f02d53..7ba07385 100644 --- a/train_network.py +++ b/train_network.py @@ -493,13 +493,15 @@ class NetworkTrainer: # before resuming make hook for saving/loading to save/load the network weights only def save_model_hook(models, weights, output_dir): # pop weights of other models than network to save only network weights - if accelerator.is_main_process: + # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606 + if accelerator.is_main_process or args.deepspeed: remove_indices = [] for i, model in enumerate(models): if not isinstance(model, type(accelerator.unwrap_model(network))): remove_indices.append(i) for i in reversed(remove_indices): - weights.pop(i) + if len(weights) > i: + weights.pop(i) # print(f"save model hook: {len(weights)} weights will be saved") # save current ecpoch and step @@ -813,11 +815,12 @@ class NetworkTrainer: ) logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします") initial_step *= args.gradient_accumulation_steps + + # set epoch to start to make initial_step less than len(train_dataloader) + epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) else: # if not, only epoch no is skipped for informative purpose - epoch_to_start = initial_step // math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps - ) + epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) initial_step = 0 # do not skip global_step = 0 @@ -878,9 +881,11 @@ class NetworkTrainer: self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # training loop - for skip_epoch in range(epoch_to_start): # skip epochs - logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") - initial_step -= len(train_dataloader) + if initial_step > 0: # only if skip_until_initial_step is specified + for skip_epoch in range(epoch_to_start): # skip epochs + logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") + initial_step -= len(train_dataloader) + global_step = initial_step for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") @@ -892,7 +897,7 @@ class NetworkTrainer: skipped_dataloader = None if initial_step > 0: - skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step-1) + skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1) initial_step = 1 for step, batch in enumerate(skipped_dataloader or train_dataloader):