diff --git a/src/trainers/diffusion_trainer.py b/src/trainers/diffusion_trainer.py index 00eaae6..37b6e1c 100644 --- a/src/trainers/diffusion_trainer.py +++ b/src/trainers/diffusion_trainer.py @@ -236,7 +236,7 @@ class DiffusionTrainer: ) def save_checkpoint(self, val_loss, task, iteration: int): - torch.save(self.model.state_dict(), "checkpoint.pt") + torch.save(self.model, "checkpoint.pt") task.update_output_model( model_path="checkpoint.pt", iteration=iteration, auto_delete_file=False ) diff --git a/src/trainers/trainer.py b/src/trainers/trainer.py index 23ef867..18f0a88 100644 --- a/src/trainers/trainer.py +++ b/src/trainers/trainer.py @@ -279,7 +279,7 @@ class Trainer: return test_loss def save_checkpoint(self, val_loss, task, iteration: int): - torch.save(self.model.state_dict(), "checkpoint.pt") + torch.save(self.model, "checkpoint.pt") task.update_output_model( model_path="checkpoint.pt", iteration=iteration, auto_delete_file=False )