mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
update for corner cases
This commit is contained in:
@@ -663,6 +663,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
for _ in range(num_epochs):
|
||||
self.current_epoch += 1
|
||||
self.shuffle_buckets()
|
||||
# self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader?
|
||||
else:
|
||||
logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
|
||||
self.current_epoch = epoch
|
||||
@@ -5560,6 +5561,8 @@ class LossRecorder:
|
||||
if epoch == 0:
|
||||
self.loss_list.append(loss)
|
||||
else:
|
||||
while len(self.loss_list) <= step:
|
||||
self.loss_list.append(0.0)
|
||||
self.loss_total -= self.loss_list[step]
|
||||
self.loss_list[step] = loss
|
||||
self.loss_total += loss
|
||||
|
||||
@@ -493,12 +493,14 @@ class NetworkTrainer:
|
||||
# before resuming make hook for saving/loading to save/load the network weights only
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
# pop weights of other models than network to save only network weights
|
||||
if accelerator.is_main_process:
|
||||
# only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
|
||||
if accelerator.is_main_process or args.deepspeed:
|
||||
remove_indices = []
|
||||
for i, model in enumerate(models):
|
||||
if not isinstance(model, type(accelerator.unwrap_model(network))):
|
||||
remove_indices.append(i)
|
||||
for i in reversed(remove_indices):
|
||||
if len(weights) > i:
|
||||
weights.pop(i)
|
||||
# print(f"save model hook: {len(weights)} weights will be saved")
|
||||
|
||||
@@ -813,11 +815,12 @@ class NetworkTrainer:
|
||||
)
|
||||
logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
|
||||
initial_step *= args.gradient_accumulation_steps
|
||||
|
||||
# set epoch to start to make initial_step less than len(train_dataloader)
|
||||
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
else:
|
||||
# if not, only epoch no is skipped for informative purpose
|
||||
epoch_to_start = initial_step // math.ceil(
|
||||
len(train_dataloader) / args.gradient_accumulation_steps
|
||||
)
|
||||
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
initial_step = 0 # do not skip
|
||||
|
||||
global_step = 0
|
||||
@@ -878,9 +881,11 @@ class NetworkTrainer:
|
||||
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
# training loop
|
||||
if initial_step > 0: # only if skip_until_initial_step is specified
|
||||
for skip_epoch in range(epoch_to_start): # skip epochs
|
||||
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
|
||||
initial_step -= len(train_dataloader)
|
||||
global_step = initial_step
|
||||
|
||||
for epoch in range(epoch_to_start, num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
|
||||
Reference in New Issue
Block a user