mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add schnell option to load_cn
This commit is contained in:
@@ -259,14 +259,14 @@ def train(args):
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# load FLUX
|
||||
_, flux = flux_utils.load_flow_model(
|
||||
is_schnell, flux = flux_utils.load_flow_model(
|
||||
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
|
||||
)
|
||||
flux.requires_grad_(False)
|
||||
flux.to(accelerator.device)
|
||||
|
||||
# load controlnet
|
||||
controlnet = flux_utils.load_controlnet(args.controlnet, torch.float32, accelerator.device, args.disable_mmap_load_safetensors)
|
||||
controlnet = flux_utils.load_controlnet(args.controlnet, is_schnell, torch.float32, accelerator.device, args.disable_mmap_load_safetensors)
|
||||
controlnet.train()
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from dataclasses import replace
|
||||
import json
|
||||
import os
|
||||
from dataclasses import replace
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import einops
|
||||
import torch
|
||||
|
||||
from safetensors.torch import load_file
|
||||
from safetensors import safe_open
|
||||
from accelerate import init_empty_weights
|
||||
from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import load_file
|
||||
from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
@@ -154,11 +154,9 @@ def load_ae(
|
||||
|
||||
|
||||
def load_controlnet(
|
||||
ckpt_path: Optional[str], dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
|
||||
ckpt_path: Optional[str], is_schnell: bool, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
|
||||
):
|
||||
logger.info("Building ControlNet")
|
||||
# is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
|
||||
is_schnell = False
|
||||
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
|
||||
with torch.device(device):
|
||||
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype)
|
||||
|
||||
Reference in New Issue
Block a user