mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
feat: add LoRA training support for Chroma
This commit is contained in:
@@ -23,6 +23,7 @@ from library.utils import load_safetensors
|
||||
MODEL_VERSION_FLUX_V1 = "flux1"
|
||||
MODEL_NAME_DEV = "dev"
|
||||
MODEL_NAME_SCHNELL = "schnell"
|
||||
MODEL_VERSION_CHROMA = "chroma"
|
||||
|
||||
|
||||
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
|
||||
@@ -97,7 +98,7 @@ def load_flow_model(
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
model_type: str = "flux",
|
||||
) -> Tuple[str, bool, flux_models.Flux]:
|
||||
) -> Tuple[bool, flux_models.Flux]:
|
||||
if model_type == "flux":
|
||||
is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
|
||||
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
|
||||
@@ -140,7 +141,7 @@ def load_flow_model(
|
||||
|
||||
info = model.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded Flux: {info}")
|
||||
return model_type, is_schnell, model
|
||||
return is_schnell, model
|
||||
|
||||
elif model_type == "chroma":
|
||||
from . import chroma_models
|
||||
@@ -166,7 +167,7 @@ def load_flow_model(
|
||||
info = model.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded Chroma: {info}")
|
||||
is_schnell = False # Chroma is not schnell
|
||||
return model_type, is_schnell, model
|
||||
return is_schnell, model
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported model_type: {model_type}. Supported types are 'flux' and 'chroma'.")
|
||||
@@ -203,6 +204,42 @@ def load_controlnet(
|
||||
return controlnet
|
||||
|
||||
|
||||
def dummy_clip_l() -> torch.nn.Module:
|
||||
"""
|
||||
Returns a dummy CLIP-L model with the output shape of (N, 77, 768).
|
||||
"""
|
||||
return DummyCLIPL()
|
||||
|
||||
|
||||
class DummyTextModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.embeddings = torch.nn.Parameter(torch.zeros(1))
|
||||
|
||||
|
||||
class DummyCLIPL(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.output_shape = (77, 1) # Note: The original code had (77, 768), but we use (77, 1) for the dummy output
|
||||
self.dummy_param = torch.nn.Parameter(torch.zeros(1)) # get dtype and device from this parameter
|
||||
self.text_model = DummyTextModel()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.dummy_param.device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.dummy_param.dtype
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
Returns a dummy output with the shape of (N, 77, 768).
|
||||
"""
|
||||
batch_size = args[0].shape[0] if args else 1
|
||||
return {"pooler_output": torch.zeros(batch_size, *self.output_shape, device=self.device, dtype=self.dtype)}
|
||||
|
||||
|
||||
def load_clip_l(
|
||||
ckpt_path: Optional[str],
|
||||
dtype: torch.dtype,
|
||||
|
||||
Reference in New Issue
Block a user