diff --git a/library/model_util.py b/library/model_util.py index d1020c05..3d8e7539 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -1046,10 +1046,14 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p key_count = len(state_dict.keys()) new_ckpt = {'state_dict': state_dict} - if 'epoch' in checkpoint: - epochs += checkpoint['epoch'] - if 'global_step' in checkpoint: - steps += checkpoint['global_step'] + # epoch and global_step are sometimes not int + try: + if 'epoch' in checkpoint: + epochs += checkpoint['epoch'] + if 'global_step' in checkpoint: + steps += checkpoint['global_step'] + except: + pass new_ckpt['epoch'] = epochs new_ckpt['global_step'] = steps