Fix default initialization

This commit is contained in:
rockerBOO
2025-05-07 23:33:08 -04:00
parent ef8371243b
commit 554674909a
7 changed files with 469 additions and 8 deletions

View File

@@ -141,6 +141,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
# Apply partitioned for Diffusion4k
if args.partitioned_vae:
ae.decoder.partitioned = True
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
def get_tokenize_strategy(self, args):
@@ -360,7 +364,13 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# pack latents and get img_ids
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
if args.partitioned_vae:
packed_latent_height, packed_latent_width = noisy_model_input.shape[2], noisy_model_input.shape[3]
img_ids = flux_utils.prepare_paritioned_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
else:
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
# get guidance
# ensure guidance_scale in args is float
@@ -408,7 +418,18 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
)
# unpack latents
<<<<<<< Updated upstream
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
=======
if args.partitioned_vae:
model_pred = flux_utils.unpack_partitioned_latents(model_pred, packed_latent_height, packed_latent_width)
else:
# unpack latents
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
if args.bypass_flux_guidance:
flux_utils.restore_flux_guidance(unet)
>>>>>>> Stashed changes
# apply model prediction type
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)

View File

@@ -54,6 +54,8 @@ class AutoEncoderParams:
z_channels: int
scale_factor: float
shift_factor: float
stride: int
partitioned: bool
def swish(x: Tensor) -> Tensor:
@@ -228,6 +230,8 @@ class Decoder(nn.Module):
in_channels: int,
resolution: int,
z_channels: int,
partitioned=False,
stride=1,
):
super().__init__()
self.ch = ch
@@ -236,6 +240,8 @@ class Decoder(nn.Module):
self.resolution = resolution
self.in_channels = in_channels
self.ffactor = 2 ** (self.num_resolutions - 1)
self.stride = stride
self.partitioned = partitioned
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
@@ -272,7 +278,7 @@ class Decoder(nn.Module):
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z: Tensor) -> Tensor:
def forward(self, z: Tensor, partitioned=None) -> Tensor:
# z to block_in
h = self.conv_in(z)
@@ -291,9 +297,55 @@ class Decoder(nn.Module):
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
# Diffusion4k
partitioned = partitioned if not None else self.partitioned
if self.stride > 1 and partitioned:
h = swish(h)
overlap_size = 1 # because last conv kernel_size = 3
res = []
partitioned_height = h.shape[2] // self.stride
partitioned_width = h.shape[3] // self.stride
assert self.stride == 2 # only support stride = 2 for now
rows = []
for i in range(0, h.shape[2], partitioned_height):
row = []
for j in range(0, h.shape[3], partitioned_width):
partition = h[:,:, max(i - overlap_size, 0) : min(i + partitioned_height + overlap_size, h.shape[2]), max(j - overlap_size, 0) : min(j + partitioned_width + overlap_size, h.shape[3])]
# for strih
if i==0 and j==0:
partition = torch.nn.functional.pad(partition, (1, 0, 1, 0), "constant", 0)
elif i==0:
partition = torch.nn.functional.pad(partition, (0, 1, 1, 0), "constant", 0)
elif i>0 and j==0:
partition = torch.nn.functional.pad(partition, (1, 0, 0, 1), "constant", 0)
elif i>0 and j>0:
partition = torch.nn.functional.pad(partition, (0, 1, 0, 1), "constant", 0)
partition = torch.nn.functional.interpolate(partition, scale_factor=self.stride, mode='nearest')
partition = self.conv_out(partition)
partition = partition[:,:,overlap_size:partitioned_height*2+overlap_size,overlap_size:partitioned_width*2+overlap_size]
row.append(partition)
rows.append(row)
for row in rows:
res.append(torch.cat(row, dim=3))
h = torch.cat(res, dim=2)
# Diffusion4k
elif self.stride > 1:
h = self.norm_out(h)
h = torch.nn.functional.interpolate(h, scale_factor=self.stride, mode='nearest')
h = swish(h)
h = self.conv_out(h)
else:
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
@@ -404,6 +456,9 @@ configs = {
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
# Diffusion4k
stride=1,
partitioned=False,
),
),
"schnell": ModelSpec(
@@ -436,6 +491,9 @@ configs = {
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
# Diffusion4k
stride=1,
partitioned=False,
),
),
}

View File

@@ -689,3 +689,33 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
default=3.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
)
<<<<<<< Updated upstream
=======
parser.add_argument(
"--redux_model_path",
type=str,
help="path to Redux model (*.sft or *.safetensors), should be float16",
)
parser.add_argument(
"--vision_cond_downsample",
type=int,
default=0,
help="Downsample Redux tokens to the specified grid size (default is 27). Zero disables this feature.",
)
parser.add_argument(
"--vision_cond_dropout",
type=float,
default=1.0,
help="Probability of dropout for Redux conditioning.",
)
# bypass guidance module for flux
parser.add_argument(
"--bypass_flux_guidance"
, action="store_true"
, help="bypass flux guidance module for Flex.1-Alpha Training / Flex.1-Alpha トレーニング用バイパス フラックス ガイダンス モジュール"
)
parser.add_argument("--proportional_attention", action="store_true", help="Dynamic attention scale with respect to the resolution. Proportional attention to the image sequence length. From URAE paper")
parser.add_argument("--ntk_factor", type=float, default=1.0, help="NTK Factor for increasing the embedding space for RoPE. Defaults to 1.0. 10.0 for 2k/4k images. From URAE paper.")
parser.add_argument("--partitioned_vae", action="store_true", help="Partitioned VAE from Diffusion4k paper")
>>>>>>> Stashed changes

View File

@@ -346,6 +346,13 @@ def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_wi
img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
return img_ids
def prepare_paritioned_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int):
img_ids = torch.zeros(packed_latent_height // 2, packed_latent_width // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width // 2)[None, :]
img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
return img_ids
def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
"""
@@ -354,6 +361,13 @@ def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_wid
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
return x
def unpack_partitioned_latents(x: torch.FloatTensor, packed_latent_height: int, packed_latent_width: int) -> torch.FloatTensor:
"""
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
"""
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height // 2, w=packed_latent_width // 2, ph=2, pw=2)
return x
def pack_latents(x: torch.Tensor) -> torch.Tensor:
"""

338
library/incremental_pca.py Normal file
View File

@@ -0,0 +1,338 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
class IncrementalPCA:
"""
An implementation of Incremental Principal Components Analysis (IPCA) that leverages PyTorch for GPU acceleration.
Adapted from https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/decomposition/_incremental_pca.py
This class provides methods to fit the model on data incrementally in batches, and to transform new data based on
the principal components learned during the fitting process.
Args:
n_components (int, optional): Number of components to keep. If `None`, it's set to the minimum of the
number of samples and features. Defaults to None.
copy (bool): If False, input data will be overwritten. Defaults to True.
batch_size (int, optional): The number of samples to use for each batch. Only needed if self.fit is called.
If `None`, it's inferred from the data and set to `5 * n_features`. Defaults to None.
svd_driver (str, optional): name of the cuSOLVER method to be used for torch.linalg.svd. This keyword
argument only works on CUDA inputs. Available options are: None, gesvd, gesvdj, and gesvda. Defaults to
None.
lowrank (bool, optional): Whether to use torch.svd_lowrank instead of torch.linalg.svd which can be faster.
Defaults to False.
lowrank_q (int, optional): For an adequate approximation of n_components, this parameter defaults to
n_components * 2.
lowrank_niter (int, optional): Number of subspace iterations to conduct for torch.svd_lowrank.
Defaults to 4.
lowrank_seed (int, optional): Seed for making results of torch.svd_lowrank reproducible.
"""
def __init__(
self,
n_components: Optional[int] = None,
copy: Optional[bool] = True,
batch_size: Optional[int] = None,
svd_driver: Optional[str] = None,
lowrank: bool = False,
lowrank_q: Optional[int] = None,
lowrank_niter: int = 4,
lowrank_seed: Optional[int] = None,
):
self.n_components = n_components
self.copy = copy
self.batch_size = batch_size
self.svd_driver = svd_driver
self.lowrank = lowrank
self.lowrank_q = lowrank_q
self.lowrank_niter = lowrank_niter
self.lowrank_seed = lowrank_seed
self.n_features_ = None
if self.lowrank:
self._validate_lowrank_params()
def _validate_lowrank_params(self):
if self.lowrank_q is None:
if self.n_components is None:
raise ValueError("n_components must be specified when using lowrank mode with lowrank_q=None.")
self.lowrank_q = self.n_components * 2
elif self.n_components is not None and self.lowrank_q < self.n_components:
raise ValueError("lowrank_q must be greater than or equal to n_components.")
def _svd_fn_full(self, X):
return torch.linalg.svd(X, full_matrices=False, driver=self.svd_driver)
def _svd_fn_lowrank(self, X):
seed_enabled = self.lowrank_seed is not None
with torch.random.fork_rng(enabled=seed_enabled):
if seed_enabled:
torch.manual_seed(self.lowrank_seed)
U, S, V = torch.svd_lowrank(X, q=self.lowrank_q, niter=self.lowrank_niter)
return U, S, V.mH
def _validate_data(self, X) -> torch.Tensor:
"""
Validates and converts the input data `X` to the appropriate tensor format.
Args:
X (torch.Tensor): Input data.
Returns:
torch.Tensor: Converted to appropriate format.
"""
valid_dtypes = [torch.float32, torch.float64]
if not isinstance(X, torch.Tensor):
X = torch.tensor(X, dtype=torch.float32)
elif self.copy:
X = X.clone()
n_samples, n_features = X.shape
if self.n_components is None:
pass
elif self.n_components > n_features:
raise ValueError(
f"n_components={self.n_components} invalid for n_features={n_features}, "
"need more rows than columns for IncrementalPCA processing."
)
elif self.n_components > n_samples:
raise ValueError(
f"n_components={self.n_components} must be less or equal to the batch number of samples {n_samples}"
)
if X.dtype not in valid_dtypes:
X = X.to(torch.float32)
return X
@staticmethod
def _incremental_mean_and_var(
X, last_mean, last_variance, last_sample_count
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the incremental mean and variance for the data `X`.
Args:
X (torch.Tensor): The batch input data tensor with shape (n_samples, n_features).
last_mean (torch.Tensor): The previous mean tensor with shape (n_features,).
last_variance (torch.Tensor): The previous variance tensor with shape (n_features,).
last_sample_count (torch.Tensor): The count tensor of samples processed before the current batch.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Updated mean, variance tensors, and total sample count.
"""
if X.shape[0] == 0:
return last_mean, last_variance, last_sample_count
if last_sample_count > 0:
if last_mean is None:
raise ValueError("last_mean should not be None if last_sample_count > 0.")
if last_variance is None:
raise ValueError("last_variance should not be None if last_sample_count > 0.")
new_sample_count = torch.tensor([X.shape[0]], device=X.device)
updated_sample_count = last_sample_count + new_sample_count
if last_mean is None:
last_sum = torch.zeros(X.shape[1], dtype=torch.float64, device=X.device)
else:
last_sum = last_mean * last_sample_count
new_sum = X.sum(dim=0, dtype=torch.float64)
updated_mean = (last_sum + new_sum) / updated_sample_count
T = new_sum / new_sample_count
temp = X - T
correction = temp.sum(dim=0, dtype=torch.float64).square()
temp.square_()
new_unnormalized_variance = temp.sum(dim=0, dtype=torch.float64)
new_unnormalized_variance -= correction / new_sample_count
if last_variance is None:
updated_variance = new_unnormalized_variance / updated_sample_count
else:
last_unnormalized_variance = last_variance * last_sample_count
last_over_new_count = last_sample_count.double() / new_sample_count
updated_unnormalized_variance = (
last_unnormalized_variance
+ new_unnormalized_variance
+ last_over_new_count / updated_sample_count * (last_sum / last_over_new_count - new_sum).square()
)
updated_variance = updated_unnormalized_variance / updated_sample_count
return updated_mean, updated_variance, updated_sample_count
@staticmethod
def _svd_flip(u, v, u_based_decision=True) -> tuple[torch.Tensor, torch.Tensor]:
"""
Adjusts the signs of the singular vectors from the SVD decomposition for deterministic output.
This method ensures that the output remains consistent across different runs.
Args:
u (torch.Tensor): Left singular vectors tensor.
v (torch.Tensor): Right singular vectors tensor.
u_based_decision (bool, optional): If True, uses the left singular vectors to determine the sign flipping.
Defaults to True.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Adjusted left and right singular vectors tensors.
"""
if u_based_decision:
max_abs_cols = torch.argmax(torch.abs(u), dim=0)
signs = torch.sign(u[max_abs_cols, range(u.shape[1])])
else:
max_abs_rows = torch.argmax(torch.abs(v), dim=1)
signs = torch.sign(v[range(v.shape[0]), max_abs_rows])
u *= signs[: u.shape[1]].view(1, -1)
v *= signs.view(-1, 1)
return u, v
def fit(self, X: torch.Tensor, check_input=True):
"""
Fits the model with data `X` using minibatches of size `batch_size`.
Args:
X (torch.Tensor): The input data tensor with shape (n_samples, n_features).
check_input (bool, optional): If True, validates the input. Defaults to True.
Returns:
IncrementalPCA: The fitted IPCA model.
"""
if check_input:
X = self._validate_data(X)
n_samples, n_features = X.shape
if self.batch_size is None:
self.batch_size = 5 * n_features
for batch in self.gen_batches(n_samples, self.batch_size, min_batch_size=self.n_components or 0):
self.partial_fit(X[batch], check_input=False)
return self
def partial_fit(self, X, check_input=True):
"""
Incrementally fits the model with batch data `X`.
Args:
X (torch.Tensor): The batch input data tensor with shape (n_samples, n_features).
check_input (bool, optional): If True, validates the input. Defaults to True.
Returns:
IncrementalPCA: The updated IPCA model after processing the batch.
"""
first_pass = not hasattr(self, "components_")
if check_input:
X = self._validate_data(X)
n_samples, n_features = X.shape
# Initialize attributes to avoid errors during the first call to partial_fit
if first_pass:
self.mean_ = None # Will be initialized properly in _incremental_mean_and_var based on data dimensions
self.var_ = None # Will be initialized properly in _incremental_mean_and_var based on data dimensions
self.n_samples_seen_ = torch.tensor([0], device=X.device)
self.n_features_ = n_features
if not self.n_components:
self.n_components = min(n_samples, n_features)
if n_features != self.n_features_:
raise ValueError(
"Number of features of the new batch does not match the number of features of the first batch."
)
col_mean, col_var, n_total_samples = self._incremental_mean_and_var(
X, self.mean_, self.var_, self.n_samples_seen_
)
if first_pass:
X -= col_mean
else:
col_batch_mean = torch.mean(X, dim=0)
X -= col_batch_mean
mean_correction_factor = torch.sqrt((self.n_samples_seen_.double() / n_total_samples) * n_samples)
mean_correction = mean_correction_factor * (self.mean_ - col_batch_mean)
X = torch.vstack(
(
self.singular_values_.view((-1, 1)) * self.components_,
X,
mean_correction,
)
)
if self.lowrank:
U, S, Vt = self._svd_fn_lowrank(X)
else:
U, S, Vt = self._svd_fn_full(X)
U, Vt = self._svd_flip(U, Vt, u_based_decision=False)
explained_variance = S**2 / (n_total_samples - 1)
explained_variance_ratio = S**2 / torch.sum(col_var * n_total_samples)
self.n_samples_seen_ = n_total_samples
self.components_ = Vt[: self.n_components]
self.singular_values_ = S[: self.n_components]
self.mean_ = col_mean
self.var_ = col_var
self.explained_variance_ = explained_variance[: self.n_components]
self.explained_variance_ratio_ = explained_variance_ratio[: self.n_components]
if self.n_components not in (n_samples, n_features):
self.noise_variance_ = explained_variance[self.n_components :].mean()
else:
self.noise_variance_ = torch.tensor(0.0, device=X.device)
return self
def transform(self, X) -> torch.Tensor:
"""
Applies dimensionality reduction to `X`.
The input data `X` is projected on the first principal components previously extracted from a training set.
Args:
X (torch.Tensor): New data tensor with shape (n_samples, n_features) to be transformed.
Returns:
torch.Tensor: Transformed data tensor with shape (n_samples, n_components).
"""
X = X - self.mean_
return torch.mm(X.double(), self.components_.T).to(X.dtype)
@staticmethod
def gen_batches(n: int, batch_size: int, min_batch_size: int = 0):
"""Generator to create slices containing `batch_size` elements from 0 to `n`.
The last slice may contain less than `batch_size` elements, when `batch_size` does not divide `n`.
Args:
n (int): Size of the sequence.
batch_size (int): Number of elements in each batch.
min_batch_size (int, optional): Minimum number of elements in each batch. Defaults to 0.
Yields:
slice: A slice of `batch_size` elements.
"""
start = 0
for _ in range(int(n // batch_size)):
end = start + batch_size
if end + min_batch_size > n:
continue
yield slice(start, end)
start = end
if start < n:
yield slice(start, n)

View File

@@ -11,7 +11,7 @@ class InitializeParams:
"""Parameters for initialization methods (PiSSA, URAE)"""
use_ipca: bool = False
use_lowrank: bool = True
use_lowrank: bool = False
lowrank_q: Optional[int] = None
lowrank_niter: int = 4
lowrank_seed: Optional[int] = None
@@ -187,7 +187,7 @@ def initialize_pissa(
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
use_ipca: bool = False,
use_lowrank: bool = True,
use_lowrank: bool = False,
lowrank_q: Optional[int] = None,
lowrank_niter: int = 4,
lowrank_seed: Optional[int] = None,

View File

@@ -19,7 +19,7 @@ from tqdm import tqdm
import re
from library.utils import setup_logging
from library.device_utils import clean_memory_on_device
from library.network_utils import initialize_lora, initialize_pissa, initialize_urae
from library.network_utils import initialize_lora, initialize_pissa, initialize_urae, initialize_parse_opts
setup_logging()
import logging