From dccdb8771c6facfb40087fc670e2698a2deb3bce Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 7 Jun 2023 08:12:52 +0900 Subject: [PATCH] support sample generation in training --- library/original_unet.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/library/original_unet.py b/library/original_unet.py index 47b751c1..e8920727 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -107,6 +107,7 @@ v2.1 """ import math +from types import SimpleNamespace from typing import Dict, Optional, Tuple, Union import torch from torch import nn @@ -134,6 +135,10 @@ def get_parameter_dtype(parameter: torch.nn.Module): return next(parameter.parameters()).dtype +def get_parameter_device(parameter: torch.nn.Module): + return next(parameter.parameters()).device + + def get_timestep_embedding( timesteps: torch.Tensor, embedding_dim: int, @@ -1068,6 +1073,7 @@ class UNet2DConditionModel(nn.Module): self.out_channels = OUT_CHANNELS self.sample_size = sample_size + self.prepare_config() # state_dictの書式が変わるのでmoduleの持ち方は変えられない @@ -1154,13 +1160,19 @@ class UNet2DConditionModel(nn.Module): self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1) # region diffusers compatibility + def prepare_config(self): + self.config = SimpleNamespace() + @property def dtype(self) -> torch.dtype: - """ - `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). - """ + # `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). return get_parameter_dtype(self) + @property + def device(self) -> torch.device: + # `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device). + return get_parameter_device(self) + def set_attention_slice(self, slice_size): raise NotImplementedError("Attention slicing is not supported for this model.")