From c0caf33e3fa7a99c2160946e42d4ef7b8d7660a4 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sat, 15 Feb 2025 16:38:59 +0800 Subject: [PATCH] update --- library/lumina_util.py | 8 -- lumina_train_network.py | 175 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 171 insertions(+), 12 deletions(-) diff --git a/library/lumina_util.py b/library/lumina_util.py index 990f8c68..b47e057a 100644 --- a/library/lumina_util.py +++ b/library/lumina_util.py @@ -108,14 +108,6 @@ def load_gemma2( logger.info(f"Loaded Gemma2: {info}") return gemma2 -def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int): - img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3) - img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None] - img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :] - img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size) - return img_ids - - def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor: """ x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2 diff --git a/lumina_train_network.py b/lumina_train_network.py index 40b84e14..db329a9b 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -53,7 +53,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): self.train_gemma2 = not args.network_train_unet_only def load_target_model(self, args, weight_dtype, accelerator): - loading_dtype = None if args.fp8 else weight_dtype + loading_dtype = None if args.fp8_base else weight_dtype model = lumina_util.load_lumina_model( args.pretrained_model_name_or_path, @@ -67,8 +67,12 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # model.enable_block_swap(args.blocks_to_swap, accelerator.device) # self.is_swapping_blocks = True - gemma2 = lumina_util.load_gemma2(args.gemma2, weight_dtype, "cpu") - ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu") + gemma2 = lumina_util.load_gemma2( + args.gemma2, weight_dtype, "cpu" + ) + ae = lumina_util.load_ae( + args.ae, weight_dtype, "cpu" + ) return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model @@ -168,11 +172,174 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): def shift_scale_latents(self, args, latents): return latents + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet: lumina_models.NextDiT, + network, + weight_dtype, + train_unet, + is_train=True, + ): + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = ( + flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + ) + ) + + # pack latents and get img_ids - 这部分可以保留因为NextDiT也需要packed格式的输入 + packed_noisy_model_input = lumina_util.pack_latents(noisy_model_input) + packed_latent_height, packed_latent_width = ( + noisy_model_input.shape[2] // 2, + noisy_model_input.shape[3] // 2, + ) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + if t is not None and t.dtype.is_floating_point: + t.requires_grad_(True) + + # Unpack Gemma2 outputs + gemma2_hidden_states, gemma2_attn_mask, input_ids = text_encoder_conds + if not args.apply_gemma2_attn_mask: + gemma2_attn_mask = None + + def call_dit(img, gemma2_hidden_states, input_ids, timesteps, gemma2_attn_mask): + with torch.set_grad_enabled(is_train), accelerator.autocast(): + # NextDiT forward expects (x, t, cap_feats, cap_mask) + model_pred = unet( + x=img, # packed latents + t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期 + cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features + cap_mask=gemma2_attn_mask, # Gemma2的attention mask + ) + return model_pred + + model_pred = call_dit( + img=packed_noisy_model_input, + gemma2_hidden_states=gemma2_hidden_states, + input_ids=input_ids, + timesteps=timesteps, + gemma2_attn_mask=gemma2_attn_mask, + ) + + # unpack latents + model_pred = lumina_util.unpack_latents( + model_pred, packed_latent_height, packed_latent_width + ) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type( + args, model_pred, noisy_model_input, sigmas + ) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if ( + "diff_output_preservation" in custom_attributes + and custom_attributes["diff_output_preservation"] + ): + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(): + model_pred_prior = call_dit( + img=packed_noisy_model_input[diff_output_pr_indices], + gemma2_hidden_states=gemma2_hidden_states[ + diff_output_pr_indices + ], + input_ids=input_ids[diff_output_pr_indices], + timesteps=timesteps[diff_output_pr_indices], + gemma2_attn_mask=( + gemma2_attn_mask[diff_output_pr_indices] + if gemma2_attn_mask is not None + else None + ), + ) + network.set_multiplier(1.0) + + model_pred_prior = lumina_util.unpack_latents( + model_pred_prior, packed_latent_height, packed_latent_width + ) + model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( + args, + model_pred_prior, + noisy_model_input[diff_output_pr_indices], + sigmas[diff_output_pr_indices] if sigmas is not None else None, + ) + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + + return model_pred, target, timesteps, weighting + def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss def get_sai_model_spec(self, args): - return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev") + return train_util.get_sai_model_spec( + None, args, False, True, False, lumina="lumina2" + ) + + def update_metadata(self, metadata, args): + metadata["ss_apply_gemma2_attn_mask"] = args.apply_gemma2_attn_mask + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + metadata["ss_guidance_scale"] = args.guidance_scale + metadata["ss_timestep_sampling"] = args.timestep_sampling + metadata["ss_sigmoid_scale"] = args.sigmoid_scale + metadata["ss_model_prediction_type"] = args.model_prediction_type + metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift + + def is_text_encoder_not_needed_for_training(self, args): + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + text_encoder.model.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8( + self, index, text_encoder, te_weight_dtype, weight_dtype + ): + logger.info( + f"prepare Gemma2 for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}" + ) + text_encoder.to(te_weight_dtype) # fp8 + text_encoder.model.embed_tokens.to(dtype=weight_dtype) + + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + nextdit: lumina_models.Nextdit = unet + nextdit = accelerator.prepare( + nextdit, device_placement=[not self.is_swapping_blocks] + ) + accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks( + accelerator.device + ) # reduce peak memory usage + accelerator.unwrap_model(nextdit).prepare_block_swap_before_forward() + + return nextdit def setup_parser() -> argparse.ArgumentParser: