Saving whole model instead of weights only
This commit is contained in:
@@ -236,7 +236,7 @@ class DiffusionTrainer:
|
||||
)
|
||||
|
||||
def save_checkpoint(self, val_loss, task, iteration: int):
|
||||
torch.save(self.model.state_dict(), "checkpoint.pt")
|
||||
torch.save(self.model, "checkpoint.pt")
|
||||
task.update_output_model(
|
||||
model_path="checkpoint.pt", iteration=iteration, auto_delete_file=False
|
||||
)
|
||||
|
||||
@@ -279,7 +279,7 @@ class Trainer:
|
||||
return test_loss
|
||||
|
||||
def save_checkpoint(self, val_loss, task, iteration: int):
|
||||
torch.save(self.model.state_dict(), "checkpoint.pt")
|
||||
torch.save(self.model, "checkpoint.pt")
|
||||
task.update_output_model(
|
||||
model_path="checkpoint.pt", iteration=iteration, auto_delete_file=False
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user