Fix saving issue if epoch/step not in checkpoint

This commit is contained in:
Kohya S
2023-03-27 21:22:32 +09:00
parent 238f01bc9c
commit 895b0b6ca7

View File

@@ -1046,10 +1046,14 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
key_count = len(state_dict.keys()) key_count = len(state_dict.keys())
new_ckpt = {'state_dict': state_dict} new_ckpt = {'state_dict': state_dict}
if 'epoch' in checkpoint: # epoch and global_step are sometimes not int
epochs += checkpoint['epoch'] try:
if 'global_step' in checkpoint: if 'epoch' in checkpoint:
steps += checkpoint['global_step'] epochs += checkpoint['epoch']
if 'global_step' in checkpoint:
steps += checkpoint['global_step']
except:
pass
new_ckpt['epoch'] = epochs new_ckpt['epoch'] = epochs
new_ckpt['global_step'] = steps new_ckpt['global_step'] = steps