Saving whole model instead of weights only

This commit is contained in:
Victor Mylle
2024-01-15 11:07:57 +00:00
parent c26ae76951
commit a977021dfc
2 changed files with 2 additions and 2 deletions

View File

@@ -236,7 +236,7 @@ class DiffusionTrainer:
)
def save_checkpoint(self, val_loss, task, iteration: int):
torch.save(self.model.state_dict(), "checkpoint.pt")
torch.save(self.model, "checkpoint.pt")
task.update_output_model(
model_path="checkpoint.pt", iteration=iteration, auto_delete_file=False
)

View File

@@ -279,7 +279,7 @@ class Trainer:
return test_loss
def save_checkpoint(self, val_loss, task, iteration: int):
torch.save(self.model.state_dict(), "checkpoint.pt")
torch.save(self.model, "checkpoint.pt")
task.update_output_model(
model_path="checkpoint.pt", iteration=iteration, auto_delete_file=False
)