add schnell option to load_cn

This commit is contained in:
minux302
2024-11-30 00:08:21 +09:00
parent 575f583fd9
commit be5860f8e2
2 changed files with 8 additions and 10 deletions

View File

@@ -259,14 +259,14 @@ def train(args):
clean_memory_on_device(accelerator.device)
# load FLUX
_, flux = flux_utils.load_flow_model(
is_schnell, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
)
flux.requires_grad_(False)
flux.to(accelerator.device)
# load controlnet
controlnet = flux_utils.load_controlnet(args.controlnet, torch.float32, accelerator.device, args.disable_mmap_load_safetensors)
controlnet = flux_utils.load_controlnet(args.controlnet, is_schnell, torch.float32, accelerator.device, args.disable_mmap_load_safetensors)
controlnet.train()
if args.gradient_checkpointing: