Fix saving issue if epoch/step not in checkpoint

This commit is contained in:
Kohya S
2023-03-27 21:22:32 +09:00
parent 238f01bc9c
commit 895b0b6ca7

View File

@@ -1046,10 +1046,14 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
key_count = len(state_dict.keys())
new_ckpt = {'state_dict': state_dict}
if 'epoch' in checkpoint:
epochs += checkpoint['epoch']
if 'global_step' in checkpoint:
steps += checkpoint['global_step']
# epoch and global_step are sometimes not int
try:
if 'epoch' in checkpoint:
epochs += checkpoint['epoch']
if 'global_step' in checkpoint:
steps += checkpoint['global_step']
except:
pass
new_ckpt['epoch'] = epochs
new_ckpt['global_step'] = steps