diff --git a/library/lumina_models.py b/library/lumina_models.py index 43b1e9c6..3f2e854e 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -21,7 +21,8 @@ import torch.nn.functional as F try: from apex.normalization import FusedRMSNorm as RMSNorm -except ImportError: +except ModuleNotFoundError: + import warnings warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") memory_efficient_attention = None @@ -39,17 +40,20 @@ except: class LuminaParams: """Parameters for Lumina model configuration""" patch_size: int = 2 - dim: int = 2592 + in_channels: int = 4 + dim: int = 4096 n_layers: int = 30 + n_refiner_layers: int = 2 n_heads: int = 24 n_kv_heads: int = 8 + multiple_of: int = 256 axes_dims: List[int] = None axes_lens: List[int] = None - qk_norm: bool = False, - ffn_dim_multiplier: Optional[float] = None, - norm_eps: float = 1e-5, - scaling_factor: float = 1.0, - cap_feat_dim: int = 32, + qk_norm: bool = False + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + scaling_factor: float = 1.0 + cap_feat_dim: int = 32 def __post_init__(self): if self.axes_dims is None: @@ -62,12 +66,15 @@ class LuminaParams: """Returns the configuration for the 2B parameter model""" return cls( patch_size=2, - dim=2592, - n_layers=30, + in_channels=16, + dim=2304, + n_layers=26, n_heads=24, n_kv_heads=8, - axes_dims=[36, 36, 36], - axes_lens=[300, 512, 512] + axes_dims=[32, 32, 32], + axes_lens=[300, 512, 512], + qk_norm=True, + cap_feat_dim=2304 ) @classmethod @@ -696,8 +703,8 @@ class NextDiT(nn.Module): norm_eps: float = 1e-5, qk_norm: bool = False, cap_feat_dim: int = 5120, - axes_dims: List[int] = (16, 56, 56), - axes_lens: List[int] = (1, 512, 512), + axes_dims: List[int] = [16, 56, 56], + axes_lens: List[int] = [1, 512, 512], ) -> None: super().__init__() self.in_channels = in_channels @@ -1090,6 +1097,7 @@ def NextDiT_2B_GQA_patch2_Adaln_Refiner(params: Optional[LuminaParams] = None, * return NextDiT( patch_size=params.patch_size, + in_channels=params.in_channels, dim=params.dim, n_layers=params.n_layers, n_heads=params.n_heads, @@ -1099,7 +1107,6 @@ def NextDiT_2B_GQA_patch2_Adaln_Refiner(params: Optional[LuminaParams] = None, * qk_norm=params.qk_norm, ffn_dim_multiplier=params.ffn_dim_multiplier, norm_eps=params.norm_eps, - scaling_factor=params.scaling_factor, cap_feat_dim=params.cap_feat_dim, **kwargs, ) diff --git a/library/lumina_util.py b/library/lumina_util.py index b47e057a..f8e3f7db 100644 --- a/library/lumina_util.py +++ b/library/lumina_util.py @@ -27,14 +27,14 @@ def load_lumina_model( dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False, -) -> lumina_models.Lumina: +): logger.info("Building Lumina") with torch.device("meta"): model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner().to(dtype) logger.info(f"Loading state dict from {ckpt_path}") state_dict = load_safetensors( - ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype + ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype ) info = model.load_state_dict(state_dict, strict=False, assign=True) logger.info(f"Loaded Lumina: {info}") @@ -69,30 +69,39 @@ def load_gemma2( ) -> Gemma2Model: logger.info("Building Gemma2") GEMMA2_CONFIG = { - "_name_or_path": "google/gemma-2b", - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 2, - "eos_token_id": 1, - "head_dim": 256, - "hidden_act": "gelu", - "hidden_size": 2048, - "initializer_range": 0.02, - "intermediate_size": 16384, - "max_position_embeddings": 8192, - "model_type": "gemma", - "num_attention_heads": 8, - "num_hidden_layers": 18, - "num_key_value_heads": 1, - "pad_token_id": 0, - "rms_norm_eps": 1e-06, - "rope_scaling": null, - "rope_theta": 10000.0, - "torch_dtype": "bfloat16", - "transformers_version": "4.38.0.dev0", - "use_cache": true, - "vocab_size": 256000 + "_name_or_path": "google/gemma-2-2b", + "architectures": [ + "Gemma2Model" + ], + "attention_bias": False, + "attention_dropout": 0.0, + "attn_logit_softcapping": 50.0, + "bos_token_id": 2, + "cache_implementation": "hybrid", + "eos_token_id": 1, + "final_logit_softcapping": 30.0, + "head_dim": 256, + "hidden_act": "gelu_pytorch_tanh", + "hidden_activation": "gelu_pytorch_tanh", + "hidden_size": 2304, + "initializer_range": 0.02, + "intermediate_size": 9216, + "max_position_embeddings": 8192, + "model_type": "gemma2", + "num_attention_heads": 8, + "num_hidden_layers": 26, + "num_key_value_heads": 4, + "pad_token_id": 0, + "query_pre_attn_scalar": 256, + "rms_norm_eps": 1e-06, + "rope_theta": 10000.0, + "sliding_window": 4096, + "torch_dtype": "float32", + "transformers_version": "4.44.2", + "use_cache": True, + "vocab_size": 256000 } + config = Gemma2Config(**GEMMA2_CONFIG) with init_empty_weights(): gemma2 = Gemma2Model._from_config(config) @@ -104,6 +113,13 @@ def load_gemma2( sd = load_safetensors( ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype ) + + for key in list(sd.keys()): + new_key = key.replace("model.", "") + if new_key == key: + break # the model doesn't have annoying prefix + sd[new_key] = sd.pop(key) + info = gemma2.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded Gemma2: {info}") return gemma2 diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index 622c019a..615f6e00 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -9,7 +9,9 @@ from library.strategy_base import ( LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy, + TextEncoderOutputsCachingStrategy ) +import numpy as np from library.utils import setup_logging setup_logging() diff --git a/lumina_train_network.py b/lumina_train_network.py index db329a9b..1f8ba613 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -345,7 +345,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() train_util.add_dit_training_arguments(parser) - lumina_train_utils.add_lumina_train_arguments(parser) + lumina_train_util.add_lumina_train_arguments(parser) return parser