From f6a679e4913ba94de7829a2f5a1f9ba0ff9b9aa8 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Mon, 23 Feb 2026 18:43:46 +0900 Subject: [PATCH] fix: enhance architecture detection to support InferSdxlUNet2DConditionModel for gen_img.py --- networks/network_base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/networks/network_base.py b/networks/network_base.py index ab8d25ab..b2d8a9fd 100644 --- a/networks/network_base.py +++ b/networks/network_base.py @@ -9,6 +9,7 @@ from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple, Type, Union import torch +from library.sdxl_original_unet import InferSdxlUNet2DConditionModel from library.utils import setup_logging setup_logging() @@ -33,7 +34,9 @@ def detect_arch_config(unet, text_encoders) -> ArchConfig: from library.sdxl_original_unet import SdxlUNet2DConditionModel # Check SDXL first - if unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel): + if unet is not None and ( + issubclass(unet.__class__, SdxlUNet2DConditionModel) or issubclass(unet.__class__, InferSdxlUNet2DConditionModel) + ): return ArchConfig( unet_target_modules=["Transformer2DModel"], te_target_modules=["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"],