support sample generation in training

This commit is contained in:
Kohya S
2023-06-07 08:12:52 +09:00
parent d4b5cab7f7
commit dccdb8771c

View File

@@ -107,6 +107,7 @@ v2.1
""" """
import math import math
from types import SimpleNamespace
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
@@ -134,6 +135,10 @@ def get_parameter_dtype(parameter: torch.nn.Module):
return next(parameter.parameters()).dtype return next(parameter.parameters()).dtype
def get_parameter_device(parameter: torch.nn.Module):
return next(parameter.parameters()).device
def get_timestep_embedding( def get_timestep_embedding(
timesteps: torch.Tensor, timesteps: torch.Tensor,
embedding_dim: int, embedding_dim: int,
@@ -1068,6 +1073,7 @@ class UNet2DConditionModel(nn.Module):
self.out_channels = OUT_CHANNELS self.out_channels = OUT_CHANNELS
self.sample_size = sample_size self.sample_size = sample_size
self.prepare_config()
# state_dictの書式が変わるのでmoduleの持ち方は変えられない # state_dictの書式が変わるのでmoduleの持ち方は変えられない
@@ -1154,13 +1160,19 @@ class UNet2DConditionModel(nn.Module):
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1) self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
# region diffusers compatibility # region diffusers compatibility
def prepare_config(self):
self.config = SimpleNamespace()
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:
""" # `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
"""
return get_parameter_dtype(self) return get_parameter_dtype(self)
@property
def device(self) -> torch.device:
# `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).
return get_parameter_device(self)
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
raise NotImplementedError("Attention slicing is not supported for this model.") raise NotImplementedError("Attention slicing is not supported for this model.")