diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 00000000..3b8943c3 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,3 @@ +# These are supported funding model platforms + +github: kohya-ss diff --git a/README.md b/README.md index ae417d05..497969ab 100644 --- a/README.md +++ b/README.md @@ -9,11 +9,26 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`. + - [FLUX.1 training](#flux1-training) - [SD3 training](#sd3-training) ### Recent Updates +May 1, 2025: +- The error when training FLUX.1 with mixed precision in flux_train.py with DeepSpeed enabled has been resolved. Thanks to sharlynxy for PR [#2060](https://github.com/kohya-ss/sd-scripts/pull/2060). Please refer to the PR for details. + - If you enable DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`. + +Apr 27, 2025: +- FLUX.1 training now supports CFG scale in the sample generation during training. Please use `--g` option, to specify the CFG scale (note that `--l` is used as the embedded guidance scale.) PR [#2064](https://github.com/kohya-ss/sd-scripts/pull/2064). + - See [here](#sample-image-generation-during-training) for details. + - If you have any issues with this, please let us know. + +Apr 6, 2025: +- IP noise gamma has been enabled in FLUX.1. Thanks to rockerBOO for PR [#1992](https://github.com/kohya-ss/sd-scripts/pull/1992). See the PR for details. + - `--ip_noise_gamma` and `--ip_noise_gamma_random_strength` are available. + Mar 30, 2025: - LoRA-GGPO is added for FLUX.1 LoRA training. Thank you to rockerBOO for PR [#1974](https://github.com/kohya-ss/sd-scripts/pull/1974). - Specify `--network_args ggpo_sigma=0.03 ggpo_beta=0.01` in the command line or `network_args = ["ggpo_sigma=0.03", "ggpo_beta=0.01"]` in .toml file. See PR for details. @@ -866,6 +881,14 @@ Note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o (Single GPU with id `0` will be used.) +## DeepSpeed installation (experimental, Linux or WSL2 only) + +To install DeepSpeed, run the following command in your activated virtual environment: + +```bash +pip install deepspeed==0.16.7 +``` + ## Upgrade When a new release comes out you can upgrade your repo with the following command: @@ -1340,11 +1363,13 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used. - * `--n` Negative prompt up to the next option. + * `--n` Negative prompt up to the next option. Ignored when CFG scale is `1.0`. * `--w` Specifies the width of the generated image. * `--h` Specifies the height of the generated image. * `--d` Specifies the seed of the generated image. * `--l` Specifies the CFG scale of the generated image. + * In guidance distillation models like FLUX.1, this value is used as the embedded guidance scale for backward compatibility. + * `--g` Specifies the CFG scale for the models with embedded guidance scale. The default is `1.0`, `1.0` means no CFG. In general, should not be changed unless you train the un-distilled FLUX.1 models. * `--s` Specifies the number of steps in the generation. The prompt weighting such as `( )` and `[ ]` are working. diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 406f12f2..d41b8a11 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -97,15 +97,19 @@ def main(args): else: for file in SUB_DIR_FILES: hf_hub_download( - args.repo_id, - file, + repo_id=args.repo_id, + filename=file, subfolder=SUB_DIR, - cache_dir=os.path.join(model_location, SUB_DIR), + local_dir=os.path.join(model_location, SUB_DIR), force_download=True, - force_filename=file, ) for file in files: - hf_hub_download(args.repo_id, file, cache_dir=model_location, force_download=True, force_filename=file) + hf_hub_download( + repo_id=args.repo_id, + filename=file, + local_dir=model_location, + force_download=True, + ) else: logger.info("using existing wd14 tagger model") diff --git a/library/deepspeed_utils.py b/library/deepspeed_utils.py index 99a7b2b3..a8a05c3a 100644 --- a/library/deepspeed_utils.py +++ b/library/deepspeed_utils.py @@ -5,6 +5,8 @@ from accelerate import DeepSpeedPlugin, Accelerator from .utils import setup_logging +from .device_utils import get_preferred_device + setup_logging() import logging @@ -94,6 +96,7 @@ def prepare_deepspeed_plugin(args: argparse.Namespace): deepspeed_plugin.deepspeed_config["train_batch_size"] = ( args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"]) ) + deepspeed_plugin.set_mixed_precision(args.mixed_precision) if args.mixed_precision.lower() == "fp16": deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow. @@ -122,18 +125,56 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models): class DeepSpeedWrapper(torch.nn.Module): def __init__(self, **kw_models) -> None: super().__init__() + self.models = torch.nn.ModuleDict() + + wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no" for key, model in kw_models.items(): if isinstance(model, list): model = torch.nn.ModuleList(model) + + if wrap_model_forward_with_torch_autocast: + model = self.__wrap_model_with_torch_autocast(model) + assert isinstance( model, torch.nn.Module ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}" + self.models.update(torch.nn.ModuleDict({key: model})) + def __wrap_model_with_torch_autocast(self, model): + if isinstance(model, torch.nn.ModuleList): + model = torch.nn.ModuleList([self.__wrap_model_forward_with_torch_autocast(m) for m in model]) + else: + model = self.__wrap_model_forward_with_torch_autocast(model) + return model + + def __wrap_model_forward_with_torch_autocast(self, model): + + assert hasattr(model, "forward"), f"model must have a forward method." + + forward_fn = model.forward + + def forward(*args, **kwargs): + try: + device_type = model.device.type + except AttributeError: + logger.warning( + "[DeepSpeed] model.device is not available. Using get_preferred_device() " + "to determine the device_type for torch.autocast()." + ) + device_type = get_preferred_device().type + + with torch.autocast(device_type = device_type): + return forward_fn(*args, **kwargs) + + model.forward = forward + return model + def get_models(self): return self.models + ds_model = DeepSpeedWrapper(**models) return ds_model diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index c6d2baeb..5f6867a8 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -40,7 +40,7 @@ def sample_images( text_encoders, sample_prompts_te_outputs, prompt_replacement=None, - controlnet=None + controlnet=None, ): if steps == 0: if not args.sample_at_first: @@ -101,7 +101,7 @@ def sample_images( steps, sample_prompts_te_outputs, prompt_replacement, - controlnet + controlnet, ) else: # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) @@ -125,7 +125,7 @@ def sample_images( steps, sample_prompts_te_outputs, prompt_replacement, - controlnet + controlnet, ) torch.set_rng_state(rng_state) @@ -147,14 +147,16 @@ def sample_image_inference( steps, sample_prompts_te_outputs, prompt_replacement, - controlnet + controlnet, ): assert isinstance(prompt_dict, dict) - # negative_prompt = prompt_dict.get("negative_prompt") + negative_prompt = prompt_dict.get("negative_prompt") sample_steps = prompt_dict.get("sample_steps", 20) width = prompt_dict.get("width", 512) height = prompt_dict.get("height", 512) - scale = prompt_dict.get("scale", 3.5) + # TODO refactor variable names + cfg_scale = prompt_dict.get("guidance_scale", 1.0) + emb_guidance_scale = prompt_dict.get("scale", 3.5) seed = prompt_dict.get("seed") controlnet_image = prompt_dict.get("controlnet_image") prompt: str = prompt_dict.get("prompt", "") @@ -162,8 +164,8 @@ def sample_image_inference( if prompt_replacement is not None: prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) - # if negative_prompt is not None: - # negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) if seed is not None: torch.manual_seed(seed) @@ -173,16 +175,21 @@ def sample_image_inference( torch.seed() torch.cuda.seed() - # if negative_prompt is None: - # negative_prompt = "" + if negative_prompt is None: + negative_prompt = "" height = max(64, height - height % 16) # round to divisible by 16 width = max(64, width - width % 16) # round to divisible by 16 logger.info(f"prompt: {prompt}") - # logger.info(f"negative_prompt: {negative_prompt}") + if cfg_scale != 1.0: + logger.info(f"negative_prompt: {negative_prompt}") + elif negative_prompt != "": + logger.info(f"negative prompt is ignored because scale is 1.0") logger.info(f"height: {height}") logger.info(f"width: {width}") logger.info(f"sample_steps: {sample_steps}") - logger.info(f"scale: {scale}") + logger.info(f"embedded guidance scale: {emb_guidance_scale}") + if cfg_scale != 1.0: + logger.info(f"CFG scale: {cfg_scale}") # logger.info(f"sample_sampler: {sampler_name}") if seed is not None: logger.info(f"seed: {seed}") @@ -191,26 +198,37 @@ def sample_image_inference( tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() - text_encoder_conds = [] - if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: - text_encoder_conds = sample_prompts_te_outputs[prompt] - print(f"Using cached text encoder outputs for prompt: {prompt}") - if text_encoders is not None: - print(f"Encoding prompt: {prompt}") - tokens_and_masks = tokenize_strategy.tokenize(prompt) - # strategy has apply_t5_attn_mask option - encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + def encode_prompt(prpt): + text_encoder_conds = [] + if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs: + text_encoder_conds = sample_prompts_te_outputs[prpt] + print(f"Using cached text encoder outputs for prompt: {prpt}") + if text_encoders is not None: + print(f"Encoding prompt: {prpt}") + tokens_and_masks = tokenize_strategy.tokenize(prpt) + # strategy has apply_t5_attn_mask option + encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) - # if text_encoder_conds is not cached, use encoded_text_encoder_conds - if len(text_encoder_conds) == 0: - text_encoder_conds = encoded_text_encoder_conds - else: - # if encoded_text_encoder_conds is not None, update cached text_encoder_conds - for i in range(len(encoded_text_encoder_conds)): - if encoded_text_encoder_conds[i] is not None: - text_encoder_conds[i] = encoded_text_encoder_conds[i] + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] + return text_encoder_conds - l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + l_pooled, t5_out, txt_ids, t5_attn_mask = encode_prompt(prompt) + # encode negative prompts + if cfg_scale != 1.0: + neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode_prompt(negative_prompt) + neg_t5_attn_mask = ( + neg_t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask and neg_t5_attn_mask is not None else None + ) + neg_cond = (cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask) + else: + neg_cond = None # sample image weight_dtype = ae.dtype # TOFO give dtype as argument @@ -235,7 +253,20 @@ def sample_image_inference( controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device) with accelerator.autocast(), torch.no_grad(): - x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) + x = denoise( + flux, + noise, + img_ids, + t5_out, + txt_ids, + l_pooled, + timesteps=timesteps, + guidance=emb_guidance_scale, + t5_attn_mask=t5_attn_mask, + controlnet=controlnet, + controlnet_img=controlnet_image, + neg_cond=neg_cond, + ) x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) @@ -305,21 +336,24 @@ def denoise( model: flux_models.Flux, img: torch.Tensor, img_ids: torch.Tensor, - txt: torch.Tensor, + txt: torch.Tensor, # t5_out txt_ids: torch.Tensor, - vec: torch.Tensor, + vec: torch.Tensor, # l_pooled timesteps: list[float], guidance: float = 4.0, t5_attn_mask: Optional[torch.Tensor] = None, controlnet: Optional[flux_models.ControlNetFlux] = None, controlnet_img: Optional[torch.Tensor] = None, + neg_cond: Optional[Tuple[float, torch.Tensor, torch.Tensor, torch.Tensor]] = None, ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + do_cfg = neg_cond is not None for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) model.prepare_block_swap_before_forward() + if controlnet is not None: block_samples, block_single_samples = controlnet( img=img, @@ -335,20 +369,48 @@ def denoise( else: block_samples = None block_single_samples = None - pred = model( - img=img, - img_ids=img_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, - block_controlnet_hidden_states=block_samples, - block_controlnet_single_hidden_states=block_single_samples, - timesteps=t_vec, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, - ) - img = img + (t_prev - t_curr) * pred + if not do_cfg: + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + block_controlnet_hidden_states=block_samples, + block_controlnet_single_hidden_states=block_single_samples, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + img = img + (t_prev - t_curr) * pred + else: + cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask = neg_cond + nc_c_t5_attn_mask = None if t5_attn_mask is None else torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0) + + # TODO is it ok to use the same block samples for both cond and uncond? + block_samples = None if block_samples is None else torch.cat([block_samples, block_samples], dim=0) + block_single_samples = ( + None if block_single_samples is None else torch.cat([block_single_samples, block_single_samples], dim=0) + ) + + nc_c_pred = model( + img=torch.cat([img, img], dim=0), + img_ids=torch.cat([img_ids, img_ids], dim=0), + txt=torch.cat([neg_t5_out, txt], dim=0), + txt_ids=torch.cat([txt_ids, txt_ids], dim=0), + y=torch.cat([neg_l_pooled, vec], dim=0), + block_controlnet_hidden_states=block_samples, + block_controlnet_single_hidden_states=block_single_samples, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=nc_c_t5_attn_mask, + ) + neg_pred, pred = torch.chunk(nc_c_pred, 2, dim=0) + pred = neg_pred + (pred - neg_pred) * cfg_scale + + img = img + (t_prev - t_curr) * pred model.prepare_block_swap_before_forward() return img @@ -365,8 +427,6 @@ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < n_dim: - sigma = sigma.unsqueeze(-1) return sigma @@ -409,42 +469,34 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, device, dtype + args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, _, h, w = latents.shape - sigmas = None - + assert bsz > 0, "Batch size not large enough" + num_timesteps = noise_scheduler.config.num_train_timesteps if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": - # Simple random t-based noise sampling + # Simple random sigma-based noise sampling if args.timestep_sampling == "sigmoid": # https://github.com/XLabs-AI/x-flux/tree/main - t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) + sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) else: - t = torch.rand((bsz,), device=device) + sigmas = torch.rand((bsz,), device=device) - timesteps = t * 1000.0 - t = t.view(-1, 1, 1, 1) - noisy_model_input = (1 - t) * latents + t * noise + timesteps = sigmas * num_timesteps elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift - logits_norm = torch.randn(bsz, device=device) - logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling - timesteps = logits_norm.sigmoid() - timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) - - t = timesteps.view(-1, 1, 1, 1) - timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * noise + sigmas = torch.randn(bsz, device=device) + sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling + sigmas = sigmas.sigmoid() + sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas) + timesteps = sigmas * num_timesteps elif args.timestep_sampling == "flux_shift": - logits_norm = torch.randn(bsz, device=device) - logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling - timesteps = logits_norm.sigmoid() - mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) - timesteps = time_shift(mu, 1.0, timesteps) - - t = timesteps.view(-1, 1, 1, 1) - timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * noise + sigmas = torch.randn(bsz, device=device) + sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling + sigmas = sigmas.sigmoid() + mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size + sigmas = time_shift(mu, 1.0, sigmas) + timesteps = sigmas * num_timesteps else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -455,12 +507,24 @@ def get_noisy_model_input_and_timesteps( logit_std=args.logit_std, mode_scale=args.mode_scale, ) - indices = (u * noise_scheduler.config.num_train_timesteps).long() + indices = (u * num_timesteps).long() timesteps = noise_scheduler.timesteps[indices].to(device=device) - - # Add noise according to flow matching. sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + # Broadcast sigmas to latent shape + sigmas = sigmas.view(-1, 1, 1, 1) + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.ip_noise_gamma: + xi = torch.randn_like(latents, device=latents.device, dtype=dtype) + if args.ip_noise_gamma_random_strength: + ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma + else: + ip_noise_gamma = args.ip_noise_gamma + noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi) + else: + noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas @@ -566,7 +630,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): "--controlnet_model_name_or_path", type=str, default=None, - help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)" + help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)", ) parser.add_argument( "--t5xxl_max_token_length", diff --git a/library/lumina_models.py b/library/lumina_models.py index 2508cc7d..7e925352 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -868,6 +868,8 @@ class NextDiT(nn.Module): cap_feat_dim (int): Dimension of the caption features. axes_dims (List[int]): List of dimensions for the axes. axes_lens (List[int]): List of lengths for the axes. + use_flash_attn (bool): Whether to use Flash Attention. + use_sage_attn (bool): Whether to use Sage Attention. Sage Attention only supports inference. Returns: None @@ -1110,7 +1112,11 @@ class NextDiT(nn.Module): cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2) + x_mask = torch.zeros(bsz, image_seq_len, dtype=torch.bool, device=device) + for i in range(bsz): + x[i, :image_seq_len] = x[i] + x_mask[i, :image_seq_len] = True x = self.x_embedder(x) diff --git a/library/lumina_util.py b/library/lumina_util.py index 06f089d4..452b242f 100644 --- a/library/lumina_util.py +++ b/library/lumina_util.py @@ -173,62 +173,61 @@ def pack_latents(x: torch.Tensor) -> torch.Tensor: return x -DIFFUSERS_TO_ALPHA_VLLM_MAP = { + +DIFFUSERS_TO_ALPHA_VLLM_MAP: dict[str, str] = { # Embedding layers - "cap_embedder.0.weight": ["time_caption_embed.caption_embedder.0.weight"], - "cap_embedder.1.weight": "time_caption_embed.caption_embedder.1.weight", - "cap_embedder.1.bias": "text_embedder.1.bias", - "x_embedder.weight": "patch_embedder.proj.weight", - "x_embedder.bias": "patch_embedder.proj.bias", + "time_caption_embed.caption_embedder.0.weight": "cap_embedder.0.weight", + "time_caption_embed.caption_embedder.1.weight": "cap_embedder.1.weight", + "text_embedder.1.bias": "cap_embedder.1.bias", + "patch_embedder.proj.weight": "x_embedder.weight", + "patch_embedder.proj.bias": "x_embedder.bias", # Attention modulation - "layers.().adaLN_modulation.1.weight": "transformer_blocks.().adaln_modulation.1.weight", - "layers.().adaLN_modulation.1.bias": "transformer_blocks.().adaln_modulation.1.bias", + "transformer_blocks.().adaln_modulation.1.weight": "layers.().adaLN_modulation.1.weight", + "transformer_blocks.().adaln_modulation.1.bias": "layers.().adaLN_modulation.1.bias", # Final layers - "final_layer.adaLN_modulation.1.weight": "final_adaln_modulation.1.weight", - "final_layer.adaLN_modulation.1.bias": "final_adaln_modulation.1.bias", - "final_layer.linear.weight": "final_linear.weight", - "final_layer.linear.bias": "final_linear.bias", + "final_adaln_modulation.1.weight": "final_layer.adaLN_modulation.1.weight", + "final_adaln_modulation.1.bias": "final_layer.adaLN_modulation.1.bias", + "final_linear.weight": "final_layer.linear.weight", + "final_linear.bias": "final_layer.linear.bias", # Noise refiner - "noise_refiner.().adaLN_modulation.1.weight": "single_transformer_blocks.().adaln_modulation.1.weight", - "noise_refiner.().adaLN_modulation.1.bias": "single_transformer_blocks.().adaln_modulation.1.bias", - "noise_refiner.().attention.qkv.weight": "single_transformer_blocks.().attn.to_qkv.weight", - "noise_refiner.().attention.out.weight": "single_transformer_blocks.().attn.to_out.0.weight", - # Time embedding - "t_embedder.mlp.0.weight": "time_embedder.0.weight", - "t_embedder.mlp.0.bias": "time_embedder.0.bias", - "t_embedder.mlp.2.weight": "time_embedder.2.weight", - "t_embedder.mlp.2.bias": "time_embedder.2.bias", - # Context attention - "context_refiner.().attention.qkv.weight": "transformer_blocks.().attn2.to_qkv.weight", - "context_refiner.().attention.out.weight": "transformer_blocks.().attn2.to_out.0.weight", + "single_transformer_blocks.().adaln_modulation.1.weight": "noise_refiner.().adaLN_modulation.1.weight", + "single_transformer_blocks.().adaln_modulation.1.bias": "noise_refiner.().adaLN_modulation.1.bias", + "single_transformer_blocks.().attn.to_qkv.weight": "noise_refiner.().attention.qkv.weight", + "single_transformer_blocks.().attn.to_out.0.weight": "noise_refiner.().attention.out.weight", # Normalization - "layers.().attention_norm1.weight": "transformer_blocks.().norm1.weight", - "layers.().attention_norm2.weight": "transformer_blocks.().norm2.weight", + "transformer_blocks.().norm1.weight": "layers.().attention_norm1.weight", + "transformer_blocks.().norm2.weight": "layers.().attention_norm2.weight", # FFN - "layers.().feed_forward.w1.weight": "transformer_blocks.().ff.net.0.proj.weight", - "layers.().feed_forward.w2.weight": "transformer_blocks.().ff.net.2.weight", - "layers.().feed_forward.w3.weight": "transformer_blocks.().ff.net.4.weight", + "transformer_blocks.().ff.net.0.proj.weight": "layers.().feed_forward.w1.weight", + "transformer_blocks.().ff.net.2.weight": "layers.().feed_forward.w2.weight", + "transformer_blocks.().ff.net.4.weight": "layers.().feed_forward.w3.weight", } def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict: """Convert Diffusers checkpoint to Alpha-VLLM format""" logger.info("Converting Diffusers checkpoint to Alpha-VLLM format") - new_sd = {} + new_sd = sd.copy() # Preserve original keys - for key, value in sd.items(): - new_key = key - for pattern, replacement in DIFFUSERS_TO_ALPHA_VLLM_MAP.items(): - if "()." in pattern: - for block_idx in range(num_double_blocks): - if str(block_idx) in key: - converted = pattern.replace("()", str(block_idx)) - new_key = key.replace(converted, replacement.replace("()", str(block_idx))) - break + for diff_key, alpha_key in DIFFUSERS_TO_ALPHA_VLLM_MAP.items(): + # Handle block-specific patterns + if '().' in diff_key: + for block_idx in range(num_double_blocks): + block_alpha_key = alpha_key.replace('().', f'{block_idx}.') + block_diff_key = diff_key.replace('().', f'{block_idx}.') + + # Search for and convert block-specific keys + for input_key, value in list(sd.items()): + if input_key == block_diff_key: + new_sd[block_alpha_key] = value + else: + # Handle static keys + if diff_key in sd: + print(f"Replacing {diff_key} with {alpha_key}") + new_sd[alpha_key] = sd[diff_key] + else: + print(f"Not found: {diff_key}") - if new_key == key: - logger.debug(f"Unmatched key in conversion: {key}") - new_sd[new_key] = value logger.info(f"Converted {len(new_sd)} keys to Alpha-VLLM format") return new_sd diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 6a4b39b3..c4079884 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -610,21 +610,6 @@ from diffusers.utils.torch_utils import randn_tensor from diffusers.utils import BaseOutput -# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - @dataclass class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): """ @@ -664,49 +649,22 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): self, num_train_timesteps: int = 1000, shift: float = 1.0, - use_dynamic_shifting=False, - base_shift: Optional[float] = 0.5, - max_shift: Optional[float] = 1.15, - base_image_seq_len: Optional[int] = 256, - max_image_seq_len: Optional[int] = 4096, - invert_sigmas: bool = False, - shift_terminal: Optional[float] = None, - use_karras_sigmas: Optional[bool] = False, - use_exponential_sigmas: Optional[bool] = False, - use_beta_sigmas: Optional[bool] = False, ): - if self.config.use_beta_sigmas and not is_scipy_available(): - raise ImportError("Make sure to install scipy if you want to use beta sigmas.") - if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: - raise ValueError( - "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." - ) timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) sigmas = timesteps / num_train_timesteps - if not use_dynamic_shifting: - # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution - sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) self.timesteps = sigmas * num_train_timesteps self._step_index = None self._begin_index = None - self._shift = shift - self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigma_min = self.sigmas[-1].item() self.sigma_max = self.sigmas[0].item() - @property - def shift(self): - """ - The value used for shifting. - """ - return self._shift - @property def step_index(self): """ @@ -732,9 +690,6 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): """ self._begin_index = begin_index - def set_shift(self, shift: float): - self._shift = shift - def scale_noise( self, sample: torch.FloatTensor, @@ -754,31 +709,10 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: A scaled input sample. """ - # Make sure sigmas and timesteps have the same device and dtype as original_samples - sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) - - if sample.device.type == "mps" and torch.is_floating_point(timestep): - # mps does not support float64 - schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) - timestep = timestep.to(sample.device, dtype=torch.float32) - else: - schedule_timesteps = self.timesteps.to(sample.device) - timestep = timestep.to(sample.device) - - # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index - if self.begin_index is None: - step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] - elif self.step_index is not None: - # add_noise is called after first denoising step (for inpainting) - step_indices = [self.step_index] * timestep.shape[0] - else: - # add noise is called before first denoising step to create initial latent(img2img) - step_indices = [self.begin_index] * timestep.shape[0] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < len(sample.shape): - sigma = sigma.unsqueeze(-1) + if self.step_index is None: + self._init_step_index(timestep) + sigma = self.sigmas[self.step_index] sample = sigma * noise + (1.0 - sigma) * sample return sample @@ -786,37 +720,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): def _sigma_to_t(self, sigma): return sigma * self.config.num_train_timesteps - def time_shift(self, mu: float, sigma: float, t: torch.Tensor): - return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) - - def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: - r""" - Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config - value. - - Reference: - https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 - - Args: - t (`torch.Tensor`): - A tensor of timesteps to be stretched and shifted. - - Returns: - `torch.Tensor`: - A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`. - """ - one_minus_z = 1 - t - scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal) - stretched_t = 1 - (one_minus_z / scale_factor) - return stretched_t - - def set_timesteps( - self, - num_inference_steps: int = None, - device: Union[str, torch.device] = None, - sigmas: Optional[List[float]] = None, - mu: Optional[float] = None, - ): + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -826,49 +730,18 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - if self.config.use_dynamic_shifting and mu is None: - raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") - - if sigmas is None: - timesteps = np.linspace( - self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps - ) - - sigmas = timesteps / self.config.num_train_timesteps - else: - sigmas = np.array(sigmas).astype(np.float32) - num_inference_steps = len(sigmas) self.num_inference_steps = num_inference_steps - if self.config.use_dynamic_shifting: - sigmas = self.time_shift(mu, 1.0, sigmas) - else: - sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) - - if self.config.shift_terminal: - sigmas = self.stretch_shift_to_terminal(sigmas) - - if self.config.use_karras_sigmas: - sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - - elif self.config.use_exponential_sigmas: - sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - - elif self.config.use_beta_sigmas: - sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps) + sigmas = timesteps / self.config.num_train_timesteps + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + timesteps = sigmas * self.config.num_train_timesteps - - if self.config.invert_sigmas: - sigmas = 1.0 - sigmas - timesteps = sigmas * self.config.num_train_timesteps - sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)]) - else: - sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) - self.timesteps = timesteps.to(device=device) - self.sigmas = sigmas + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + self._step_index = None self._begin_index = None @@ -934,11 +807,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): returned, otherwise a tuple is returned where the first element is the sample tensor. """ - if ( - isinstance(timestep, int) - or isinstance(timestep, torch.IntTensor) - or isinstance(timestep, torch.LongTensor) - ): + if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" @@ -954,10 +823,30 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): sample = sample.to(torch.float32) sigma = self.sigmas[self.step_index] - sigma_next = self.sigmas[self.step_index + 1] - prev_sample = sample + (sigma_next - sigma) * model_output + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator) + + eps = noise * s_noise + sigma_hat = sigma * (gamma + 1) + + if gamma > 0: + sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + # NOTE: "original_sample" should not be an expected prediction_type but is left in for + # backwards compatibility + + # if self.config.prediction_type == "vector_field": + + denoised = sample - model_output * sigma + # 2. Convert to an ODE derivative + derivative = (sample - denoised) / sigma_hat + + dt = self.sigmas[self.step_index + 1] - sigma_hat + + prev_sample = sample + derivative * dt # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) @@ -969,86 +858,6 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras - def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" - - # Hack to make sure that other schedulers which copy this function don't break - # TODO: Add this logic to the other schedulers - if hasattr(self.config, "sigma_min"): - sigma_min = self.config.sigma_min - else: - sigma_min = None - - if hasattr(self.config, "sigma_max"): - sigma_max = self.config.sigma_max - else: - sigma_max = None - - sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() - sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() - - rho = 7.0 # 7.0 is the value used in the paper - ramp = np.linspace(0, 1, num_inference_steps) - min_inv_rho = sigma_min ** (1 / rho) - max_inv_rho = sigma_max ** (1 / rho) - sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho - return sigmas - - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential - def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: - """Constructs an exponential noise schedule.""" - - # Hack to make sure that other schedulers which copy this function don't break - # TODO: Add this logic to the other schedulers - if hasattr(self.config, "sigma_min"): - sigma_min = self.config.sigma_min - else: - sigma_min = None - - if hasattr(self.config, "sigma_max"): - sigma_max = self.config.sigma_max - else: - sigma_max = None - - sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() - sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() - - sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) - return sigmas - - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta - def _convert_to_beta( - self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 - ) -> torch.Tensor: - """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" - - # Hack to make sure that other schedulers which copy this function don't break - # TODO: Add this logic to the other schedulers - if hasattr(self.config, "sigma_min"): - sigma_min = self.config.sigma_min - else: - sigma_min = None - - if hasattr(self.config, "sigma_max"): - sigma_max = self.config.sigma_max - else: - sigma_max = None - - sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() - sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() - - sigmas = np.array( - [ - sigma_min + (ppf * (sigma_max - sigma_min)) - for ppf in [ - scipy.stats.beta.ppf(timestep, alpha, beta) - for timestep in 1 - np.linspace(0, 1, num_inference_steps) - ] - ] - ) - return sigmas - def __len__(self): return self.config.num_train_timesteps diff --git a/library/train_util.py b/library/train_util.py index e2d0d175..68019e21 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1070,8 +1070,11 @@ class BaseDataset(torch.utils.data.Dataset): self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)} logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}") - img_ar_errors = np.array(img_ar_errors) - mean_img_ar_error = np.mean(np.abs(img_ar_errors)) + if len(img_ar_errors) == 0: + mean_img_ar_error = 0 # avoid NaN + else: + img_ar_errors = np.array(img_ar_errors) + mean_img_ar_error = np.mean(np.abs(img_ar_errors)) self.bucket_info["mean_img_ar_error"] = mean_img_ar_error logger.info(f"mean ar error (without repeats): {mean_img_ar_error}") @@ -5520,6 +5523,11 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio def patch_accelerator_for_fp16_training(accelerator): + + from accelerate import DistributedType + if accelerator.distributed_type == DistributedType.DEEPSPEED: + return + org_unscale_grads = accelerator.scaler._unscale_grads_ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): @@ -6203,6 +6211,11 @@ def line_to_prompt_dict(line: str) -> dict: prompt_dict["scale"] = float(m.group(1)) continue + m = re.match(r"g ([\d\.]+)", parg, re.IGNORECASE) + if m: # guidance scale + prompt_dict["guidance_scale"] = float(m.group(1)) + continue + m = re.match(r"n (.+)", parg, re.IGNORECASE) if m: # negative prompt prompt_dict["negative_prompt"] = m.group(1) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 92b3979a..0b30f1b8 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -955,26 +955,26 @@ class LoRANetwork(torch.nn.Module): for lora in self.text_encoder_loras + self.unet_loras: lora.update_grad_norms() - def grad_norms(self) -> Tensor: + def grad_norms(self) -> Tensor | None: grad_norms = [] for lora in self.text_encoder_loras + self.unet_loras: if hasattr(lora, "grad_norms") and lora.grad_norms is not None: grad_norms.append(lora.grad_norms.mean(dim=0)) - return torch.stack(grad_norms) if len(grad_norms) > 0 else torch.tensor([]) + return torch.stack(grad_norms) if len(grad_norms) > 0 else None - def weight_norms(self) -> Tensor: + def weight_norms(self) -> Tensor | None: weight_norms = [] for lora in self.text_encoder_loras + self.unet_loras: if hasattr(lora, "weight_norms") and lora.weight_norms is not None: weight_norms.append(lora.weight_norms.mean(dim=0)) - return torch.stack(weight_norms) if len(weight_norms) > 0 else torch.tensor([]) + return torch.stack(weight_norms) if len(weight_norms) > 0 else None - def combined_weight_norms(self) -> Tensor: + def combined_weight_norms(self) -> Tensor | None: combined_weight_norms = [] for lora in self.text_encoder_loras + self.unet_loras: if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None: combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0)) - return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([]) + return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None def load_weights(self, file): diff --git a/pytest.ini b/pytest.ini index 484d3aef..34b7e9c1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -6,3 +6,4 @@ filterwarnings = ignore::DeprecationWarning ignore::UserWarning ignore::FutureWarning +pythonpath = . diff --git a/tests/library/test_flux_train_utils.py b/tests/library/test_flux_train_utils.py new file mode 100644 index 00000000..2ad7ce4e --- /dev/null +++ b/tests/library/test_flux_train_utils.py @@ -0,0 +1,220 @@ +import pytest +import torch +from unittest.mock import MagicMock, patch +from library.flux_train_utils import ( + get_noisy_model_input_and_timesteps, +) + +# Mock classes and functions +class MockNoiseScheduler: + def __init__(self, num_train_timesteps=1000): + self.config = MagicMock() + self.config.num_train_timesteps = num_train_timesteps + self.timesteps = torch.arange(num_train_timesteps, dtype=torch.long) + + +# Create fixtures for commonly used objects +@pytest.fixture +def args(): + args = MagicMock() + args.timestep_sampling = "uniform" + args.weighting_scheme = "uniform" + args.logit_mean = 0.0 + args.logit_std = 1.0 + args.mode_scale = 1.0 + args.sigmoid_scale = 1.0 + args.discrete_flow_shift = 3.1582 + args.ip_noise_gamma = None + args.ip_noise_gamma_random_strength = False + return args + + +@pytest.fixture +def noise_scheduler(): + return MockNoiseScheduler(num_train_timesteps=1000) + + +@pytest.fixture +def latents(): + return torch.randn(2, 4, 8, 8) + + +@pytest.fixture +def noise(): + return torch.randn(2, 4, 8, 8) + + +@pytest.fixture +def device(): + # return "cuda" if torch.cuda.is_available() else "cpu" + return "cpu" + + +# Mock the required functions +@pytest.fixture(autouse=True) +def mock_functions(): + with ( + patch("torch.sigmoid", side_effect=torch.sigmoid), + patch("torch.rand", side_effect=torch.rand), + patch("torch.randn", side_effect=torch.randn), + ): + yield + + +# Test different timestep sampling methods +def test_uniform_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "uniform" + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + assert noisy_input.dtype == dtype + assert timesteps.dtype == dtype + + +def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "sigmoid" + args.sigmoid_scale = 1.0 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_shift_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "shift" + args.sigmoid_scale = 1.0 + args.discrete_flow_shift = 3.1582 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "flux_shift" + args.sigmoid_scale = 1.0 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_weighting_scheme(args, noise_scheduler, latents, noise, device): + # Mock the necessary functions for this specific test + with patch("library.flux_train_utils.compute_density_for_timestep_sampling", + return_value=torch.tensor([0.3, 0.7], device=device)), \ + patch("library.flux_train_utils.get_sigmas", + return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)): + + args.timestep_sampling = "other" # Will trigger the weighting scheme path + args.weighting_scheme = "uniform" + args.logit_mean = 0.0 + args.logit_std = 1.0 + args.mode_scale = 1.0 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, device, dtype + ) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +# Test IP noise options +def test_with_ip_noise(args, noise_scheduler, latents, noise, device): + args.ip_noise_gamma = 0.5 + args.ip_noise_gamma_random_strength = False + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device): + args.ip_noise_gamma = 0.1 + args.ip_noise_gamma_random_strength = True + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +# Test different data types +def test_float16_dtype(args, noise_scheduler, latents, noise, device): + dtype = torch.float16 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.dtype == dtype + assert timesteps.dtype == dtype + + +# Test different batch sizes +def test_different_batch_size(args, noise_scheduler, device): + latents = torch.randn(5, 4, 8, 8) # batch size of 5 + noise = torch.randn(5, 4, 8, 8) + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (5,) + assert sigmas.shape == (5, 1, 1, 1) + + +# Test different image sizes +def test_different_image_size(args, noise_scheduler, device): + latents = torch.randn(2, 4, 16, 16) # larger image size + noise = torch.randn(2, 4, 16, 16) + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (2,) + assert sigmas.shape == (2, 1, 1, 1) + + +# Test edge cases +def test_zero_batch_size(args, noise_scheduler, device): + with pytest.raises(AssertionError): # expecting an error with zero batch size + latents = torch.randn(0, 4, 8, 8) + noise = torch.randn(0, 4, 8, 8) + dtype = torch.float32 + + get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + +def test_different_timestep_count(args, device): + noise_scheduler = MockNoiseScheduler(num_train_timesteps=500) # different timestep count + latents = torch.randn(2, 4, 8, 8) + noise = torch.randn(2, 4, 8, 8) + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (2,) + # Check that timesteps are within the proper range + assert torch.all(timesteps < 500) diff --git a/tests/library/test_lumina_models.py b/tests/library/test_lumina_models.py new file mode 100644 index 00000000..ba063688 --- /dev/null +++ b/tests/library/test_lumina_models.py @@ -0,0 +1,295 @@ +import pytest +import torch + +from library.lumina_models import ( + LuminaParams, + to_cuda, + to_cpu, + RopeEmbedder, + TimestepEmbedder, + modulate, + NextDiT, +) + +cuda_required = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + + +def test_lumina_params(): + # Test default configuration + default_params = LuminaParams() + assert default_params.patch_size == 2 + assert default_params.in_channels == 4 + assert default_params.axes_dims == [36, 36, 36] + assert default_params.axes_lens == [300, 512, 512] + + # Test 2B config + config_2b = LuminaParams.get_2b_config() + assert config_2b.dim == 2304 + assert config_2b.in_channels == 16 + assert config_2b.n_layers == 26 + assert config_2b.n_heads == 24 + assert config_2b.cap_feat_dim == 2304 + + # Test 7B config + config_7b = LuminaParams.get_7b_config() + assert config_7b.dim == 4096 + assert config_7b.n_layers == 32 + assert config_7b.n_heads == 32 + assert config_7b.axes_dims == [64, 64, 64] + + +@cuda_required +def test_to_cuda_to_cpu(): + # Test tensor conversion + x = torch.tensor([1, 2, 3]) + x_cuda = to_cuda(x) + x_cpu = to_cpu(x_cuda) + assert x.cpu().tolist() == x_cpu.tolist() + + # Test list conversion + list_data = [torch.tensor([1]), torch.tensor([2])] + list_cuda = to_cuda(list_data) + assert all(tensor.device.type == "cuda" for tensor in list_cuda) + + list_cpu = to_cpu(list_cuda) + assert all(not tensor.device.type == "cuda" for tensor in list_cpu) + + # Test dict conversion + dict_data = {"a": torch.tensor([1]), "b": torch.tensor([2])} + dict_cuda = to_cuda(dict_data) + assert all(tensor.device.type == "cuda" for tensor in dict_cuda.values()) + + dict_cpu = to_cpu(dict_cuda) + assert all(not tensor.device.type == "cuda" for tensor in dict_cpu.values()) + + +def test_timestep_embedder(): + # Test initialization + hidden_size = 256 + freq_emb_size = 128 + embedder = TimestepEmbedder(hidden_size, freq_emb_size) + assert embedder.frequency_embedding_size == freq_emb_size + + # Test timestep embedding + t = torch.tensor([0.5, 1.0, 2.0]) + emb_dim = freq_emb_size + embeddings = TimestepEmbedder.timestep_embedding(t, emb_dim) + + assert embeddings.shape == (3, emb_dim) + assert embeddings.dtype == torch.float32 + + # Ensure embeddings are unique for different input times + assert not torch.allclose(embeddings[0], embeddings[1]) + + # Test forward pass + t_emb = embedder(t) + assert t_emb.shape == (3, hidden_size) + + +def test_rope_embedder_simple(): + rope_embedder = RopeEmbedder() + batch_size, seq_len = 2, 10 + + # Create position_ids with valid ranges for each axis + position_ids = torch.stack( + [ + torch.zeros(batch_size, seq_len, dtype=torch.int64), # First axis: only 0 is valid + torch.randint(0, 512, (batch_size, seq_len), dtype=torch.int64), # Second axis: 0-511 + torch.randint(0, 512, (batch_size, seq_len), dtype=torch.int64), # Third axis: 0-511 + ], + dim=-1, + ) + + freqs_cis = rope_embedder(position_ids) + # RoPE embeddings work in pairs, so output dimension is half of total axes_dims + expected_dim = sum(rope_embedder.axes_dims) // 2 # 128 // 2 = 64 + assert freqs_cis.shape == (batch_size, seq_len, expected_dim) + + +def test_modulate(): + # Test modulation with different scales + x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + scale = torch.tensor([1.5, 2.0]) + + modulated_x = modulate(x, scale) + + # Check that modulation scales correctly + # The function does x * (1 + scale), so: + # For scale [1.5, 2.0], (1 + scale) = [2.5, 3.0] + expected_x = torch.tensor([[2.5 * 1.0, 2.5 * 2.0], [3.0 * 3.0, 3.0 * 4.0]]) + # Which equals: [[2.5, 5.0], [9.0, 12.0]] + + assert torch.allclose(modulated_x, expected_x) + + +def test_nextdit_parameter_count_optimized(): + # The constraint is: (dim // n_heads) == sum(axes_dims) + # So for dim=120, n_heads=4: 120//4 = 30, so sum(axes_dims) must = 30 + model_small = NextDiT( + patch_size=2, + in_channels=4, # Smaller + dim=120, # 120 // 4 = 30 + n_layers=2, # Much fewer layers + n_heads=4, # Fewer heads + n_kv_heads=2, + axes_dims=[10, 10, 10], # sum = 30 + axes_lens=[10, 32, 32], # Smaller + ) + param_count_small = model_small.parameter_count() + assert param_count_small > 0 + + # For dim=192, n_heads=6: 192//6 = 32, so sum(axes_dims) must = 32 + model_medium = NextDiT( + patch_size=2, + in_channels=4, + dim=192, # 192 // 6 = 32 + n_layers=4, # More layers + n_heads=6, + n_kv_heads=3, + axes_dims=[10, 11, 11], # sum = 32 + axes_lens=[10, 32, 32], + ) + param_count_medium = model_medium.parameter_count() + assert param_count_medium > param_count_small + print(f"Small model: {param_count_small:,} parameters") + print(f"Medium model: {param_count_medium:,} parameters") + + +@torch.no_grad() +def test_precompute_freqs_cis(): + # Test precompute_freqs_cis + dim = [16, 56, 56] + end = [1, 512, 512] + theta = 10000.0 + + freqs_cis = NextDiT.precompute_freqs_cis(dim, end, theta) + + # Check number of frequency tensors + assert len(freqs_cis) == len(dim) + + # Check each frequency tensor + for i, (d, e) in enumerate(zip(dim, end)): + assert freqs_cis[i].shape == (e, d // 2) + assert freqs_cis[i].dtype == torch.complex128 + + +@torch.no_grad() +def test_nextdit_patchify_and_embed(): + """Test the patchify_and_embed method which is crucial for training""" + # Create a small NextDiT model for testing + # The constraint is: (dim // n_heads) == sum(axes_dims) + # For dim=120, n_heads=4: 120//4 = 30, so sum(axes_dims) must = 30 + model = NextDiT( + patch_size=2, + in_channels=4, + dim=120, # 120 // 4 = 30 + n_layers=1, # Minimal layers for faster testing + n_refiner_layers=1, # Minimal refiner layers + n_heads=4, + n_kv_heads=2, + axes_dims=[10, 10, 10], # sum = 30 + axes_lens=[10, 32, 32], + cap_feat_dim=120, # Match dim for consistency + ) + + # Prepare test inputs + batch_size = 2 + height, width = 64, 64 # Must be divisible by patch_size (2) + caption_seq_len = 8 + + # Create mock inputs + x = torch.randn(batch_size, 4, height, width) # Image latents + cap_feats = torch.randn(batch_size, caption_seq_len, 120) # Caption features + cap_mask = torch.ones(batch_size, caption_seq_len, dtype=torch.bool) # All valid tokens + # Make second batch have shorter caption + cap_mask[1, 6:] = False # Only first 6 tokens are valid for second batch + t = torch.randn(batch_size, 120) # Timestep embeddings + + # Call patchify_and_embed + joint_hidden_states, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths = model.patchify_and_embed( + x, cap_feats, cap_mask, t + ) + + # Validate outputs + image_seq_len = (height // 2) * (width // 2) # patch_size = 2 + expected_seq_lengths = [caption_seq_len + image_seq_len, 6 + image_seq_len] # Second batch has shorter caption + max_seq_len = max(expected_seq_lengths) + + # Check joint hidden states shape + assert joint_hidden_states.shape == (batch_size, max_seq_len, 120) + assert joint_hidden_states.dtype == torch.float32 + + # Check attention mask shape and values + assert attention_mask.shape == (batch_size, max_seq_len) + assert attention_mask.dtype == torch.bool + # First batch should have all positions valid up to its sequence length + assert torch.all(attention_mask[0, : expected_seq_lengths[0]]) + assert torch.all(~attention_mask[0, expected_seq_lengths[0] :]) + # Second batch should have all positions valid up to its sequence length + assert torch.all(attention_mask[1, : expected_seq_lengths[1]]) + assert torch.all(~attention_mask[1, expected_seq_lengths[1] :]) + + # Check freqs_cis shape + assert freqs_cis.shape == (batch_size, max_seq_len, sum(model.axes_dims) // 2) + + # Check effective caption lengths + assert l_effective_cap_len == [caption_seq_len, 6] + + # Check sequence lengths + assert seq_lengths == expected_seq_lengths + + # Validate that the joint hidden states contain non-zero values where attention mask is True + for i in range(batch_size): + valid_positions = attention_mask[i] + # Check that valid positions have meaningful data (not all zeros) + valid_data = joint_hidden_states[i][valid_positions] + assert not torch.allclose(valid_data, torch.zeros_like(valid_data)) + + # Check that invalid positions are zeros + if valid_positions.sum() < max_seq_len: + invalid_data = joint_hidden_states[i][~valid_positions] + assert torch.allclose(invalid_data, torch.zeros_like(invalid_data)) + + +@torch.no_grad() +def test_nextdit_patchify_and_embed_edge_cases(): + """Test edge cases for patchify_and_embed""" + # Create minimal model + model = NextDiT( + patch_size=2, + in_channels=4, + dim=60, # 60 // 3 = 20 + n_layers=1, + n_refiner_layers=1, + n_heads=3, + n_kv_heads=1, + axes_dims=[8, 6, 6], # sum = 20 + axes_lens=[10, 16, 16], + cap_feat_dim=60, + ) + + # Test with empty captions (all masked) + batch_size = 1 + height, width = 32, 32 + caption_seq_len = 4 + + x = torch.randn(batch_size, 4, height, width) + cap_feats = torch.randn(batch_size, caption_seq_len, 60) + cap_mask = torch.zeros(batch_size, caption_seq_len, dtype=torch.bool) # All tokens masked + t = torch.randn(batch_size, 60) + + joint_hidden_states, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths = model.patchify_and_embed( + x, cap_feats, cap_mask, t + ) + + # With all captions masked, effective length should be 0 + assert l_effective_cap_len == [0] + + # Sequence length should just be the image sequence length + image_seq_len = (height // 2) * (width // 2) + assert seq_lengths == [image_seq_len] + + # Joint hidden states should only contain image data + assert joint_hidden_states.shape == (batch_size, image_seq_len, 60) + assert attention_mask.shape == (batch_size, image_seq_len) + assert torch.all(attention_mask[0]) # All image positions should be valid diff --git a/tests/library/test_lumina_train_util.py b/tests/library/test_lumina_train_util.py new file mode 100644 index 00000000..bcf448c8 --- /dev/null +++ b/tests/library/test_lumina_train_util.py @@ -0,0 +1,241 @@ +import pytest +import torch +import math + +from library.lumina_train_util import ( + batchify, + time_shift, + get_lin_function, + get_schedule, + compute_density_for_timestep_sampling, + get_sigmas, + compute_loss_weighting_for_sd3, + get_noisy_model_input_and_timesteps, + apply_model_prediction_type, + retrieve_timesteps, +) +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler + + +def test_batchify(): + # Test case with no batch size specified + prompts = [ + {"prompt": "test1"}, + {"prompt": "test2"}, + {"prompt": "test3"} + ] + batchified = list(batchify(prompts)) + assert len(batchified) == 1 + assert len(batchified[0]) == 3 + + # Test case with batch size specified + batchified_sized = list(batchify(prompts, batch_size=2)) + assert len(batchified_sized) == 2 + assert len(batchified_sized[0]) == 2 + assert len(batchified_sized[1]) == 1 + + # Test batching with prompts having same parameters + prompts_with_params = [ + {"prompt": "test1", "width": 512, "height": 512}, + {"prompt": "test2", "width": 512, "height": 512}, + {"prompt": "test3", "width": 1024, "height": 1024} + ] + batchified_params = list(batchify(prompts_with_params)) + assert len(batchified_params) == 2 + + # Test invalid batch size + with pytest.raises(ValueError): + list(batchify(prompts, batch_size=0)) + with pytest.raises(ValueError): + list(batchify(prompts, batch_size=-1)) + + +def test_time_shift(): + # Test standard parameters + t = torch.tensor([0.5]) + mu = 1.0 + sigma = 1.0 + result = time_shift(mu, sigma, t) + assert 0 <= result <= 1 + + # Test with edge cases + t_edges = torch.tensor([0.0, 1.0]) + result_edges = time_shift(1.0, 1.0, t_edges) + + # Check that results are bounded within [0, 1] + assert torch.all(result_edges >= 0) + assert torch.all(result_edges <= 1) + + +def test_get_lin_function(): + # Default parameters + func = get_lin_function() + assert func(256) == 0.5 + assert func(4096) == 1.15 + + # Custom parameters + custom_func = get_lin_function(x1=100, x2=1000, y1=0.1, y2=0.9) + assert custom_func(100) == 0.1 + assert custom_func(1000) == 0.9 + + +def test_get_schedule(): + # Basic schedule + schedule = get_schedule(num_steps=10, image_seq_len=256) + assert len(schedule) == 10 + assert all(0 <= x <= 1 for x in schedule) + + # Test different sequence lengths + short_schedule = get_schedule(num_steps=5, image_seq_len=128) + long_schedule = get_schedule(num_steps=15, image_seq_len=1024) + assert len(short_schedule) == 5 + assert len(long_schedule) == 15 + + # Test with shift disabled + unshifted_schedule = get_schedule(num_steps=10, image_seq_len=256, shift=False) + assert torch.allclose( + torch.tensor(unshifted_schedule), + torch.linspace(1, 1/10, 10) + ) + + +def test_compute_density_for_timestep_sampling(): + # Test uniform sampling + uniform_samples = compute_density_for_timestep_sampling("uniform", batch_size=100) + assert len(uniform_samples) == 100 + assert torch.all((uniform_samples >= 0) & (uniform_samples <= 1)) + + # Test logit normal sampling + logit_normal_samples = compute_density_for_timestep_sampling( + "logit_normal", batch_size=100, logit_mean=0.0, logit_std=1.0 + ) + assert len(logit_normal_samples) == 100 + assert torch.all((logit_normal_samples >= 0) & (logit_normal_samples <= 1)) + + # Test mode sampling + mode_samples = compute_density_for_timestep_sampling( + "mode", batch_size=100, mode_scale=0.5 + ) + assert len(mode_samples) == 100 + assert torch.all((mode_samples >= 0) & (mode_samples <= 1)) + + +def test_get_sigmas(): + # Create a mock noise scheduler + scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000) + device = torch.device('cpu') + + # Test with default parameters + timesteps = torch.tensor([100, 500, 900]) + sigmas = get_sigmas(scheduler, timesteps, device) + + # Check shape and basic properties + assert sigmas.shape[0] == 3 + assert torch.all(sigmas >= 0) + + # Test with different n_dim + sigmas_4d = get_sigmas(scheduler, timesteps, device, n_dim=4) + assert sigmas_4d.ndim == 4 + + # Test with different dtype + sigmas_float16 = get_sigmas(scheduler, timesteps, device, dtype=torch.float16) + assert sigmas_float16.dtype == torch.float16 + + +def test_compute_loss_weighting_for_sd3(): + # Prepare some mock sigmas + sigmas = torch.tensor([0.1, 0.5, 1.0]) + + # Test sigma_sqrt weighting + sqrt_weighting = compute_loss_weighting_for_sd3("sigma_sqrt", sigmas) + assert torch.allclose(sqrt_weighting, 1 / (sigmas**2), rtol=1e-5) + + # Test cosmap weighting + cosmap_weighting = compute_loss_weighting_for_sd3("cosmap", sigmas) + bot = 1 - 2 * sigmas + 2 * sigmas**2 + expected_cosmap = 2 / (math.pi * bot) + assert torch.allclose(cosmap_weighting, expected_cosmap, rtol=1e-5) + + # Test default weighting + default_weighting = compute_loss_weighting_for_sd3("unknown", sigmas) + assert torch.all(default_weighting == 1) + + +def test_apply_model_prediction_type(): + # Create mock args and tensors + class MockArgs: + model_prediction_type = "raw" + weighting_scheme = "sigma_sqrt" + + args = MockArgs() + model_pred = torch.tensor([1.0, 2.0, 3.0]) + noisy_model_input = torch.tensor([0.5, 1.0, 1.5]) + sigmas = torch.tensor([0.1, 0.5, 1.0]) + + # Test raw prediction type + raw_pred, raw_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + assert torch.all(raw_pred == model_pred) + assert raw_weighting is None + + # Test additive prediction type + args.model_prediction_type = "additive" + additive_pred, _ = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + assert torch.all(additive_pred == model_pred + noisy_model_input) + + # Test sigma scaled prediction type + args.model_prediction_type = "sigma_scaled" + sigma_scaled_pred, sigma_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + assert torch.all(sigma_scaled_pred == model_pred * (-sigmas) + noisy_model_input) + assert sigma_weighting is not None + + +def test_retrieve_timesteps(): + # Create a mock scheduler + scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000) + + # Test with num_inference_steps + timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=50) + assert len(timesteps) == 50 + assert n_steps == 50 + + # Test error handling with simultaneous timesteps and sigmas + with pytest.raises(ValueError): + retrieve_timesteps(scheduler, timesteps=[1, 2, 3], sigmas=[0.1, 0.2, 0.3]) + + +def test_get_noisy_model_input_and_timesteps(): + # Create a mock args and setup + class MockArgs: + timestep_sampling = "uniform" + weighting_scheme = "sigma_sqrt" + sigmoid_scale = 1.0 + discrete_flow_shift = 6.0 + + args = MockArgs() + scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000) + device = torch.device('cpu') + + # Prepare mock latents and noise + latents = torch.randn(4, 16, 64, 64) + noise = torch.randn_like(latents) + + # Test uniform sampling + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps( + args, scheduler, latents, noise, device, torch.float32 + ) + + # Validate output shapes and types + assert noisy_input.shape == latents.shape + assert timesteps.shape[0] == latents.shape[0] + assert noisy_input.dtype == torch.float32 + assert timesteps.dtype == torch.float32 + + # Test different sampling methods + sampling_methods = ["sigmoid", "shift", "nextdit_shift"] + for method in sampling_methods: + args.timestep_sampling = method + noisy_input, timesteps, _ = get_noisy_model_input_and_timesteps( + args, scheduler, latents, noise, device, torch.float32 + ) + assert noisy_input.shape == latents.shape + assert timesteps.shape[0] == latents.shape[0] diff --git a/tests/library/test_lumina_util.py b/tests/library/test_lumina_util.py new file mode 100644 index 00000000..397bab5a --- /dev/null +++ b/tests/library/test_lumina_util.py @@ -0,0 +1,112 @@ +import torch +from torch.nn.modules import conv + +from library import lumina_util + + +def test_unpack_latents(): + # Create a test tensor + # Shape: [batch, height*width, channels*patch_height*patch_width] + x = torch.randn(2, 4, 16) # 2 batches, 4 tokens, 16 channels + packed_latent_height = 2 + packed_latent_width = 2 + + # Unpack the latents + unpacked = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width) + + # Check output shape + # Expected shape: [batch, channels, height*patch_height, width*patch_width] + assert unpacked.shape == (2, 4, 4, 4) + + +def test_pack_latents(): + # Create a test tensor + # Shape: [batch, channels, height*patch_height, width*patch_width] + x = torch.randn(2, 4, 4, 4) + + # Pack the latents + packed = lumina_util.pack_latents(x) + + # Check output shape + # Expected shape: [batch, height*width, channels*patch_height*patch_width] + assert packed.shape == (2, 4, 16) + + +def test_convert_diffusers_sd_to_alpha_vllm(): + num_double_blocks = 2 + # Predefined test cases based on the actual conversion map + test_cases = [ + # Static key conversions with possible list mappings + { + "original_keys": ["time_caption_embed.caption_embedder.0.weight"], + "original_pattern": ["time_caption_embed.caption_embedder.0.weight"], + "expected_converted_keys": ["cap_embedder.0.weight"], + }, + { + "original_keys": ["patch_embedder.proj.weight"], + "original_pattern": ["patch_embedder.proj.weight"], + "expected_converted_keys": ["x_embedder.weight"], + }, + { + "original_keys": ["transformer_blocks.0.norm1.weight"], + "original_pattern": ["transformer_blocks.().norm1.weight"], + "expected_converted_keys": ["layers.0.attention_norm1.weight"], + }, + ] + + + for test_case in test_cases: + for original_key, original_pattern, expected_converted_key in zip( + test_case["original_keys"], test_case["original_pattern"], test_case["expected_converted_keys"] + ): + # Create test state dict + test_sd = {original_key: torch.randn(10, 10)} + + # Convert the state dict + converted_sd = lumina_util.convert_diffusers_sd_to_alpha_vllm(test_sd, num_double_blocks) + + # Verify conversion (handle both string and list keys) + # Find the correct converted key + match_found = False + if expected_converted_key in converted_sd: + # Verify tensor preservation + assert torch.allclose(converted_sd[expected_converted_key], test_sd[original_key], atol=1e-6), ( + f"Tensor mismatch for {original_key}" + ) + match_found = True + break + + assert match_found, f"Failed to convert {original_key}" + + # Ensure original key is also present + assert original_key in converted_sd + + # Test with block-specific keys + block_specific_cases = [ + { + "original_pattern": "transformer_blocks.().norm1.weight", + "converted_pattern": "layers.().attention_norm1.weight", + } + ] + + for case in block_specific_cases: + for block_idx in range(2): # Test multiple block indices + # Prepare block-specific keys + block_original_key = case["original_pattern"].replace("()", str(block_idx)) + block_converted_key = case["converted_pattern"].replace("()", str(block_idx)) + print(block_original_key, block_converted_key) + + # Create test state dict + test_sd = {block_original_key: torch.randn(10, 10)} + + # Convert the state dict + converted_sd = lumina_util.convert_diffusers_sd_to_alpha_vllm(test_sd, num_double_blocks) + + # Verify conversion + # assert block_converted_key in converted_sd, f"Failed to convert block key {block_original_key}" + assert torch.allclose(converted_sd[block_converted_key], test_sd[block_original_key], atol=1e-6), ( + f"Tensor mismatch for block key {block_original_key}" + ) + + # Ensure original key is also present + assert block_original_key in converted_sd diff --git a/tests/library/test_strategy_lumina.py b/tests/library/test_strategy_lumina.py new file mode 100644 index 00000000..18e196bf --- /dev/null +++ b/tests/library/test_strategy_lumina.py @@ -0,0 +1,227 @@ +import os +import tempfile +import torch +import numpy as np +from unittest.mock import patch +from transformers import Gemma2Model + +from library.strategy_lumina import ( + LuminaTokenizeStrategy, + LuminaTextEncodingStrategy, + LuminaTextEncoderOutputsCachingStrategy, + LuminaLatentsCachingStrategy, +) + + +class SimpleMockGemma2Model: + """Lightweight mock that avoids initializing the actual Gemma2Model""" + + def __init__(self, hidden_size=2304): + self.device = torch.device("cpu") + self._hidden_size = hidden_size + self._orig_mod = self # For dynamic compilation compatibility + + def __call__(self, input_ids, attention_mask, output_hidden_states=False, return_dict=False): + # Create a mock output object with hidden states + batch_size, seq_len = input_ids.shape + hidden_size = self._hidden_size + + class MockOutput: + def __init__(self, hidden_states): + self.hidden_states = hidden_states + + mock_hidden_states = [ + torch.randn(batch_size, seq_len, hidden_size, device=input_ids.device) + for _ in range(3) # Mimic multiple layers of hidden states + ] + + return MockOutput(mock_hidden_states) + + +def test_lumina_tokenize_strategy(): + # Test default initialization + tokenize_strategy = LuminaTokenizeStrategy(max_length=None) + assert tokenize_strategy.max_length == 256 + assert tokenize_strategy.tokenizer.padding_side == "right" + + # Test tokenization of a single string + text = "Hello" + tokens, attention_mask = tokenize_strategy.tokenize(text) + + assert tokens.ndim == 2 + assert attention_mask.ndim == 2 + assert tokens.shape == attention_mask.shape + assert tokens.shape[1] == 256 # max_length + + # Test tokenize_with_weights + tokens, attention_mask, weights = tokenize_strategy.tokenize_with_weights(text) + assert len(weights) == 1 + assert torch.all(weights[0] == 1) + + +def test_lumina_text_encoding_strategy(): + # Create strategies + tokenize_strategy = LuminaTokenizeStrategy(max_length=None) + encoding_strategy = LuminaTextEncodingStrategy() + + # Create a mock model + mock_model = SimpleMockGemma2Model() + + # Patch the isinstance check to accept our simple mock + original_isinstance = isinstance + with patch("library.strategy_lumina.isinstance") as mock_isinstance: + + def custom_isinstance(obj, class_or_tuple): + if obj is mock_model and class_or_tuple is Gemma2Model: + return True + if hasattr(obj, "_orig_mod") and obj._orig_mod is mock_model and class_or_tuple is Gemma2Model: + return True + return original_isinstance(obj, class_or_tuple) + + mock_isinstance.side_effect = custom_isinstance + + # Prepare sample text + text = "Test encoding strategy" + tokens, attention_mask = tokenize_strategy.tokenize(text) + + # Perform encoding + hidden_states, input_ids, attention_masks = encoding_strategy.encode_tokens( + tokenize_strategy, [mock_model], (tokens, attention_mask) + ) + + # Validate outputs + assert original_isinstance(hidden_states, torch.Tensor) + assert original_isinstance(input_ids, torch.Tensor) + assert original_isinstance(attention_masks, torch.Tensor) + + # Check the shape of the second-to-last hidden state + assert hidden_states.ndim == 3 + + # Test weighted encoding (which falls back to standard encoding for Lumina) + weights = [torch.ones_like(tokens)] + hidden_states_w, input_ids_w, attention_masks_w = encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, [mock_model], (tokens, attention_mask), weights + ) + + # For the mock, we can't guarantee identical outputs since each call returns random tensors + # Instead, check that the outputs have the same shape and are tensors + assert hidden_states_w.shape == hidden_states.shape + assert original_isinstance(hidden_states_w, torch.Tensor) + assert torch.allclose(input_ids, input_ids_w) # Input IDs should be the same + assert torch.allclose(attention_masks, attention_masks_w) # Attention masks should be the same + + +def test_lumina_text_encoder_outputs_caching_strategy(): + # Create a temporary directory for caching + with tempfile.TemporaryDirectory() as tmpdir: + # Create a cache file path + cache_file = os.path.join(tmpdir, "test_outputs.npz") + + # Create the caching strategy + caching_strategy = LuminaTextEncoderOutputsCachingStrategy( + cache_to_disk=True, + batch_size=1, + skip_disk_cache_validity_check=False, + ) + + # Create a mock class for ImageInfo + class MockImageInfo: + def __init__(self, caption, system_prompt, cache_path): + self.caption = caption + self.system_prompt = system_prompt + self.text_encoder_outputs_npz = cache_path + + # Create a sample input info + image_info = MockImageInfo("Test caption", "", cache_file) + + # Simulate a batch + batch = [image_info] + + # Create mock strategies and model + tokenize_strategy = LuminaTokenizeStrategy(max_length=None) + encoding_strategy = LuminaTextEncodingStrategy() + mock_model = SimpleMockGemma2Model() + + # Patch the isinstance check to accept our simple mock + original_isinstance = isinstance + with patch("library.strategy_lumina.isinstance") as mock_isinstance: + + def custom_isinstance(obj, class_or_tuple): + if obj is mock_model and class_or_tuple is Gemma2Model: + return True + if hasattr(obj, "_orig_mod") and obj._orig_mod is mock_model and class_or_tuple is Gemma2Model: + return True + return original_isinstance(obj, class_or_tuple) + + mock_isinstance.side_effect = custom_isinstance + + # Call cache_batch_outputs + caching_strategy.cache_batch_outputs(tokenize_strategy, [mock_model], encoding_strategy, batch) + + # Verify the npz file was created + assert os.path.exists(cache_file), f"Cache file not created at {cache_file}" + + # Verify the is_disk_cached_outputs_expected method + assert caching_strategy.is_disk_cached_outputs_expected(cache_file) + + # Test loading from npz + loaded_data = caching_strategy.load_outputs_npz(cache_file) + assert len(loaded_data) == 3 # hidden_state, input_ids, attention_mask + + +def test_lumina_latents_caching_strategy(): + # Create a temporary directory for caching + with tempfile.TemporaryDirectory() as tmpdir: + # Prepare a mock absolute path + abs_path = os.path.join(tmpdir, "test_image.png") + + # Use smaller image size for faster testing + image_size = (64, 64) + + # Create a smaller dummy image for testing + test_image = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + + # Create the caching strategy + caching_strategy = LuminaLatentsCachingStrategy(cache_to_disk=True, batch_size=1, skip_disk_cache_validity_check=False) + + # Create a simple mock VAE + class MockVAE: + def __init__(self): + self.device = torch.device("cpu") + self.dtype = torch.float32 + + def encode(self, x): + # Return smaller encoded tensor for faster processing + encoded = torch.randn(1, 4, 8, 8, device=x.device) + return type("EncodedLatents", (), {"to": lambda *args, **kwargs: encoded}) + + # Prepare a mock batch + class MockImageInfo: + def __init__(self, path, image): + self.absolute_path = path + self.image = image + self.image_path = path + self.bucket_reso = image_size + self.resized_size = image_size + self.resize_interpolation = "lanczos" + # Specify full path to the latents npz file + self.latents_npz = os.path.join(tmpdir, f"{os.path.splitext(os.path.basename(path))[0]}_0064x0064_lumina.npz") + + batch = [MockImageInfo(abs_path, test_image)] + + # Call cache_batch_latents + mock_vae = MockVAE() + caching_strategy.cache_batch_latents(mock_vae, batch, flip_aug=False, alpha_mask=False, random_crop=False) + + # Generate the expected npz path + npz_path = caching_strategy.get_latents_npz_path(abs_path, image_size) + + # Verify the file was created + assert os.path.exists(npz_path), f"NPZ file not created at {npz_path}" + + # Verify is_disk_cached_latents_expected + assert caching_strategy.is_disk_cached_latents_expected(image_size, npz_path, False, False) + + # Test loading from disk + loaded_data = caching_strategy.load_latents_from_disk(npz_path, image_size) + assert len(loaded_data) == 5 # Check for 5 expected elements diff --git a/tests/test_lumina_train_network.py b/tests/test_lumina_train_network.py new file mode 100644 index 00000000..353a742f --- /dev/null +++ b/tests/test_lumina_train_network.py @@ -0,0 +1,173 @@ +import pytest +import torch +from unittest.mock import MagicMock, patch +import argparse + +from library import lumina_models, lumina_util +from lumina_train_network import LuminaNetworkTrainer + + +@pytest.fixture +def lumina_trainer(): + return LuminaNetworkTrainer() + + +@pytest.fixture +def mock_args(): + args = MagicMock() + args.pretrained_model_name_or_path = "test_path" + args.disable_mmap_load_safetensors = False + args.use_flash_attn = False + args.use_sage_attn = False + args.fp8_base = False + args.blocks_to_swap = None + args.gemma2 = "test_gemma2_path" + args.ae = "test_ae_path" + args.cache_text_encoder_outputs = True + args.cache_text_encoder_outputs_to_disk = False + args.network_train_unet_only = False + return args + + +@pytest.fixture +def mock_accelerator(): + accelerator = MagicMock() + accelerator.device = torch.device("cpu") + accelerator.prepare.side_effect = lambda x, **kwargs: x + accelerator.unwrap_model.side_effect = lambda x: x + return accelerator + + +def test_assert_extra_args(lumina_trainer, mock_args): + train_dataset_group = MagicMock() + train_dataset_group.verify_bucket_reso_steps = MagicMock() + val_dataset_group = MagicMock() + val_dataset_group.verify_bucket_reso_steps = MagicMock() + + # Test with default settings + lumina_trainer.assert_extra_args(mock_args, train_dataset_group, val_dataset_group) + + # Verify verify_bucket_reso_steps was called for both groups + assert train_dataset_group.verify_bucket_reso_steps.call_count > 0 + assert val_dataset_group.verify_bucket_reso_steps.call_count > 0 + + # Check text encoder output caching + assert lumina_trainer.train_gemma2 is (not mock_args.network_train_unet_only) + assert mock_args.cache_text_encoder_outputs is True + + +def test_load_target_model(lumina_trainer, mock_args, mock_accelerator): + # Patch lumina_util methods + with ( + patch("library.lumina_util.load_lumina_model") as mock_load_lumina_model, + patch("library.lumina_util.load_gemma2") as mock_load_gemma2, + patch("library.lumina_util.load_ae") as mock_load_ae + ): + # Create mock models + mock_model = MagicMock(spec=lumina_models.NextDiT) + mock_model.dtype = torch.float32 + mock_gemma2 = MagicMock() + mock_ae = MagicMock() + + mock_load_lumina_model.return_value = mock_model + mock_load_gemma2.return_value = mock_gemma2 + mock_load_ae.return_value = mock_ae + + # Test load_target_model + version, gemma2_list, ae, model = lumina_trainer.load_target_model(mock_args, torch.float32, mock_accelerator) + + # Verify calls and return values + assert version == lumina_util.MODEL_VERSION_LUMINA_V2 + assert gemma2_list == [mock_gemma2] + assert ae == mock_ae + assert model == mock_model + + # Verify load calls + mock_load_lumina_model.assert_called_once() + mock_load_gemma2.assert_called_once() + mock_load_ae.assert_called_once() + + +def test_get_strategies(lumina_trainer, mock_args): + # Test tokenize strategy + tokenize_strategy = lumina_trainer.get_tokenize_strategy(mock_args) + assert tokenize_strategy.__class__.__name__ == "LuminaTokenizeStrategy" + + # Test latents caching strategy + latents_strategy = lumina_trainer.get_latents_caching_strategy(mock_args) + assert latents_strategy.__class__.__name__ == "LuminaLatentsCachingStrategy" + + # Test text encoding strategy + text_encoding_strategy = lumina_trainer.get_text_encoding_strategy(mock_args) + assert text_encoding_strategy.__class__.__name__ == "LuminaTextEncodingStrategy" + + +def test_text_encoder_output_caching_strategy(lumina_trainer, mock_args): + # Call assert_extra_args to set train_gemma2 + train_dataset_group = MagicMock() + train_dataset_group.verify_bucket_reso_steps = MagicMock() + val_dataset_group = MagicMock() + val_dataset_group.verify_bucket_reso_steps = MagicMock() + lumina_trainer.assert_extra_args(mock_args, train_dataset_group, val_dataset_group) + + # With text encoder caching enabled + mock_args.skip_cache_check = False + mock_args.text_encoder_batch_size = 16 + strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args) + + assert strategy.__class__.__name__ == "LuminaTextEncoderOutputsCachingStrategy" + assert strategy.cache_to_disk is False # based on mock_args + + # With text encoder caching disabled + mock_args.cache_text_encoder_outputs = False + strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args) + assert strategy is None + + +def test_noise_scheduler(lumina_trainer, mock_args): + device = torch.device("cpu") + noise_scheduler = lumina_trainer.get_noise_scheduler(mock_args, device) + + assert noise_scheduler.__class__.__name__ == "FlowMatchEulerDiscreteScheduler" + assert noise_scheduler.num_train_timesteps == 1000 + assert hasattr(lumina_trainer, "noise_scheduler_copy") + + +def test_sai_model_spec(lumina_trainer, mock_args): + with patch("library.train_util.get_sai_model_spec") as mock_get_spec: + mock_get_spec.return_value = "test_spec" + spec = lumina_trainer.get_sai_model_spec(mock_args) + assert spec == "test_spec" + mock_get_spec.assert_called_once_with(None, mock_args, False, True, False, lumina="lumina2") + + +def test_update_metadata(lumina_trainer, mock_args): + metadata = {} + lumina_trainer.update_metadata(metadata, mock_args) + + assert "ss_weighting_scheme" in metadata + assert "ss_logit_mean" in metadata + assert "ss_logit_std" in metadata + assert "ss_mode_scale" in metadata + assert "ss_timestep_sampling" in metadata + assert "ss_sigmoid_scale" in metadata + assert "ss_model_prediction_type" in metadata + assert "ss_discrete_flow_shift" in metadata + + +def test_is_text_encoder_not_needed_for_training(lumina_trainer, mock_args): + # Test with text encoder output caching, but not training text encoder + mock_args.cache_text_encoder_outputs = True + with patch.object(lumina_trainer, 'is_train_text_encoder', return_value=False): + result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args) + assert result is True + + # Test with text encoder output caching and training text encoder + with patch.object(lumina_trainer, 'is_train_text_encoder', return_value=True): + result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args) + assert result is False + + # Test with no text encoder output caching + mock_args.cache_text_encoder_outputs = False + result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args) + assert result is False \ No newline at end of file diff --git a/train_network.py b/train_network.py index 4b8a9c73..6073c4c3 100644 --- a/train_network.py +++ b/train_network.py @@ -389,7 +389,18 @@ class NetworkTrainer: latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) else: # latentに変換 - latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype)) + if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size: + latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype)) + else: + chunks = [ + batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size) + ] + list_latents = [] + for chunk in chunks: + with torch.no_grad(): + chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype)) + list_latents.append(chunk) + latents = torch.cat(list_latents, dim=0) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): @@ -1433,11 +1444,13 @@ class NetworkTrainer: max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} else: if hasattr(network, "weight_norms"): - mean_norm = network.weight_norms().mean().item() - mean_grad_norm = network.grad_norms().mean().item() - mean_combined_norm = network.combined_weight_norms().mean().item() weight_norms = network.weight_norms() - maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None + mean_norm = weight_norms.mean().item() if weight_norms is not None else None + grad_norms = network.grad_norms() + mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None + combined_weight_norms = network.combined_weight_norms() + mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None + maximum_norm = weight_norms.max().item() if weight_norms is not None else None keys_scaled = None max_mean_logs = {} else: