mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Set default discrete_flow_shift to 6.0. Remove default system prompt.
This commit is contained in:
@@ -1,9 +1,19 @@
|
||||
# Copyright Alpha VLLM/Lumina Image 2.0 and contributors
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# References:
|
||||
# GLIDE: https://github.com/openai/glide-text2im
|
||||
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
||||
@@ -13,8 +23,6 @@ import math
|
||||
from typing import List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
@@ -25,6 +33,7 @@ try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
except:
|
||||
# flash_attn may not be available but it is not required
|
||||
pass
|
||||
|
||||
try:
|
||||
@@ -34,6 +43,58 @@ except:
|
||||
|
||||
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
|
||||
|
||||
#############################################################################
|
||||
# RMSNorm #
|
||||
#############################################################################
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
"""
|
||||
Initialize the RMSNorm normalization layer.
|
||||
|
||||
Args:
|
||||
dim (int): The dimension of the input tensor.
|
||||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||
|
||||
Attributes:
|
||||
eps (float): A small value added to the denominator for numerical stability.
|
||||
weight (nn.Parameter): Learnable scaling parameter.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x) -> Tensor:
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
|
||||
"""
|
||||
return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
"""
|
||||
Apply RMSNorm to the input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
"""
|
||||
x_dtype = x.dtype
|
||||
# To handle float8 we need to convert the tensor to float
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return ((x * rrms) * self.weight.float()).to(dtype=x_dtype)
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class LuminaParams:
|
||||
@@ -111,58 +172,6 @@ class GradientCheckpointMixin(nn.Module):
|
||||
return self._forward(*args, **kwargs)
|
||||
|
||||
|
||||
#############################################################################
|
||||
# RMSNorm #
|
||||
#############################################################################
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
"""
|
||||
Initialize the RMSNorm normalization layer.
|
||||
|
||||
Args:
|
||||
dim (int): The dimension of the input tensor.
|
||||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||
|
||||
Attributes:
|
||||
eps (float): A small value added to the denominator for numerical stability.
|
||||
weight (nn.Parameter): Learnable scaling parameter.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x) -> Tensor:
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
|
||||
"""
|
||||
return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
"""
|
||||
Apply RMSNorm to the input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
"""
|
||||
x_dtype = x.dtype
|
||||
# To handle float8 we need to convert the tensor to float
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return ((x * rrms) * self.weight.float()).to(dtype=x_dtype)
|
||||
|
||||
|
||||
def modulate(x, scale):
|
||||
return x * (1 + scale.unsqueeze(1))
|
||||
|
||||
@@ -878,8 +878,8 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--discrete_flow_shift",
|
||||
type=float,
|
||||
default=3.0,
|
||||
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
|
||||
default=6.0,
|
||||
help="Discrete flow shift for the Euler Discrete Scheduler, default is 6.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは6.0。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_flash_attn",
|
||||
@@ -889,6 +889,6 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--system_prompt",
|
||||
type=str,
|
||||
default="You are an assistant designed to generate high-quality images based on user prompts. <Prompt Start> ",
|
||||
default="",
|
||||
help="System prompt to add to the prompt. / プロンプトに追加するシステムプロンプト。",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user