Added trainer for Diffusion model
This commit is contained in:
47
src/models/diffusion_model.py
Normal file
47
src/models/diffusion_model.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class DiffusionModel(nn.Module):
|
||||
def __init__(self, time_dim: int = 64):
|
||||
super(DiffusionModel, self).__init__()
|
||||
self.time_dim = time_dim
|
||||
self.layers = nn.ModuleList()
|
||||
|
||||
def pos_encoding(self, t, channels):
|
||||
inv_freq = 1.0 / (
|
||||
10000 ** (torch.arange(0, channels, 2).float() / channels)
|
||||
).to(t.device)
|
||||
pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
|
||||
pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
|
||||
pos_enc = torch.cat((pos_enc_a, pos_enc_b), dim=-1)
|
||||
return pos_enc
|
||||
|
||||
|
||||
def forward(self, x, t, inputs):
|
||||
t = t.unsqueeze(-1).type(torch.float)
|
||||
t = self.pos_encoding(t, self.time_dim)
|
||||
|
||||
x = torch.cat((x, t, inputs), dim=-1)
|
||||
|
||||
for layer in self.layers[:-1]:
|
||||
x = layer(x)
|
||||
if not isinstance(layer, nn.ReLU):
|
||||
x = torch.cat((x, t, inputs), dim=-1)
|
||||
|
||||
x = self.layers[-1](x)
|
||||
return x
|
||||
|
||||
class SimpleDiffusionModel(DiffusionModel):
|
||||
def __init__(self, input_size: int, hidden_sizes: list, other_inputs_dim: int, time_dim: int = 64):
|
||||
super(SimpleDiffusionModel, self).__init__(time_dim)
|
||||
|
||||
self.other_inputs_dim = other_inputs_dim
|
||||
|
||||
self.layers.append(nn.Linear(input_size + time_dim + other_inputs_dim, hidden_sizes[0]))
|
||||
self.layers.append(nn.ReLU())
|
||||
|
||||
for i in range(1, len(hidden_sizes)):
|
||||
self.layers.append(nn.Linear(hidden_sizes[i - 1] + time_dim + other_inputs_dim, hidden_sizes[i]))
|
||||
self.layers.append(nn.ReLU())
|
||||
|
||||
self.layers.append(nn.Linear(hidden_sizes[-1] + time_dim + other_inputs_dim, input_size))
|
||||
Reference in New Issue
Block a user