From 554674909a2373378fe6ba2777ac2d98b1906320 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 7 May 2025 23:33:08 -0400 Subject: [PATCH] Fix default initialization --- flux_train_network.py | 23 ++- library/flux_models.py | 66 ++++++- library/flux_train_utils.py | 30 ++++ library/flux_utils.py | 14 ++ library/incremental_pca.py | 338 ++++++++++++++++++++++++++++++++++++ library/network_utils.py | 4 +- networks/lora_flux.py | 2 +- 7 files changed, 469 insertions(+), 8 deletions(-) create mode 100644 library/incremental_pca.py diff --git a/flux_train_network.py b/flux_train_network.py index def44155..38101583 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -141,6 +141,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + # Apply partitioned for Diffusion4k + if args.partitioned_vae: + ae.decoder.partitioned = True + return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model def get_tokenize_strategy(self, args): @@ -360,7 +364,13 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # pack latents and get img_ids packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 - img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + if args.partitioned_vae: + packed_latent_height, packed_latent_width = noisy_model_input.shape[2], noisy_model_input.shape[3] + img_ids = flux_utils.prepare_paritioned_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + else: + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) # get guidance # ensure guidance_scale in args is float @@ -408,7 +418,18 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): ) # unpack latents +<<<<<<< Updated upstream model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) +======= + if args.partitioned_vae: + model_pred = flux_utils.unpack_partitioned_latents(model_pred, packed_latent_height, packed_latent_width) + else: + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + if args.bypass_flux_guidance: + flux_utils.restore_flux_guidance(unet) +>>>>>>> Stashed changes # apply model prediction type model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) diff --git a/library/flux_models.py b/library/flux_models.py index 328ad481..2fbf0902 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -54,6 +54,8 @@ class AutoEncoderParams: z_channels: int scale_factor: float shift_factor: float + stride: int + partitioned: bool def swish(x: Tensor) -> Tensor: @@ -228,6 +230,8 @@ class Decoder(nn.Module): in_channels: int, resolution: int, z_channels: int, + partitioned=False, + stride=1, ): super().__init__() self.ch = ch @@ -236,6 +240,8 @@ class Decoder(nn.Module): self.resolution = resolution self.in_channels = in_channels self.ffactor = 2 ** (self.num_resolutions - 1) + self.stride = stride + self.partitioned = partitioned # compute in_ch_mult, block_in and curr_res at lowest res block_in = ch * ch_mult[self.num_resolutions - 1] @@ -272,7 +278,7 @@ class Decoder(nn.Module): self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) - def forward(self, z: Tensor) -> Tensor: + def forward(self, z: Tensor, partitioned=None) -> Tensor: # z to block_in h = self.conv_in(z) @@ -291,9 +297,55 @@ class Decoder(nn.Module): h = self.up[i_level].upsample(h) # end - h = self.norm_out(h) - h = swish(h) - h = self.conv_out(h) + + # Diffusion4k + partitioned = partitioned if not None else self.partitioned + if self.stride > 1 and partitioned: + h = swish(h) + + overlap_size = 1 # because last conv kernel_size = 3 + res = [] + partitioned_height = h.shape[2] // self.stride + partitioned_width = h.shape[3] // self.stride + + assert self.stride == 2 # only support stride = 2 for now + rows = [] + for i in range(0, h.shape[2], partitioned_height): + row = [] + for j in range(0, h.shape[3], partitioned_width): + partition = h[:,:, max(i - overlap_size, 0) : min(i + partitioned_height + overlap_size, h.shape[2]), max(j - overlap_size, 0) : min(j + partitioned_width + overlap_size, h.shape[3])] + + # for strih + if i==0 and j==0: + partition = torch.nn.functional.pad(partition, (1, 0, 1, 0), "constant", 0) + elif i==0: + partition = torch.nn.functional.pad(partition, (0, 1, 1, 0), "constant", 0) + elif i>0 and j==0: + partition = torch.nn.functional.pad(partition, (1, 0, 0, 1), "constant", 0) + elif i>0 and j>0: + partition = torch.nn.functional.pad(partition, (0, 1, 0, 1), "constant", 0) + + partition = torch.nn.functional.interpolate(partition, scale_factor=self.stride, mode='nearest') + partition = self.conv_out(partition) + partition = partition[:,:,overlap_size:partitioned_height*2+overlap_size,overlap_size:partitioned_width*2+overlap_size] + + row.append(partition) + rows.append(row) + + for row in rows: + res.append(torch.cat(row, dim=3)) + + h = torch.cat(res, dim=2) + # Diffusion4k + elif self.stride > 1: + h = self.norm_out(h) + h = torch.nn.functional.interpolate(h, scale_factor=self.stride, mode='nearest') + h = swish(h) + h = self.conv_out(h) + else: + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) return h @@ -404,6 +456,9 @@ configs = { z_channels=16, scale_factor=0.3611, shift_factor=0.1159, + # Diffusion4k + stride=1, + partitioned=False, ), ), "schnell": ModelSpec( @@ -436,6 +491,9 @@ configs = { z_channels=16, scale_factor=0.3611, shift_factor=0.1159, + # Diffusion4k + stride=1, + partitioned=False, ), ), } diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index e5fb8163..e2a8fd67 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -689,3 +689,33 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): default=3.0, help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", ) +<<<<<<< Updated upstream +======= + parser.add_argument( + "--redux_model_path", + type=str, + help="path to Redux model (*.sft or *.safetensors), should be float16", + ) + parser.add_argument( + "--vision_cond_downsample", + type=int, + default=0, + help="Downsample Redux tokens to the specified grid size (default is 27). Zero disables this feature.", + ) + + parser.add_argument( + "--vision_cond_dropout", + type=float, + default=1.0, + help="Probability of dropout for Redux conditioning.", + ) + # bypass guidance module for flux + parser.add_argument( + "--bypass_flux_guidance" + , action="store_true" + , help="bypass flux guidance module for Flex.1-Alpha Training / Flex.1-Alpha トレーニング用バイパス フラックス ガイダンス モジュール" + ) + parser.add_argument("--proportional_attention", action="store_true", help="Dynamic attention scale with respect to the resolution. Proportional attention to the image sequence length. From URAE paper") + parser.add_argument("--ntk_factor", type=float, default=1.0, help="NTK Factor for increasing the embedding space for RoPE. Defaults to 1.0. 10.0 for 2k/4k images. From URAE paper.") + parser.add_argument("--partitioned_vae", action="store_true", help="Partitioned VAE from Diffusion4k paper") +>>>>>>> Stashed changes diff --git a/library/flux_utils.py b/library/flux_utils.py index 8be1d63e..bb4b4f2c 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -346,6 +346,13 @@ def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_wi img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size) return img_ids +def prepare_paritioned_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int): + img_ids = torch.zeros(packed_latent_height // 2, packed_latent_width // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width // 2)[None, :] + img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + return img_ids + def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor: """ @@ -354,6 +361,13 @@ def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_wid x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) return x +def unpack_partitioned_latents(x: torch.FloatTensor, packed_latent_height: int, packed_latent_width: int) -> torch.FloatTensor: + """ + x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height // 2, w=packed_latent_width // 2, ph=2, pw=2) + return x + def pack_latents(x: torch.Tensor) -> torch.Tensor: """ diff --git a/library/incremental_pca.py b/library/incremental_pca.py new file mode 100644 index 00000000..e9667040 --- /dev/null +++ b/library/incremental_pca.py @@ -0,0 +1,338 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch + + +class IncrementalPCA: + """ + An implementation of Incremental Principal Components Analysis (IPCA) that leverages PyTorch for GPU acceleration. + Adapted from https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/decomposition/_incremental_pca.py + + This class provides methods to fit the model on data incrementally in batches, and to transform new data based on + the principal components learned during the fitting process. + + Args: + n_components (int, optional): Number of components to keep. If `None`, it's set to the minimum of the + number of samples and features. Defaults to None. + copy (bool): If False, input data will be overwritten. Defaults to True. + batch_size (int, optional): The number of samples to use for each batch. Only needed if self.fit is called. + If `None`, it's inferred from the data and set to `5 * n_features`. Defaults to None. + svd_driver (str, optional): name of the cuSOLVER method to be used for torch.linalg.svd. This keyword + argument only works on CUDA inputs. Available options are: None, gesvd, gesvdj, and gesvda. Defaults to + None. + lowrank (bool, optional): Whether to use torch.svd_lowrank instead of torch.linalg.svd which can be faster. + Defaults to False. + lowrank_q (int, optional): For an adequate approximation of n_components, this parameter defaults to + n_components * 2. + lowrank_niter (int, optional): Number of subspace iterations to conduct for torch.svd_lowrank. + Defaults to 4. + lowrank_seed (int, optional): Seed for making results of torch.svd_lowrank reproducible. + """ + + def __init__( + self, + n_components: Optional[int] = None, + copy: Optional[bool] = True, + batch_size: Optional[int] = None, + svd_driver: Optional[str] = None, + lowrank: bool = False, + lowrank_q: Optional[int] = None, + lowrank_niter: int = 4, + lowrank_seed: Optional[int] = None, + ): + self.n_components = n_components + self.copy = copy + self.batch_size = batch_size + self.svd_driver = svd_driver + self.lowrank = lowrank + self.lowrank_q = lowrank_q + self.lowrank_niter = lowrank_niter + self.lowrank_seed = lowrank_seed + + self.n_features_ = None + + if self.lowrank: + self._validate_lowrank_params() + + def _validate_lowrank_params(self): + if self.lowrank_q is None: + if self.n_components is None: + raise ValueError("n_components must be specified when using lowrank mode with lowrank_q=None.") + self.lowrank_q = self.n_components * 2 + elif self.n_components is not None and self.lowrank_q < self.n_components: + raise ValueError("lowrank_q must be greater than or equal to n_components.") + + def _svd_fn_full(self, X): + return torch.linalg.svd(X, full_matrices=False, driver=self.svd_driver) + + def _svd_fn_lowrank(self, X): + seed_enabled = self.lowrank_seed is not None + with torch.random.fork_rng(enabled=seed_enabled): + if seed_enabled: + torch.manual_seed(self.lowrank_seed) + U, S, V = torch.svd_lowrank(X, q=self.lowrank_q, niter=self.lowrank_niter) + return U, S, V.mH + + def _validate_data(self, X) -> torch.Tensor: + """ + Validates and converts the input data `X` to the appropriate tensor format. + + Args: + X (torch.Tensor): Input data. + + Returns: + torch.Tensor: Converted to appropriate format. + """ + valid_dtypes = [torch.float32, torch.float64] + + if not isinstance(X, torch.Tensor): + X = torch.tensor(X, dtype=torch.float32) + elif self.copy: + X = X.clone() + + n_samples, n_features = X.shape + if self.n_components is None: + pass + elif self.n_components > n_features: + raise ValueError( + f"n_components={self.n_components} invalid for n_features={n_features}, " + "need more rows than columns for IncrementalPCA processing." + ) + elif self.n_components > n_samples: + raise ValueError( + f"n_components={self.n_components} must be less or equal to the batch number of samples {n_samples}" + ) + + if X.dtype not in valid_dtypes: + X = X.to(torch.float32) + + return X + + @staticmethod + def _incremental_mean_and_var( + X, last_mean, last_variance, last_sample_count + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes the incremental mean and variance for the data `X`. + + Args: + X (torch.Tensor): The batch input data tensor with shape (n_samples, n_features). + last_mean (torch.Tensor): The previous mean tensor with shape (n_features,). + last_variance (torch.Tensor): The previous variance tensor with shape (n_features,). + last_sample_count (torch.Tensor): The count tensor of samples processed before the current batch. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Updated mean, variance tensors, and total sample count. + """ + if X.shape[0] == 0: + return last_mean, last_variance, last_sample_count + + if last_sample_count > 0: + if last_mean is None: + raise ValueError("last_mean should not be None if last_sample_count > 0.") + if last_variance is None: + raise ValueError("last_variance should not be None if last_sample_count > 0.") + + new_sample_count = torch.tensor([X.shape[0]], device=X.device) + updated_sample_count = last_sample_count + new_sample_count + + if last_mean is None: + last_sum = torch.zeros(X.shape[1], dtype=torch.float64, device=X.device) + else: + last_sum = last_mean * last_sample_count + + new_sum = X.sum(dim=0, dtype=torch.float64) + + updated_mean = (last_sum + new_sum) / updated_sample_count + + T = new_sum / new_sample_count + temp = X - T + correction = temp.sum(dim=0, dtype=torch.float64).square() + temp.square_() + new_unnormalized_variance = temp.sum(dim=0, dtype=torch.float64) + new_unnormalized_variance -= correction / new_sample_count + if last_variance is None: + updated_variance = new_unnormalized_variance / updated_sample_count + else: + last_unnormalized_variance = last_variance * last_sample_count + last_over_new_count = last_sample_count.double() / new_sample_count + updated_unnormalized_variance = ( + last_unnormalized_variance + + new_unnormalized_variance + + last_over_new_count / updated_sample_count * (last_sum / last_over_new_count - new_sum).square() + ) + updated_variance = updated_unnormalized_variance / updated_sample_count + + return updated_mean, updated_variance, updated_sample_count + + @staticmethod + def _svd_flip(u, v, u_based_decision=True) -> tuple[torch.Tensor, torch.Tensor]: + """ + Adjusts the signs of the singular vectors from the SVD decomposition for deterministic output. + + This method ensures that the output remains consistent across different runs. + + Args: + u (torch.Tensor): Left singular vectors tensor. + v (torch.Tensor): Right singular vectors tensor. + u_based_decision (bool, optional): If True, uses the left singular vectors to determine the sign flipping. + Defaults to True. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Adjusted left and right singular vectors tensors. + """ + if u_based_decision: + max_abs_cols = torch.argmax(torch.abs(u), dim=0) + signs = torch.sign(u[max_abs_cols, range(u.shape[1])]) + else: + max_abs_rows = torch.argmax(torch.abs(v), dim=1) + signs = torch.sign(v[range(v.shape[0]), max_abs_rows]) + u *= signs[: u.shape[1]].view(1, -1) + v *= signs.view(-1, 1) + return u, v + + def fit(self, X: torch.Tensor, check_input=True): + """ + Fits the model with data `X` using minibatches of size `batch_size`. + + Args: + X (torch.Tensor): The input data tensor with shape (n_samples, n_features). + check_input (bool, optional): If True, validates the input. Defaults to True. + + Returns: + IncrementalPCA: The fitted IPCA model. + """ + if check_input: + X = self._validate_data(X) + n_samples, n_features = X.shape + if self.batch_size is None: + self.batch_size = 5 * n_features + + for batch in self.gen_batches(n_samples, self.batch_size, min_batch_size=self.n_components or 0): + self.partial_fit(X[batch], check_input=False) + + return self + + def partial_fit(self, X, check_input=True): + """ + Incrementally fits the model with batch data `X`. + + Args: + X (torch.Tensor): The batch input data tensor with shape (n_samples, n_features). + check_input (bool, optional): If True, validates the input. Defaults to True. + + Returns: + IncrementalPCA: The updated IPCA model after processing the batch. + """ + first_pass = not hasattr(self, "components_") + + if check_input: + X = self._validate_data(X) + n_samples, n_features = X.shape + + # Initialize attributes to avoid errors during the first call to partial_fit + if first_pass: + self.mean_ = None # Will be initialized properly in _incremental_mean_and_var based on data dimensions + self.var_ = None # Will be initialized properly in _incremental_mean_and_var based on data dimensions + self.n_samples_seen_ = torch.tensor([0], device=X.device) + self.n_features_ = n_features + if not self.n_components: + self.n_components = min(n_samples, n_features) + + if n_features != self.n_features_: + raise ValueError( + "Number of features of the new batch does not match the number of features of the first batch." + ) + + col_mean, col_var, n_total_samples = self._incremental_mean_and_var( + X, self.mean_, self.var_, self.n_samples_seen_ + ) + + if first_pass: + X -= col_mean + else: + col_batch_mean = torch.mean(X, dim=0) + X -= col_batch_mean + mean_correction_factor = torch.sqrt((self.n_samples_seen_.double() / n_total_samples) * n_samples) + mean_correction = mean_correction_factor * (self.mean_ - col_batch_mean) + X = torch.vstack( + ( + self.singular_values_.view((-1, 1)) * self.components_, + X, + mean_correction, + ) + ) + + if self.lowrank: + U, S, Vt = self._svd_fn_lowrank(X) + else: + U, S, Vt = self._svd_fn_full(X) + U, Vt = self._svd_flip(U, Vt, u_based_decision=False) + explained_variance = S**2 / (n_total_samples - 1) + explained_variance_ratio = S**2 / torch.sum(col_var * n_total_samples) + + self.n_samples_seen_ = n_total_samples + self.components_ = Vt[: self.n_components] + self.singular_values_ = S[: self.n_components] + self.mean_ = col_mean + self.var_ = col_var + self.explained_variance_ = explained_variance[: self.n_components] + self.explained_variance_ratio_ = explained_variance_ratio[: self.n_components] + if self.n_components not in (n_samples, n_features): + self.noise_variance_ = explained_variance[self.n_components :].mean() + else: + self.noise_variance_ = torch.tensor(0.0, device=X.device) + return self + + def transform(self, X) -> torch.Tensor: + """ + Applies dimensionality reduction to `X`. + + The input data `X` is projected on the first principal components previously extracted from a training set. + + Args: + X (torch.Tensor): New data tensor with shape (n_samples, n_features) to be transformed. + + Returns: + torch.Tensor: Transformed data tensor with shape (n_samples, n_components). + """ + X = X - self.mean_ + return torch.mm(X.double(), self.components_.T).to(X.dtype) + + @staticmethod + def gen_batches(n: int, batch_size: int, min_batch_size: int = 0): + """Generator to create slices containing `batch_size` elements from 0 to `n`. + + The last slice may contain less than `batch_size` elements, when `batch_size` does not divide `n`. + + Args: + n (int): Size of the sequence. + batch_size (int): Number of elements in each batch. + min_batch_size (int, optional): Minimum number of elements in each batch. Defaults to 0. + + Yields: + slice: A slice of `batch_size` elements. + """ + start = 0 + for _ in range(int(n // batch_size)): + end = start + batch_size + if end + min_batch_size > n: + continue + yield slice(start, end) + start = end + if start < n: + yield slice(start, n) diff --git a/library/network_utils.py b/library/network_utils.py index ca9f836e..65654740 100644 --- a/library/network_utils.py +++ b/library/network_utils.py @@ -11,7 +11,7 @@ class InitializeParams: """Parameters for initialization methods (PiSSA, URAE)""" use_ipca: bool = False - use_lowrank: bool = True + use_lowrank: bool = False lowrank_q: Optional[int] = None lowrank_niter: int = 4 lowrank_seed: Optional[int] = None @@ -187,7 +187,7 @@ def initialize_pissa( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, use_ipca: bool = False, - use_lowrank: bool = True, + use_lowrank: bool = False, lowrank_q: Optional[int] = None, lowrank_niter: int = 4, lowrank_seed: Optional[int] = None, diff --git a/networks/lora_flux.py b/networks/lora_flux.py index e6780e21..e8cd528f 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -19,7 +19,7 @@ from tqdm import tqdm import re from library.utils import setup_logging from library.device_utils import clean_memory_on_device -from library.network_utils import initialize_lora, initialize_pissa, initialize_urae +from library.network_utils import initialize_lora, initialize_pissa, initialize_urae, initialize_parse_opts setup_logging() import logging