mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support sample generation in training
This commit is contained in:
@@ -107,6 +107,7 @@ v2.1
|
||||
"""
|
||||
|
||||
import math
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -134,6 +135,10 @@ def get_parameter_dtype(parameter: torch.nn.Module):
|
||||
return next(parameter.parameters()).dtype
|
||||
|
||||
|
||||
def get_parameter_device(parameter: torch.nn.Module):
|
||||
return next(parameter.parameters()).device
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
embedding_dim: int,
|
||||
@@ -1068,6 +1073,7 @@ class UNet2DConditionModel(nn.Module):
|
||||
self.out_channels = OUT_CHANNELS
|
||||
|
||||
self.sample_size = sample_size
|
||||
self.prepare_config()
|
||||
|
||||
# state_dictの書式が変わるのでmoduleの持ち方は変えられない
|
||||
|
||||
@@ -1154,13 +1160,19 @@ class UNet2DConditionModel(nn.Module):
|
||||
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
|
||||
|
||||
# region diffusers compatibility
|
||||
def prepare_config(self):
|
||||
self.config = SimpleNamespace()
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
"""
|
||||
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
||||
"""
|
||||
# `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
||||
return get_parameter_dtype(self)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
# `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).
|
||||
return get_parameter_device(self)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
raise NotImplementedError("Attention slicing is not supported for this model.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user