Tried policy with diffusion model

This commit is contained in:
Victor Mylle
2023-12-29 12:30:30 +00:00
parent da3ab3d5b3
commit ef8b5f49ac
3 changed files with 113 additions and 118 deletions

View File

@@ -94,7 +94,7 @@ class NrvDataset(Dataset):
# get indices of all 00:15 timestamps
if self.full_day_skip:
start_of_day_indices = dataframe[
dataframe["datetime"].dt.time != pd.Timestamp("00:15:00").time()
dataframe["datetime"].dt.time != pd.Timestamp("00:00:00").time()
].index
skip_indices.extend(start_of_day_indices)
skip_indices = list(set(skip_indices))

File diff suppressed because one or more lines are too long

View File

@@ -48,6 +48,7 @@ class DiffusionTrainer:
"""
return torch.randint(low=1, high=self.noise_steps, size=(n,))
def sample(self, model: DiffusionModel, n: int, inputs: torch.tensor):
inputs = inputs.repeat(n, 1).to(self.device)
model.eval()