diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 00000000..3b8943c3 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,3 @@ +# These are supported funding model platforms + +github: kohya-ss diff --git a/README.md b/README.md index ae417d05..497969ab 100644 --- a/README.md +++ b/README.md @@ -9,11 +9,26 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`. + - [FLUX.1 training](#flux1-training) - [SD3 training](#sd3-training) ### Recent Updates +May 1, 2025: +- The error when training FLUX.1 with mixed precision in flux_train.py with DeepSpeed enabled has been resolved. Thanks to sharlynxy for PR [#2060](https://github.com/kohya-ss/sd-scripts/pull/2060). Please refer to the PR for details. + - If you enable DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`. + +Apr 27, 2025: +- FLUX.1 training now supports CFG scale in the sample generation during training. Please use `--g` option, to specify the CFG scale (note that `--l` is used as the embedded guidance scale.) PR [#2064](https://github.com/kohya-ss/sd-scripts/pull/2064). + - See [here](#sample-image-generation-during-training) for details. + - If you have any issues with this, please let us know. + +Apr 6, 2025: +- IP noise gamma has been enabled in FLUX.1. Thanks to rockerBOO for PR [#1992](https://github.com/kohya-ss/sd-scripts/pull/1992). See the PR for details. + - `--ip_noise_gamma` and `--ip_noise_gamma_random_strength` are available. + Mar 30, 2025: - LoRA-GGPO is added for FLUX.1 LoRA training. Thank you to rockerBOO for PR [#1974](https://github.com/kohya-ss/sd-scripts/pull/1974). - Specify `--network_args ggpo_sigma=0.03 ggpo_beta=0.01` in the command line or `network_args = ["ggpo_sigma=0.03", "ggpo_beta=0.01"]` in .toml file. See PR for details. @@ -866,6 +881,14 @@ Note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o (Single GPU with id `0` will be used.) +## DeepSpeed installation (experimental, Linux or WSL2 only) + +To install DeepSpeed, run the following command in your activated virtual environment: + +```bash +pip install deepspeed==0.16.7 +``` + ## Upgrade When a new release comes out you can upgrade your repo with the following command: @@ -1340,11 +1363,13 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used. - * `--n` Negative prompt up to the next option. + * `--n` Negative prompt up to the next option. Ignored when CFG scale is `1.0`. * `--w` Specifies the width of the generated image. * `--h` Specifies the height of the generated image. * `--d` Specifies the seed of the generated image. * `--l` Specifies the CFG scale of the generated image. + * In guidance distillation models like FLUX.1, this value is used as the embedded guidance scale for backward compatibility. + * `--g` Specifies the CFG scale for the models with embedded guidance scale. The default is `1.0`, `1.0` means no CFG. In general, should not be changed unless you train the un-distilled FLUX.1 models. * `--s` Specifies the number of steps in the generation. The prompt weighting such as `( )` and `[ ]` are working. diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 406f12f2..d41b8a11 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -97,15 +97,19 @@ def main(args): else: for file in SUB_DIR_FILES: hf_hub_download( - args.repo_id, - file, + repo_id=args.repo_id, + filename=file, subfolder=SUB_DIR, - cache_dir=os.path.join(model_location, SUB_DIR), + local_dir=os.path.join(model_location, SUB_DIR), force_download=True, - force_filename=file, ) for file in files: - hf_hub_download(args.repo_id, file, cache_dir=model_location, force_download=True, force_filename=file) + hf_hub_download( + repo_id=args.repo_id, + filename=file, + local_dir=model_location, + force_download=True, + ) else: logger.info("using existing wd14 tagger model") diff --git a/library/deepspeed_utils.py b/library/deepspeed_utils.py index 99a7b2b3..a8a05c3a 100644 --- a/library/deepspeed_utils.py +++ b/library/deepspeed_utils.py @@ -5,6 +5,8 @@ from accelerate import DeepSpeedPlugin, Accelerator from .utils import setup_logging +from .device_utils import get_preferred_device + setup_logging() import logging @@ -94,6 +96,7 @@ def prepare_deepspeed_plugin(args: argparse.Namespace): deepspeed_plugin.deepspeed_config["train_batch_size"] = ( args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"]) ) + deepspeed_plugin.set_mixed_precision(args.mixed_precision) if args.mixed_precision.lower() == "fp16": deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow. @@ -122,18 +125,56 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models): class DeepSpeedWrapper(torch.nn.Module): def __init__(self, **kw_models) -> None: super().__init__() + self.models = torch.nn.ModuleDict() + + wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no" for key, model in kw_models.items(): if isinstance(model, list): model = torch.nn.ModuleList(model) + + if wrap_model_forward_with_torch_autocast: + model = self.__wrap_model_with_torch_autocast(model) + assert isinstance( model, torch.nn.Module ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}" + self.models.update(torch.nn.ModuleDict({key: model})) + def __wrap_model_with_torch_autocast(self, model): + if isinstance(model, torch.nn.ModuleList): + model = torch.nn.ModuleList([self.__wrap_model_forward_with_torch_autocast(m) for m in model]) + else: + model = self.__wrap_model_forward_with_torch_autocast(model) + return model + + def __wrap_model_forward_with_torch_autocast(self, model): + + assert hasattr(model, "forward"), f"model must have a forward method." + + forward_fn = model.forward + + def forward(*args, **kwargs): + try: + device_type = model.device.type + except AttributeError: + logger.warning( + "[DeepSpeed] model.device is not available. Using get_preferred_device() " + "to determine the device_type for torch.autocast()." + ) + device_type = get_preferred_device().type + + with torch.autocast(device_type = device_type): + return forward_fn(*args, **kwargs) + + model.forward = forward + return model + def get_models(self): return self.models + ds_model = DeepSpeedWrapper(**models) return ds_model diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index c6d2baeb..5f6867a8 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -40,7 +40,7 @@ def sample_images( text_encoders, sample_prompts_te_outputs, prompt_replacement=None, - controlnet=None + controlnet=None, ): if steps == 0: if not args.sample_at_first: @@ -101,7 +101,7 @@ def sample_images( steps, sample_prompts_te_outputs, prompt_replacement, - controlnet + controlnet, ) else: # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) @@ -125,7 +125,7 @@ def sample_images( steps, sample_prompts_te_outputs, prompt_replacement, - controlnet + controlnet, ) torch.set_rng_state(rng_state) @@ -147,14 +147,16 @@ def sample_image_inference( steps, sample_prompts_te_outputs, prompt_replacement, - controlnet + controlnet, ): assert isinstance(prompt_dict, dict) - # negative_prompt = prompt_dict.get("negative_prompt") + negative_prompt = prompt_dict.get("negative_prompt") sample_steps = prompt_dict.get("sample_steps", 20) width = prompt_dict.get("width", 512) height = prompt_dict.get("height", 512) - scale = prompt_dict.get("scale", 3.5) + # TODO refactor variable names + cfg_scale = prompt_dict.get("guidance_scale", 1.0) + emb_guidance_scale = prompt_dict.get("scale", 3.5) seed = prompt_dict.get("seed") controlnet_image = prompt_dict.get("controlnet_image") prompt: str = prompt_dict.get("prompt", "") @@ -162,8 +164,8 @@ def sample_image_inference( if prompt_replacement is not None: prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) - # if negative_prompt is not None: - # negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) if seed is not None: torch.manual_seed(seed) @@ -173,16 +175,21 @@ def sample_image_inference( torch.seed() torch.cuda.seed() - # if negative_prompt is None: - # negative_prompt = "" + if negative_prompt is None: + negative_prompt = "" height = max(64, height - height % 16) # round to divisible by 16 width = max(64, width - width % 16) # round to divisible by 16 logger.info(f"prompt: {prompt}") - # logger.info(f"negative_prompt: {negative_prompt}") + if cfg_scale != 1.0: + logger.info(f"negative_prompt: {negative_prompt}") + elif negative_prompt != "": + logger.info(f"negative prompt is ignored because scale is 1.0") logger.info(f"height: {height}") logger.info(f"width: {width}") logger.info(f"sample_steps: {sample_steps}") - logger.info(f"scale: {scale}") + logger.info(f"embedded guidance scale: {emb_guidance_scale}") + if cfg_scale != 1.0: + logger.info(f"CFG scale: {cfg_scale}") # logger.info(f"sample_sampler: {sampler_name}") if seed is not None: logger.info(f"seed: {seed}") @@ -191,26 +198,37 @@ def sample_image_inference( tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() - text_encoder_conds = [] - if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: - text_encoder_conds = sample_prompts_te_outputs[prompt] - print(f"Using cached text encoder outputs for prompt: {prompt}") - if text_encoders is not None: - print(f"Encoding prompt: {prompt}") - tokens_and_masks = tokenize_strategy.tokenize(prompt) - # strategy has apply_t5_attn_mask option - encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + def encode_prompt(prpt): + text_encoder_conds = [] + if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs: + text_encoder_conds = sample_prompts_te_outputs[prpt] + print(f"Using cached text encoder outputs for prompt: {prpt}") + if text_encoders is not None: + print(f"Encoding prompt: {prpt}") + tokens_and_masks = tokenize_strategy.tokenize(prpt) + # strategy has apply_t5_attn_mask option + encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) - # if text_encoder_conds is not cached, use encoded_text_encoder_conds - if len(text_encoder_conds) == 0: - text_encoder_conds = encoded_text_encoder_conds - else: - # if encoded_text_encoder_conds is not None, update cached text_encoder_conds - for i in range(len(encoded_text_encoder_conds)): - if encoded_text_encoder_conds[i] is not None: - text_encoder_conds[i] = encoded_text_encoder_conds[i] + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] + return text_encoder_conds - l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + l_pooled, t5_out, txt_ids, t5_attn_mask = encode_prompt(prompt) + # encode negative prompts + if cfg_scale != 1.0: + neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode_prompt(negative_prompt) + neg_t5_attn_mask = ( + neg_t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask and neg_t5_attn_mask is not None else None + ) + neg_cond = (cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask) + else: + neg_cond = None # sample image weight_dtype = ae.dtype # TOFO give dtype as argument @@ -235,7 +253,20 @@ def sample_image_inference( controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device) with accelerator.autocast(), torch.no_grad(): - x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) + x = denoise( + flux, + noise, + img_ids, + t5_out, + txt_ids, + l_pooled, + timesteps=timesteps, + guidance=emb_guidance_scale, + t5_attn_mask=t5_attn_mask, + controlnet=controlnet, + controlnet_img=controlnet_image, + neg_cond=neg_cond, + ) x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) @@ -305,21 +336,24 @@ def denoise( model: flux_models.Flux, img: torch.Tensor, img_ids: torch.Tensor, - txt: torch.Tensor, + txt: torch.Tensor, # t5_out txt_ids: torch.Tensor, - vec: torch.Tensor, + vec: torch.Tensor, # l_pooled timesteps: list[float], guidance: float = 4.0, t5_attn_mask: Optional[torch.Tensor] = None, controlnet: Optional[flux_models.ControlNetFlux] = None, controlnet_img: Optional[torch.Tensor] = None, + neg_cond: Optional[Tuple[float, torch.Tensor, torch.Tensor, torch.Tensor]] = None, ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + do_cfg = neg_cond is not None for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) model.prepare_block_swap_before_forward() + if controlnet is not None: block_samples, block_single_samples = controlnet( img=img, @@ -335,20 +369,48 @@ def denoise( else: block_samples = None block_single_samples = None - pred = model( - img=img, - img_ids=img_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, - block_controlnet_hidden_states=block_samples, - block_controlnet_single_hidden_states=block_single_samples, - timesteps=t_vec, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, - ) - img = img + (t_prev - t_curr) * pred + if not do_cfg: + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + block_controlnet_hidden_states=block_samples, + block_controlnet_single_hidden_states=block_single_samples, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + img = img + (t_prev - t_curr) * pred + else: + cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask = neg_cond + nc_c_t5_attn_mask = None if t5_attn_mask is None else torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0) + + # TODO is it ok to use the same block samples for both cond and uncond? + block_samples = None if block_samples is None else torch.cat([block_samples, block_samples], dim=0) + block_single_samples = ( + None if block_single_samples is None else torch.cat([block_single_samples, block_single_samples], dim=0) + ) + + nc_c_pred = model( + img=torch.cat([img, img], dim=0), + img_ids=torch.cat([img_ids, img_ids], dim=0), + txt=torch.cat([neg_t5_out, txt], dim=0), + txt_ids=torch.cat([txt_ids, txt_ids], dim=0), + y=torch.cat([neg_l_pooled, vec], dim=0), + block_controlnet_hidden_states=block_samples, + block_controlnet_single_hidden_states=block_single_samples, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=nc_c_t5_attn_mask, + ) + neg_pred, pred = torch.chunk(nc_c_pred, 2, dim=0) + pred = neg_pred + (pred - neg_pred) * cfg_scale + + img = img + (t_prev - t_curr) * pred model.prepare_block_swap_before_forward() return img @@ -365,8 +427,6 @@ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < n_dim: - sigma = sigma.unsqueeze(-1) return sigma @@ -409,42 +469,34 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, device, dtype + args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, _, h, w = latents.shape - sigmas = None - + assert bsz > 0, "Batch size not large enough" + num_timesteps = noise_scheduler.config.num_train_timesteps if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": - # Simple random t-based noise sampling + # Simple random sigma-based noise sampling if args.timestep_sampling == "sigmoid": # https://github.com/XLabs-AI/x-flux/tree/main - t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) + sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) else: - t = torch.rand((bsz,), device=device) + sigmas = torch.rand((bsz,), device=device) - timesteps = t * 1000.0 - t = t.view(-1, 1, 1, 1) - noisy_model_input = (1 - t) * latents + t * noise + timesteps = sigmas * num_timesteps elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift - logits_norm = torch.randn(bsz, device=device) - logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling - timesteps = logits_norm.sigmoid() - timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) - - t = timesteps.view(-1, 1, 1, 1) - timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * noise + sigmas = torch.randn(bsz, device=device) + sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling + sigmas = sigmas.sigmoid() + sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas) + timesteps = sigmas * num_timesteps elif args.timestep_sampling == "flux_shift": - logits_norm = torch.randn(bsz, device=device) - logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling - timesteps = logits_norm.sigmoid() - mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) - timesteps = time_shift(mu, 1.0, timesteps) - - t = timesteps.view(-1, 1, 1, 1) - timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * noise + sigmas = torch.randn(bsz, device=device) + sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling + sigmas = sigmas.sigmoid() + mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size + sigmas = time_shift(mu, 1.0, sigmas) + timesteps = sigmas * num_timesteps else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -455,12 +507,24 @@ def get_noisy_model_input_and_timesteps( logit_std=args.logit_std, mode_scale=args.mode_scale, ) - indices = (u * noise_scheduler.config.num_train_timesteps).long() + indices = (u * num_timesteps).long() timesteps = noise_scheduler.timesteps[indices].to(device=device) - - # Add noise according to flow matching. sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + # Broadcast sigmas to latent shape + sigmas = sigmas.view(-1, 1, 1, 1) + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.ip_noise_gamma: + xi = torch.randn_like(latents, device=latents.device, dtype=dtype) + if args.ip_noise_gamma_random_strength: + ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma + else: + ip_noise_gamma = args.ip_noise_gamma + noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi) + else: + noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas @@ -566,7 +630,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): "--controlnet_model_name_or_path", type=str, default=None, - help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)" + help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)", ) parser.add_argument( "--t5xxl_max_token_length", diff --git a/library/train_util.py b/library/train_util.py index e2d0d175..68019e21 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1070,8 +1070,11 @@ class BaseDataset(torch.utils.data.Dataset): self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)} logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}") - img_ar_errors = np.array(img_ar_errors) - mean_img_ar_error = np.mean(np.abs(img_ar_errors)) + if len(img_ar_errors) == 0: + mean_img_ar_error = 0 # avoid NaN + else: + img_ar_errors = np.array(img_ar_errors) + mean_img_ar_error = np.mean(np.abs(img_ar_errors)) self.bucket_info["mean_img_ar_error"] = mean_img_ar_error logger.info(f"mean ar error (without repeats): {mean_img_ar_error}") @@ -5520,6 +5523,11 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio def patch_accelerator_for_fp16_training(accelerator): + + from accelerate import DistributedType + if accelerator.distributed_type == DistributedType.DEEPSPEED: + return + org_unscale_grads = accelerator.scaler._unscale_grads_ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): @@ -6203,6 +6211,11 @@ def line_to_prompt_dict(line: str) -> dict: prompt_dict["scale"] = float(m.group(1)) continue + m = re.match(r"g ([\d\.]+)", parg, re.IGNORECASE) + if m: # guidance scale + prompt_dict["guidance_scale"] = float(m.group(1)) + continue + m = re.match(r"n (.+)", parg, re.IGNORECASE) if m: # negative prompt prompt_dict["negative_prompt"] = m.group(1) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 92b3979a..0b30f1b8 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -955,26 +955,26 @@ class LoRANetwork(torch.nn.Module): for lora in self.text_encoder_loras + self.unet_loras: lora.update_grad_norms() - def grad_norms(self) -> Tensor: + def grad_norms(self) -> Tensor | None: grad_norms = [] for lora in self.text_encoder_loras + self.unet_loras: if hasattr(lora, "grad_norms") and lora.grad_norms is not None: grad_norms.append(lora.grad_norms.mean(dim=0)) - return torch.stack(grad_norms) if len(grad_norms) > 0 else torch.tensor([]) + return torch.stack(grad_norms) if len(grad_norms) > 0 else None - def weight_norms(self) -> Tensor: + def weight_norms(self) -> Tensor | None: weight_norms = [] for lora in self.text_encoder_loras + self.unet_loras: if hasattr(lora, "weight_norms") and lora.weight_norms is not None: weight_norms.append(lora.weight_norms.mean(dim=0)) - return torch.stack(weight_norms) if len(weight_norms) > 0 else torch.tensor([]) + return torch.stack(weight_norms) if len(weight_norms) > 0 else None - def combined_weight_norms(self) -> Tensor: + def combined_weight_norms(self) -> Tensor | None: combined_weight_norms = [] for lora in self.text_encoder_loras + self.unet_loras: if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None: combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0)) - return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([]) + return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None def load_weights(self, file): diff --git a/pytest.ini b/pytest.ini index 484d3aef..34b7e9c1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -6,3 +6,4 @@ filterwarnings = ignore::DeprecationWarning ignore::UserWarning ignore::FutureWarning +pythonpath = . diff --git a/tests/library/test_flux_train_utils.py b/tests/library/test_flux_train_utils.py new file mode 100644 index 00000000..2ad7ce4e --- /dev/null +++ b/tests/library/test_flux_train_utils.py @@ -0,0 +1,220 @@ +import pytest +import torch +from unittest.mock import MagicMock, patch +from library.flux_train_utils import ( + get_noisy_model_input_and_timesteps, +) + +# Mock classes and functions +class MockNoiseScheduler: + def __init__(self, num_train_timesteps=1000): + self.config = MagicMock() + self.config.num_train_timesteps = num_train_timesteps + self.timesteps = torch.arange(num_train_timesteps, dtype=torch.long) + + +# Create fixtures for commonly used objects +@pytest.fixture +def args(): + args = MagicMock() + args.timestep_sampling = "uniform" + args.weighting_scheme = "uniform" + args.logit_mean = 0.0 + args.logit_std = 1.0 + args.mode_scale = 1.0 + args.sigmoid_scale = 1.0 + args.discrete_flow_shift = 3.1582 + args.ip_noise_gamma = None + args.ip_noise_gamma_random_strength = False + return args + + +@pytest.fixture +def noise_scheduler(): + return MockNoiseScheduler(num_train_timesteps=1000) + + +@pytest.fixture +def latents(): + return torch.randn(2, 4, 8, 8) + + +@pytest.fixture +def noise(): + return torch.randn(2, 4, 8, 8) + + +@pytest.fixture +def device(): + # return "cuda" if torch.cuda.is_available() else "cpu" + return "cpu" + + +# Mock the required functions +@pytest.fixture(autouse=True) +def mock_functions(): + with ( + patch("torch.sigmoid", side_effect=torch.sigmoid), + patch("torch.rand", side_effect=torch.rand), + patch("torch.randn", side_effect=torch.randn), + ): + yield + + +# Test different timestep sampling methods +def test_uniform_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "uniform" + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + assert noisy_input.dtype == dtype + assert timesteps.dtype == dtype + + +def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "sigmoid" + args.sigmoid_scale = 1.0 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_shift_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "shift" + args.sigmoid_scale = 1.0 + args.discrete_flow_shift = 3.1582 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "flux_shift" + args.sigmoid_scale = 1.0 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_weighting_scheme(args, noise_scheduler, latents, noise, device): + # Mock the necessary functions for this specific test + with patch("library.flux_train_utils.compute_density_for_timestep_sampling", + return_value=torch.tensor([0.3, 0.7], device=device)), \ + patch("library.flux_train_utils.get_sigmas", + return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)): + + args.timestep_sampling = "other" # Will trigger the weighting scheme path + args.weighting_scheme = "uniform" + args.logit_mean = 0.0 + args.logit_std = 1.0 + args.mode_scale = 1.0 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, device, dtype + ) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +# Test IP noise options +def test_with_ip_noise(args, noise_scheduler, latents, noise, device): + args.ip_noise_gamma = 0.5 + args.ip_noise_gamma_random_strength = False + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device): + args.ip_noise_gamma = 0.1 + args.ip_noise_gamma_random_strength = True + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +# Test different data types +def test_float16_dtype(args, noise_scheduler, latents, noise, device): + dtype = torch.float16 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.dtype == dtype + assert timesteps.dtype == dtype + + +# Test different batch sizes +def test_different_batch_size(args, noise_scheduler, device): + latents = torch.randn(5, 4, 8, 8) # batch size of 5 + noise = torch.randn(5, 4, 8, 8) + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (5,) + assert sigmas.shape == (5, 1, 1, 1) + + +# Test different image sizes +def test_different_image_size(args, noise_scheduler, device): + latents = torch.randn(2, 4, 16, 16) # larger image size + noise = torch.randn(2, 4, 16, 16) + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (2,) + assert sigmas.shape == (2, 1, 1, 1) + + +# Test edge cases +def test_zero_batch_size(args, noise_scheduler, device): + with pytest.raises(AssertionError): # expecting an error with zero batch size + latents = torch.randn(0, 4, 8, 8) + noise = torch.randn(0, 4, 8, 8) + dtype = torch.float32 + + get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + +def test_different_timestep_count(args, device): + noise_scheduler = MockNoiseScheduler(num_train_timesteps=500) # different timestep count + latents = torch.randn(2, 4, 8, 8) + noise = torch.randn(2, 4, 8, 8) + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (2,) + # Check that timesteps are within the proper range + assert torch.all(timesteps < 500) diff --git a/train_network.py b/train_network.py index 4b8a9c73..6073c4c3 100644 --- a/train_network.py +++ b/train_network.py @@ -389,7 +389,18 @@ class NetworkTrainer: latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) else: # latentに変換 - latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype)) + if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size: + latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype)) + else: + chunks = [ + batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size) + ] + list_latents = [] + for chunk in chunks: + with torch.no_grad(): + chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype)) + list_latents.append(chunk) + latents = torch.cat(list_latents, dim=0) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): @@ -1433,11 +1444,13 @@ class NetworkTrainer: max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} else: if hasattr(network, "weight_norms"): - mean_norm = network.weight_norms().mean().item() - mean_grad_norm = network.grad_norms().mean().item() - mean_combined_norm = network.combined_weight_norms().mean().item() weight_norms = network.weight_norms() - maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None + mean_norm = weight_norms.mean().item() if weight_norms is not None else None + grad_norms = network.grad_norms() + mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None + combined_weight_norms = network.combined_weight_norms() + mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None + maximum_norm = weight_norms.max().item() if weight_norms is not None else None keys_scaled = None max_mean_logs = {} else: