mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 08:36:41 +00:00
fix: enhance architecture detection to support InferSdxlUNet2DConditionModel for gen_img.py
This commit is contained in:
@@ -9,6 +9,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -33,7 +34,9 @@ def detect_arch_config(unet, text_encoders) -> ArchConfig:
|
||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||
|
||||
# Check SDXL first
|
||||
if unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel):
|
||||
if unet is not None and (
|
||||
issubclass(unet.__class__, SdxlUNet2DConditionModel) or issubclass(unet.__class__, InferSdxlUNet2DConditionModel)
|
||||
):
|
||||
return ArchConfig(
|
||||
unet_target_modules=["Transformer2DModel"],
|
||||
te_target_modules=["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"],
|
||||
|
||||
Reference in New Issue
Block a user