Tried policy with diffusion model
This commit is contained in:
@@ -94,7 +94,7 @@ class NrvDataset(Dataset):
|
||||
# get indices of all 00:15 timestamps
|
||||
if self.full_day_skip:
|
||||
start_of_day_indices = dataframe[
|
||||
dataframe["datetime"].dt.time != pd.Timestamp("00:15:00").time()
|
||||
dataframe["datetime"].dt.time != pd.Timestamp("00:00:00").time()
|
||||
].index
|
||||
skip_indices.extend(start_of_day_indices)
|
||||
skip_indices = list(set(skip_indices))
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -48,6 +48,7 @@ class DiffusionTrainer:
|
||||
"""
|
||||
return torch.randint(low=1, high=self.noise_steps, size=(n,))
|
||||
|
||||
|
||||
def sample(self, model: DiffusionModel, n: int, inputs: torch.tensor):
|
||||
inputs = inputs.repeat(n, 1).to(self.device)
|
||||
model.eval()
|
||||
|
||||
Reference in New Issue
Block a user