Fix loss recorder on 0. Fix validation for cached runs. Assert on validation dataset

This commit is contained in:
rockerBOO
2025-01-23 09:57:24 -05:00
parent b489082495
commit c04e5dfe92
8 changed files with 46 additions and 20 deletions

View File

@@ -2893,6 +2893,9 @@ class MinimalDataset(BaseDataset):
"""
raise NotImplementedError
def get_resolutions(self) -> List[Tuple[int, int]]:
return []
def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset:
module = ".".join(args.dataset_class.split(".")[:-1])
@@ -6520,4 +6523,7 @@ class LossRecorder:
@property
def moving_average(self) -> float:
return self.loss_total / len(self.loss_list)
losses = len(self.loss_list)
if losses == 0:
return 0
return self.loss_total / losses