mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
fix: enhance architecture detection to support InferSdxlUNet2DConditionModel for gen_img.py
This commit is contained in:
@@ -9,6 +9,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
|
||||||
from library.utils import setup_logging
|
from library.utils import setup_logging
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
@@ -33,7 +34,9 @@ def detect_arch_config(unet, text_encoders) -> ArchConfig:
|
|||||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||||
|
|
||||||
# Check SDXL first
|
# Check SDXL first
|
||||||
if unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel):
|
if unet is not None and (
|
||||||
|
issubclass(unet.__class__, SdxlUNet2DConditionModel) or issubclass(unet.__class__, InferSdxlUNet2DConditionModel)
|
||||||
|
):
|
||||||
return ArchConfig(
|
return ArchConfig(
|
||||||
unet_target_modules=["Transformer2DModel"],
|
unet_target_modules=["Transformer2DModel"],
|
||||||
te_target_modules=["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"],
|
te_target_modules=["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"],
|
||||||
|
|||||||
Reference in New Issue
Block a user