mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
feat: add selective torch compile and activation memory budget to Lumina 2
This commit is contained in:
@@ -20,6 +20,7 @@
|
|||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
@@ -31,6 +32,15 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from library import custom_offloading_utils
|
from library import custom_offloading_utils
|
||||||
|
|
||||||
|
disable_selective_torch_compile = (
|
||||||
|
os.getenv("SDSCRIPTS_SELECTIVE_TORCH_COMPILE", "0") == "0"
|
||||||
|
)
|
||||||
|
memory_budget = float(
|
||||||
|
os.getenv("SDSCRIPTS_TORCH_COMPILE_ACTIVATION_MEMORY_BUDGET", "0")
|
||||||
|
)
|
||||||
|
if memory_budget > 0:
|
||||||
|
torch._functorch.config.activation_memory_budget = memory_budget
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn import flash_attn_varlen_func
|
from flash_attn import flash_attn_varlen_func
|
||||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||||
@@ -553,7 +563,7 @@ class JointAttention(nn.Module):
|
|||||||
f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}"
|
f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@torch.compiler.disable
|
||||||
def apply_rope(
|
def apply_rope(
|
||||||
x_in: torch.Tensor,
|
x_in: torch.Tensor,
|
||||||
freqs_cis: torch.Tensor,
|
freqs_cis: torch.Tensor,
|
||||||
@@ -633,7 +643,8 @@ class FeedForward(nn.Module):
|
|||||||
# @torch.compile
|
# @torch.compile
|
||||||
def _forward_silu_gating(self, x1, x3):
|
def _forward_silu_gating(self, x1, x3):
|
||||||
return F.silu(x1) * x3
|
return F.silu(x1) * x3
|
||||||
|
|
||||||
|
@torch.compile(disable=disable_selective_torch_compile)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||||
|
|
||||||
@@ -701,6 +712,7 @@ class JointTransformerBlock(GradientCheckpointMixin):
|
|||||||
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
||||||
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
||||||
|
|
||||||
|
@torch.compile(disable=disable_selective_torch_compile)
|
||||||
def _forward(
|
def _forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@@ -792,6 +804,7 @@ class FinalLayer(GradientCheckpointMixin):
|
|||||||
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
||||||
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
||||||
|
|
||||||
|
@torch.compile(disable=disable_selective_torch_compile)
|
||||||
def forward(self, x, c):
|
def forward(self, x, c):
|
||||||
scale = self.adaLN_modulation(c)
|
scale = self.adaLN_modulation(c)
|
||||||
x = modulate(self.norm_final(x), scale)
|
x = modulate(self.norm_final(x), scale)
|
||||||
|
|||||||
Reference in New Issue
Block a user