support sample generation in training

This commit is contained in:
Kohya S
2023-06-07 08:12:52 +09:00
parent d4b5cab7f7
commit dccdb8771c

View File

@@ -107,6 +107,7 @@ v2.1
"""
import math
from types import SimpleNamespace
from typing import Dict, Optional, Tuple, Union
import torch
from torch import nn
@@ -134,6 +135,10 @@ def get_parameter_dtype(parameter: torch.nn.Module):
return next(parameter.parameters()).dtype
def get_parameter_device(parameter: torch.nn.Module):
return next(parameter.parameters()).device
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
@@ -1068,6 +1073,7 @@ class UNet2DConditionModel(nn.Module):
self.out_channels = OUT_CHANNELS
self.sample_size = sample_size
self.prepare_config()
# state_dictの書式が変わるのでmoduleの持ち方は変えられない
@@ -1154,13 +1160,19 @@ class UNet2DConditionModel(nn.Module):
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
# region diffusers compatibility
def prepare_config(self):
self.config = SimpleNamespace()
@property
def dtype(self) -> torch.dtype:
"""
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
"""
# `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
return get_parameter_dtype(self)
@property
def device(self) -> torch.device:
# `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).
return get_parameter_device(self)
def set_attention_slice(self, slice_size):
raise NotImplementedError("Attention slicing is not supported for this model.")