add schnell option to load_cn

This commit is contained in:
minux302
2024-11-30 00:08:21 +09:00
parent 575f583fd9
commit be5860f8e2
2 changed files with 8 additions and 10 deletions

View File

@@ -1,14 +1,14 @@
from dataclasses import replace
import json
import os
from dataclasses import replace
from typing import List, Optional, Tuple, Union
import einops
import torch
from safetensors.torch import load_file
from safetensors import safe_open
from accelerate import init_empty_weights
from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config
from safetensors import safe_open
from safetensors.torch import load_file
from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel
from library.utils import setup_logging
@@ -154,11 +154,9 @@ def load_ae(
def load_controlnet(
ckpt_path: Optional[str], dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
ckpt_path: Optional[str], is_schnell: bool, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
):
logger.info("Building ControlNet")
# is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
is_schnell = False
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
with torch.device(device):
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype)