fix: enhance architecture detection to support InferSdxlUNet2DConditionModel for gen_img.py

This commit is contained in:
Kohya S
2026-02-23 18:43:46 +09:00
parent 5978694223
commit f6a679e491

View File

@@ -9,6 +9,7 @@ from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Type, Union
import torch
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
from library.utils import setup_logging
setup_logging()
@@ -33,7 +34,9 @@ def detect_arch_config(unet, text_encoders) -> ArchConfig:
from library.sdxl_original_unet import SdxlUNet2DConditionModel
# Check SDXL first
if unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel):
if unet is not None and (
issubclass(unet.__class__, SdxlUNet2DConditionModel) or issubclass(unet.__class__, InferSdxlUNet2DConditionModel)
):
return ArchConfig(
unet_target_modules=["Transformer2DModel"],
te_target_modules=["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"],