Set default discrete_flow_shift to 6.0. Remove default system prompt.

This commit is contained in:
rockerBOO
2025-02-23 18:01:09 -05:00
parent 42a801514c
commit ba725a84e9
2 changed files with 70 additions and 61 deletions

View File

@@ -1,9 +1,19 @@
# Copyright Alpha VLLM/Lumina Image 2.0 and contributors
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# References:
# GLIDE: https://github.com/openai/glide-text2im
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
@@ -13,8 +23,6 @@ import math
from typing import List, Optional, Tuple
from dataclasses import dataclass
from einops import rearrange
import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
@@ -25,6 +33,7 @@ try:
from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
except:
# flash_attn may not be available but it is not required
pass
try:
@@ -34,6 +43,58 @@ except:
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
#############################################################################
# RMSNorm #
#############################################################################
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x) -> Tensor:
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: Tensor):
"""
Apply RMSNorm to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
x_dtype = x.dtype
# To handle float8 we need to convert the tensor to float
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return ((x * rrms) * self.weight.float()).to(dtype=x_dtype)
@dataclass
class LuminaParams:
@@ -111,58 +172,6 @@ class GradientCheckpointMixin(nn.Module):
return self._forward(*args, **kwargs)
#############################################################################
# RMSNorm #
#############################################################################
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x) -> Tensor:
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: Tensor):
"""
Apply RMSNorm to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
x_dtype = x.dtype
# To handle float8 we need to convert the tensor to float
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return ((x * rrms) * self.weight.float()).to(dtype=x_dtype)
def modulate(x, scale):
return x * (1 + scale.unsqueeze(1))

View File

@@ -878,8 +878,8 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--discrete_flow_shift",
type=float,
default=3.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
default=6.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 6.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは6.0。",
)
parser.add_argument(
"--use_flash_attn",
@@ -889,6 +889,6 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--system_prompt",
type=str,
default="You are an assistant designed to generate high-quality images based on user prompts. <Prompt Start> ",
default="",
help="System prompt to add to the prompt. / プロンプトに追加するシステムプロンプト。",
)