feat: support Chroma model in loading and inference processes

This commit is contained in:
kohya-ss
2025-07-20 12:56:42 +09:00
parent a96d684ffa
commit 24d2ea86c7
6 changed files with 123 additions and 133 deletions

View File

@@ -270,8 +270,8 @@ def train(args):
clean_memory_on_device(accelerator.device)
# load FLUX
_, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
model_type, _, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux"
)
if args.gradient_checkpointing: