mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Merge 3bbfa9b258 into d633b51126
This commit is contained in:
@@ -20,6 +20,7 @@
|
|||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
@@ -31,6 +32,10 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from library import custom_offloading_utils
|
from library import custom_offloading_utils
|
||||||
|
|
||||||
|
disable_selective_torch_compile = (
|
||||||
|
os.getenv("SDSCRIPTS_SELECTIVE_TORCH_COMPILE", "0") == "0"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn import flash_attn_varlen_func
|
from flash_attn import flash_attn_varlen_func
|
||||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||||
@@ -549,7 +554,7 @@ class JointAttention(nn.Module):
|
|||||||
f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}"
|
f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@torch.compiler.disable(reason="complex ops inside")
|
||||||
def apply_rope(
|
def apply_rope(
|
||||||
x_in: torch.Tensor,
|
x_in: torch.Tensor,
|
||||||
freqs_cis: torch.Tensor,
|
freqs_cis: torch.Tensor,
|
||||||
@@ -625,13 +630,10 @@ class FeedForward(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
nn.init.xavier_uniform_(self.w3.weight)
|
nn.init.xavier_uniform_(self.w3.weight)
|
||||||
|
|
||||||
# @torch.compile
|
@torch.compile(disable=disable_selective_torch_compile)
|
||||||
def _forward_silu_gating(self, x1, x3):
|
|
||||||
return F.silu(x1) * x3
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
return self.w2(F.silu(self.w1(x))*self.w3(x))
|
||||||
|
|
||||||
|
|
||||||
class JointTransformerBlock(GradientCheckpointMixin):
|
class JointTransformerBlock(GradientCheckpointMixin):
|
||||||
@@ -697,6 +699,7 @@ class JointTransformerBlock(GradientCheckpointMixin):
|
|||||||
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
||||||
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
||||||
|
|
||||||
|
@torch.compile(disable=disable_selective_torch_compile)
|
||||||
def _forward(
|
def _forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@@ -788,6 +791,7 @@ class FinalLayer(GradientCheckpointMixin):
|
|||||||
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
||||||
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
||||||
|
|
||||||
|
@torch.compile(disable=disable_selective_torch_compile)
|
||||||
def forward(self, x, c):
|
def forward(self, x, c):
|
||||||
scale = self.adaLN_modulation(c)
|
scale = self.adaLN_modulation(c)
|
||||||
x = modulate(self.norm_final(x), scale)
|
x = modulate(self.norm_final(x), scale)
|
||||||
@@ -808,6 +812,7 @@ class RopeEmbedder:
|
|||||||
self.axes_lens = axes_lens
|
self.axes_lens = axes_lens
|
||||||
self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
|
self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
|
||||||
|
|
||||||
|
@torch.compiler.disable(reason="complex ops inside")
|
||||||
def __call__(self, ids: torch.Tensor):
|
def __call__(self, ids: torch.Tensor):
|
||||||
device = ids.device
|
device = ids.device
|
||||||
self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
|
self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
|
||||||
@@ -1219,6 +1224,7 @@ class NextDiT(nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@torch.compiler.disable(reason="complex ops inside")
|
||||||
def precompute_freqs_cis(
|
def precompute_freqs_cis(
|
||||||
dim: List[int],
|
dim: List[int],
|
||||||
end: List[int],
|
end: List[int],
|
||||||
|
|||||||
@@ -3936,6 +3936,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
],
|
],
|
||||||
help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)",
|
help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--activation_memory_budget",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="activation memory budget setting for torch.compile (range: 0~1). Smaller value saves more memory at cost of speed. If set, use --torch_compile without --gradient_checkpointing is recommended. Requires PyTorch 2.4. / torch.compileのactivation memory budget設定(0~1の値)。この値を小さくするとメモリ使用量を節約できますが、処理速度は低下します。この設定を行う場合は、--gradient_checkpointing オプションを指定せずに --torch_compile を使用することをお勧めします。PyTorch 2.4以降が必要です。"
|
||||||
|
)
|
||||||
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sdpa",
|
"--sdpa",
|
||||||
@@ -5468,6 +5474,19 @@ def prepare_accelerator(args: argparse.Namespace):
|
|||||||
if args.torch_compile:
|
if args.torch_compile:
|
||||||
dynamo_backend = args.dynamo_backend
|
dynamo_backend = args.dynamo_backend
|
||||||
|
|
||||||
|
if args.activation_memory_budget is not None: # Note: 0 is a valid value.
|
||||||
|
if 0 <= args.activation_memory_budget <= 1:
|
||||||
|
logger.info(
|
||||||
|
f"set torch compile activation memory budget to {args.activation_memory_budget}"
|
||||||
|
)
|
||||||
|
torch._functorch.config.activation_memory_budget = ( # type: ignore
|
||||||
|
args.activation_memory_budget
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"activation_memory_budget must be between 0 and 1 (inclusive)"
|
||||||
|
)
|
||||||
|
|
||||||
kwargs_handlers = [
|
kwargs_handlers = [
|
||||||
(
|
(
|
||||||
InitProcessGroupKwargs(
|
InitProcessGroupKwargs(
|
||||||
|
|||||||
Reference in New Issue
Block a user