diff --git a/src/training_scripts/diffusion_training.py b/src/training_scripts/diffusion_training.py index f68e786..9a1c2a3 100644 --- a/src/training_scripts/diffusion_training.py +++ b/src/training_scripts/diffusion_training.py @@ -54,8 +54,8 @@ model_parameters = { model_parameters = task.connect(model_parameters, name="model_parameters") #### Model #### -model = SimpleDiffusionModel(96, model_parameters["hidden_sizes"], other_inputs_dim=inputDim[1], time_dim=model_parameters["time_dim"]) -# model = GRUDiffusionModel(96, model_parameters["hidden_sizes"], other_inputs_dim=inputDim[2], time_dim=model_parameters["time_dim"], gru_hidden_size=256) +# model = SimpleDiffusionModel(96, model_parameters["hidden_sizes"], other_inputs_dim=inputDim[1], time_dim=model_parameters["time_dim"]) +model = GRUDiffusionModel(96, model_parameters["hidden_sizes"], other_inputs_dim=inputDim[2], time_dim=model_parameters["time_dim"], gru_hidden_size=256) print("Starting training ...")