Add caching gemma2, add gradient checkpointing, refactor lumina model code

This commit is contained in:
rockerBOO
2025-02-16 01:06:34 -05:00
parent a00b06bc97
commit 60a76ebb72
4 changed files with 304 additions and 225 deletions

View File

@@ -462,7 +462,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, wei
class LoRANetwork(torch.nn.Module):
LUMINA_TARGET_REPLACE_MODULE = ["JointTransformerBlock"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["GemmaAttention", "GemmaDecoderLayer", "GemmaMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Gemma2Attention", "Gemma2MLP"]
LORA_PREFIX_LUMINA = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te" # Simplified prefix since we only have one text encoder
@@ -533,7 +533,7 @@ class LoRANetwork(torch.nn.Module):
filter: Optional[str] = None,
default_dim: Optional[int] = None,
) -> List[LoRAModule]:
prefix = self.LORA_PREFIX_FLUX if is_lumina else self.LORA_PREFIX_TEXT_ENCODER
prefix = self.LORA_PREFIX_LUMINA if is_lumina else self.LORA_PREFIX_TEXT_ENCODER
loras = []
skipped = []
@@ -611,7 +611,7 @@ class LoRANetwork(torch.nn.Module):
skipped_te = []
logger.info(f"create LoRA for Gemma2 Text Encoder:")
text_encoder_loras, skipped = create_modules(False, text_encoders, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
text_encoder_loras, skipped = create_modules(False, text_encoders[0], LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
logger.info(f"create LoRA for Gemma2 Text Encoder: {len(text_encoder_loras)} modules.")
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
@@ -718,10 +718,10 @@ class LoRANetwork(torch.nn.Module):
def state_dict(self, destination=None, prefix="", keep_vars=False):
if not self.split_qkv:
return super().state_dict(destination, prefix, keep_vars)
return super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
# merge qkv
state_dict = super().state_dict(destination, prefix, keep_vars)
state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
new_state_dict = {}
for key in list(state_dict.keys()):
if "double" in key and "qkv" in key: