Added policy executer file for remotely executing

This commit is contained in:
Victor Mylle
2024-01-15 21:19:33 +00:00
parent 67cc6d4bb9
commit 428e3d9e4b
7 changed files with 323 additions and 110 deletions

View File

@@ -12,6 +12,34 @@ import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.patches as mpatches
def sample_diffusion(model: DiffusionModel, n: int, inputs: torch.tensor, noise_steps=1000, beta_start=1e-4, beta_end=0.02, ts_length=96):
device = next(model.parameters()).device
beta = torch.linspace(beta_start, beta_end, noise_steps).to(device)
alpha = 1. - beta
alpha_hat = torch.cumprod(alpha, dim=0)
inputs = inputs.repeat(n, 1).to(device)
model.eval()
with torch.no_grad():
x = torch.randn(inputs.shape[0], ts_length).to(device)
for i in reversed(range(1, noise_steps)):
t = (torch.ones(inputs.shape[0]) * i).long().to(device)
predicted_noise = model(x, t, inputs)
_alpha = alpha[t][:, None]
_alpha_hat = alpha_hat[t][:, None]
_beta = beta[t][:, None]
if i > 1:
noise = torch.randn_like(x)
else:
noise = torch.zeros_like(x)
x = 1/torch.sqrt(_alpha) * (x-((1-_alpha) / (torch.sqrt(1 - _alpha_hat))) * predicted_noise) + torch.sqrt(_beta) * noise
return x
class DiffusionTrainer:
def __init__(self, model: nn.Module, data_processor: DataProcessor, device: torch.device):
self.model = model
@@ -50,23 +78,7 @@ class DiffusionTrainer:
def sample(self, model: DiffusionModel, n: int, inputs: torch.tensor):
inputs = inputs.repeat(n, 1).to(self.device)
model.eval()
with torch.no_grad():
x = torch.randn(inputs.shape[0], self.ts_length).to(self.device)
for i in reversed(range(1, self.noise_steps)):
t = (torch.ones(inputs.shape[0]) * i).long().to(self.device)
predicted_noise = model(x, t, inputs)
alpha = self.alpha[t][:, None]
alpha_hat = self.alpha_hat[t][:, None]
beta = self.beta[t][:, None]
if i > 1:
noise = torch.randn_like(x)
else:
noise = torch.zeros_like(x)
x = 1/torch.sqrt(alpha) * (x-((1-alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
x = sample_diffusion(model, n, inputs, self.noise_steps, self.beta_start, self.beta_end, self.ts_length)
model.train()
return x