mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Add caching gemma2, add gradient checkpointing, refactor lumina model code
This commit is contained in:
@@ -462,7 +462,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, wei
|
||||
|
||||
class LoRANetwork(torch.nn.Module):
|
||||
LUMINA_TARGET_REPLACE_MODULE = ["JointTransformerBlock"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["GemmaAttention", "GemmaDecoderLayer", "GemmaMLP"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Gemma2Attention", "Gemma2MLP"]
|
||||
LORA_PREFIX_LUMINA = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te" # Simplified prefix since we only have one text encoder
|
||||
|
||||
@@ -533,7 +533,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
filter: Optional[str] = None,
|
||||
default_dim: Optional[int] = None,
|
||||
) -> List[LoRAModule]:
|
||||
prefix = self.LORA_PREFIX_FLUX if is_lumina else self.LORA_PREFIX_TEXT_ENCODER
|
||||
prefix = self.LORA_PREFIX_LUMINA if is_lumina else self.LORA_PREFIX_TEXT_ENCODER
|
||||
|
||||
loras = []
|
||||
skipped = []
|
||||
@@ -611,7 +611,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
skipped_te = []
|
||||
|
||||
logger.info(f"create LoRA for Gemma2 Text Encoder:")
|
||||
text_encoder_loras, skipped = create_modules(False, text_encoders, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
text_encoder_loras, skipped = create_modules(False, text_encoders[0], LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
logger.info(f"create LoRA for Gemma2 Text Encoder: {len(text_encoder_loras)} modules.")
|
||||
self.text_encoder_loras.extend(text_encoder_loras)
|
||||
skipped_te += skipped
|
||||
@@ -718,10 +718,10 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
if not self.split_qkv:
|
||||
return super().state_dict(destination, prefix, keep_vars)
|
||||
return super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
|
||||
# merge qkv
|
||||
state_dict = super().state_dict(destination, prefix, keep_vars)
|
||||
state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
new_state_dict = {}
|
||||
for key in list(state_dict.keys()):
|
||||
if "double" in key and "qkv" in key:
|
||||
|
||||
Reference in New Issue
Block a user