feat: add LoRA training support for Chroma

This commit is contained in:
Kohya S
2025-07-20 19:00:09 +09:00
parent c4958b5dca
commit b4e862626a
10 changed files with 158 additions and 260 deletions

View File

@@ -23,6 +23,7 @@ from library.utils import load_safetensors
MODEL_VERSION_FLUX_V1 = "flux1"
MODEL_NAME_DEV = "dev"
MODEL_NAME_SCHNELL = "schnell"
MODEL_VERSION_CHROMA = "chroma"
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
@@ -97,7 +98,7 @@ def load_flow_model(
device: Union[str, torch.device],
disable_mmap: bool = False,
model_type: str = "flux",
) -> Tuple[str, bool, flux_models.Flux]:
) -> Tuple[bool, flux_models.Flux]:
if model_type == "flux":
is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
@@ -140,7 +141,7 @@ def load_flow_model(
info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Flux: {info}")
return model_type, is_schnell, model
return is_schnell, model
elif model_type == "chroma":
from . import chroma_models
@@ -166,7 +167,7 @@ def load_flow_model(
info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Chroma: {info}")
is_schnell = False # Chroma is not schnell
return model_type, is_schnell, model
return is_schnell, model
else:
raise ValueError(f"Unsupported model_type: {model_type}. Supported types are 'flux' and 'chroma'.")
@@ -203,6 +204,42 @@ def load_controlnet(
return controlnet
def dummy_clip_l() -> torch.nn.Module:
"""
Returns a dummy CLIP-L model with the output shape of (N, 77, 768).
"""
return DummyCLIPL()
class DummyTextModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.embeddings = torch.nn.Parameter(torch.zeros(1))
class DummyCLIPL(torch.nn.Module):
def __init__(self):
super().__init__()
self.output_shape = (77, 1) # Note: The original code had (77, 768), but we use (77, 1) for the dummy output
self.dummy_param = torch.nn.Parameter(torch.zeros(1)) # get dtype and device from this parameter
self.text_model = DummyTextModel()
@property
def device(self):
return self.dummy_param.device
@property
def dtype(self):
return self.dummy_param.dtype
def forward(self, *args, **kwargs):
"""
Returns a dummy output with the shape of (N, 77, 768).
"""
batch_size = args[0].shape[0] if args else 1
return {"pooler_output": torch.zeros(batch_size, *self.output_shape, device=self.device, dtype=self.dtype)}
def load_clip_l(
ckpt_path: Optional[str],
dtype: torch.dtype,