add schnell option to load_cn

This commit is contained in:
minux302
2024-11-30 00:08:21 +09:00
parent 575f583fd9
commit be5860f8e2
2 changed files with 8 additions and 10 deletions

View File

@@ -259,14 +259,14 @@ def train(args):
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
# load FLUX # load FLUX
_, flux = flux_utils.load_flow_model( is_schnell, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
) )
flux.requires_grad_(False) flux.requires_grad_(False)
flux.to(accelerator.device) flux.to(accelerator.device)
# load controlnet # load controlnet
controlnet = flux_utils.load_controlnet(args.controlnet, torch.float32, accelerator.device, args.disable_mmap_load_safetensors) controlnet = flux_utils.load_controlnet(args.controlnet, is_schnell, torch.float32, accelerator.device, args.disable_mmap_load_safetensors)
controlnet.train() controlnet.train()
if args.gradient_checkpointing: if args.gradient_checkpointing:

View File

@@ -1,14 +1,14 @@
from dataclasses import replace
import json import json
import os import os
from dataclasses import replace
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import einops import einops
import torch import torch
from safetensors.torch import load_file
from safetensors import safe_open
from accelerate import init_empty_weights from accelerate import init_empty_weights
from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config from safetensors import safe_open
from safetensors.torch import load_file
from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel
from library.utils import setup_logging from library.utils import setup_logging
@@ -154,11 +154,9 @@ def load_ae(
def load_controlnet( def load_controlnet(
ckpt_path: Optional[str], dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False ckpt_path: Optional[str], is_schnell: bool, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
): ):
logger.info("Building ControlNet") logger.info("Building ControlNet")
# is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
is_schnell = False
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
with torch.device(device): with torch.device(device):
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype) controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype)