mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Merge branch 'dev' into gradual_latent_hires_fix
This commit is contained in:
@@ -57,7 +57,7 @@ import library.train_util as train_util
|
||||
import library.sdxl_model_util as sdxl_model_util
|
||||
import library.sdxl_train_util as sdxl_train_util
|
||||
from networks.lora import LoRANetwork
|
||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
|
||||
from library.original_unet import FlashAttentionFunction
|
||||
from networks.control_net_lllite import ControlNetLLLite
|
||||
|
||||
@@ -290,7 +290,7 @@ class PipelineLike:
|
||||
vae: AutoencoderKL,
|
||||
text_encoders: List[CLIPTextModel],
|
||||
tokenizers: List[CLIPTokenizer],
|
||||
unet: SdxlUNet2DConditionModel,
|
||||
unet: InferSdxlUNet2DConditionModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
clip_skip: int,
|
||||
):
|
||||
@@ -328,7 +328,7 @@ class PipelineLike:
|
||||
self.vae = vae
|
||||
self.text_encoders = text_encoders
|
||||
self.tokenizers = tokenizers
|
||||
self.unet: SdxlUNet2DConditionModel = unet
|
||||
self.unet: InferSdxlUNet2DConditionModel = unet
|
||||
self.scheduler = scheduler
|
||||
self.safety_checker = None
|
||||
|
||||
@@ -1611,6 +1611,7 @@ def main(args):
|
||||
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
|
||||
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype
|
||||
)
|
||||
unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet)
|
||||
|
||||
# xformers、Hypernetwork対応
|
||||
if not args.diffusers_xformers:
|
||||
@@ -1766,10 +1767,14 @@ def main(args):
|
||||
print("set vae_dtype to float32")
|
||||
vae_dtype = torch.float32
|
||||
vae.to(vae_dtype).to(device)
|
||||
vae.eval()
|
||||
|
||||
text_encoder1.to(dtype).to(device)
|
||||
text_encoder2.to(dtype).to(device)
|
||||
unet.to(dtype).to(device)
|
||||
text_encoder1.eval()
|
||||
text_encoder2.eval()
|
||||
unet.eval()
|
||||
|
||||
# networkを組み込む
|
||||
if args.network_module:
|
||||
|
||||
Reference in New Issue
Block a user