diff --git a/library/lumina_models.py b/library/lumina_models.py index f819b68f..365453c1 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -1,9 +1,19 @@ +# Copyright Alpha VLLM/Lumina Image 2.0 and contributors # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -------------------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # References: # GLIDE: https://github.com/openai/glide-text2im # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py @@ -13,8 +23,6 @@ import math from typing import List, Optional, Tuple from dataclasses import dataclass -from einops import rearrange - import torch from torch import Tensor from torch.utils.checkpoint import checkpoint @@ -25,6 +33,7 @@ try: from flash_attn import flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa except: + # flash_attn may not be available but it is not required pass try: @@ -34,6 +43,58 @@ except: warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") + ############################################################################# + # RMSNorm # + ############################################################################# + + class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x) -> Tensor: + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor): + """ + Apply RMSNorm to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + """ + x_dtype = x.dtype + # To handle float8 we need to convert the tensor to float + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return ((x * rrms) * self.weight.float()).to(dtype=x_dtype) + + @dataclass class LuminaParams: @@ -111,58 +172,6 @@ class GradientCheckpointMixin(nn.Module): return self._forward(*args, **kwargs) -############################################################################# -# RMSNorm # -############################################################################# - - -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - """ - Initialize the RMSNorm normalization layer. - - Args: - dim (int): The dimension of the input tensor. - eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. - - Attributes: - eps (float): A small value added to the denominator for numerical stability. - weight (nn.Parameter): Learnable scaling parameter. - - """ - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x) -> Tensor: - """ - Apply the RMSNorm normalization to the input tensor. - - Args: - x (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: The normalized tensor. - - """ - return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x: Tensor): - """ - Apply RMSNorm to the input tensor. - - Args: - x (torch.Tensor): The input tensor. - - Returns: - torch.Tensor: The normalized tensor. - """ - x_dtype = x.dtype - # To handle float8 we need to convert the tensor to float - x = x.float() - rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) - return ((x * rrms) * self.weight.float()).to(dtype=x_dtype) - def modulate(x, scale): return x * (1 + scale.unsqueeze(1)) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 487ae2f9..172d09ea 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -878,8 +878,8 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--discrete_flow_shift", type=float, - default=3.0, - help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", + default=6.0, + help="Discrete flow shift for the Euler Discrete Scheduler, default is 6.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは6.0。", ) parser.add_argument( "--use_flash_attn", @@ -889,6 +889,6 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--system_prompt", type=str, - default="You are an assistant designed to generate high-quality images based on user prompts. ", + default="", help="System prompt to add to the prompt. / プロンプトに追加するシステムプロンプト。", )