mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
support sample generation in training
This commit is contained in:
@@ -107,6 +107,7 @@ v2.1
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from types import SimpleNamespace
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -134,6 +135,10 @@ def get_parameter_dtype(parameter: torch.nn.Module):
|
|||||||
return next(parameter.parameters()).dtype
|
return next(parameter.parameters()).dtype
|
||||||
|
|
||||||
|
|
||||||
|
def get_parameter_device(parameter: torch.nn.Module):
|
||||||
|
return next(parameter.parameters()).device
|
||||||
|
|
||||||
|
|
||||||
def get_timestep_embedding(
|
def get_timestep_embedding(
|
||||||
timesteps: torch.Tensor,
|
timesteps: torch.Tensor,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
@@ -1068,6 +1073,7 @@ class UNet2DConditionModel(nn.Module):
|
|||||||
self.out_channels = OUT_CHANNELS
|
self.out_channels = OUT_CHANNELS
|
||||||
|
|
||||||
self.sample_size = sample_size
|
self.sample_size = sample_size
|
||||||
|
self.prepare_config()
|
||||||
|
|
||||||
# state_dictの書式が変わるのでmoduleの持ち方は変えられない
|
# state_dictの書式が変わるのでmoduleの持ち方は変えられない
|
||||||
|
|
||||||
@@ -1154,13 +1160,19 @@ class UNet2DConditionModel(nn.Module):
|
|||||||
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
|
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
|
||||||
|
|
||||||
# region diffusers compatibility
|
# region diffusers compatibility
|
||||||
|
def prepare_config(self):
|
||||||
|
self.config = SimpleNamespace()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
"""
|
# `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
||||||
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
|
||||||
"""
|
|
||||||
return get_parameter_dtype(self)
|
return get_parameter_dtype(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
# `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).
|
||||||
|
return get_parameter_device(self)
|
||||||
|
|
||||||
def set_attention_slice(self, slice_size):
|
def set_attention_slice(self, slice_size):
|
||||||
raise NotImplementedError("Attention slicing is not supported for this model.")
|
raise NotImplementedError("Attention slicing is not supported for this model.")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user