From ce37c08b9a3b8e6567c70712f9d6899a304e98b6 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Wed, 26 Feb 2025 11:20:03 +0800 Subject: [PATCH] clean code and add finetune code --- library/lumina_train_util.py | 212 ++++++-- lumina_train.py | 953 +++++++++++++++++++++++++++++++++++ lumina_train_network.py | 37 +- 3 files changed, 1118 insertions(+), 84 deletions(-) create mode 100644 lumina_train.py diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 20df7eef..ca039167 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -32,7 +32,9 @@ logger = logging.getLogger(__name__) # region sample images -def batchify(prompt_dicts, batch_size=None) -> Generator[list[dict[str, str]], None, None]: +def batchify( + prompt_dicts, batch_size=None +) -> Generator[list[dict[str, str]], None, None]: """ Group prompt dictionaries into batches with configurable batch size. @@ -64,7 +66,15 @@ def batchify(prompt_dicts, batch_size=None) -> Generator[list[dict[str, str]], N seed = int(seed) if seed is not None else None # Create a key based on the parameters - key = (width, height, guidance_scale, seed, sample_steps, cfg_trunc_ratio, renorm_cfg) + key = ( + width, + height, + guidance_scale, + seed, + sample_steps, + cfg_trunc_ratio, + renorm_cfg, + ) # Add the prompt_dict to the corresponding batch if key not in batches: @@ -131,7 +141,9 @@ def sample_images( if epoch is None or epoch % args.sample_every_n_epochs != 0: return else: - if global_step % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + if ( + global_step % args.sample_every_n_steps != 0 or epoch is not None + ): # steps is not divisible or end of epoch return assert ( @@ -139,12 +151,21 @@ def sample_images( ), "No sample prompts found. Provide `--sample_prompts` / サンプルプロンプトが見つかりません。`--sample_prompts` を指定してください" logger.info("") - logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {global_step}") - if not os.path.isfile(args.sample_prompts) and sample_prompts_gemma2_outputs is None: - logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + logger.info( + f"generating sample images at step / サンプル画像生成 ステップ: {global_step}" + ) + if ( + not os.path.isfile(args.sample_prompts) + and sample_prompts_gemma2_outputs is None + ): + logger.error( + f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}" + ) return - distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + distributed_state = ( + PartialState() + ) # for multi gpu distributed inference. this is a singleton, so it's safe to use it here # unwrap nextdit and gemma2_model nextdit = accelerator.unwrap_model(nextdit) @@ -163,7 +184,9 @@ def sample_images( rng_state = torch.get_rng_state() cuda_rng_state = None try: - cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + cuda_rng_state = ( + torch.cuda.get_rng_state() if torch.cuda.is_available() else None + ) except Exception: pass @@ -194,7 +217,9 @@ def sample_images( for i in range(distributed_state.num_processes): per_process_prompts.append(prompts[i :: distributed_state.num_processes]) - with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + with distributed_state.split_between_processes( + per_process_prompts + ) as prompt_dict_lists: # TODO: batch prompts together with buckets of image sizes for prompt_dicts in batchify(prompt_dict_lists[0], batch_size): sample_image_inference( @@ -289,7 +314,9 @@ def sample_image_inference( if prompt_replacement is not None: prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) if negative_prompt is not None: - negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + negative_prompt = negative_prompt.replace( + prompt_replacement[0], prompt_replacement[1] + ) if negative_prompt is None: negative_prompt = "" @@ -314,17 +341,26 @@ def sample_image_inference( gemma2_conds = sample_prompts_gemma2_outputs[prompt] logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}") - if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs: + if ( + sample_prompts_gemma2_outputs + and negative_prompt in sample_prompts_gemma2_outputs + ): neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt] - logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}") + logger.info( + f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}" + ) # Load sample prompts from Gemma 2 if gemma2_model is not None: tokens_and_masks = tokenize_strategy.tokenize(prompt) - gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) + gemma2_conds = encoding_strategy.encode_tokens( + tokenize_strategy, [gemma2_model], tokens_and_masks + ) tokens_and_masks = tokenize_strategy.tokenize(negative_prompt) - neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks) + neg_gemma2_conds = encoding_strategy.encode_tokens( + tokenize_strategy, [gemma2_model], tokens_and_masks + ) # Unpack Gemma2 outputs gemma2_hidden_states, _, gemma2_attn_mask = gemma2_conds @@ -340,10 +376,18 @@ def sample_image_inference( ) # Stack conditioning - cond_hidden_states = torch.stack([text_cond[0] for text_cond in text_conds]).to(accelerator.device) - cond_attn_masks = torch.stack([text_cond[1] for text_cond in text_conds]).to(accelerator.device) - uncond_hidden_states = torch.stack([text_cond[2] for text_cond in text_conds]).to(accelerator.device) - uncond_attn_masks = torch.stack([text_cond[3] for text_cond in text_conds]).to(accelerator.device) + cond_hidden_states = torch.stack([text_cond[0] for text_cond in text_conds]).to( + accelerator.device + ) + cond_attn_masks = torch.stack([text_cond[1] for text_cond in text_conds]).to( + accelerator.device + ) + uncond_hidden_states = torch.stack([text_cond[2] for text_cond in text_conds]).to( + accelerator.device + ) + uncond_attn_masks = torch.stack([text_cond[3] for text_cond in text_conds]).to( + accelerator.device + ) # sample image weight_dtype = vae.dtype # TOFO give dtype as argument @@ -362,7 +406,9 @@ def sample_image_inference( noise = noise.repeat(cond_hidden_states.shape[0], 1, 1, 1) scheduler = FlowMatchEulerDiscreteScheduler(shift=6.0) - timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=sample_steps) + timesteps, num_inference_steps = retrieve_timesteps( + scheduler, num_inference_steps=sample_steps + ) # if controlnet_image is not None: # controlnet_image = Image.open(controlnet_image).convert("RGB") @@ -422,7 +468,9 @@ def sample_image_inference( import wandb # not to commit images to avoid inconsistency between training and logging steps - wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption + wandb_tracker.log( + {f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False + ) # positive prompt as a caption vae.to(org_vae_device) clean_memory_on_device(accelerator.device) @@ -437,7 +485,9 @@ def time_shift(mu: float, sigma: float, t: torch.Tensor): return t -def get_lin_function(x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15) -> Callable[[float], float]: +def get_lin_function( + x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15 +) -> Callable[[float], float]: """ Get linear function @@ -481,7 +531,9 @@ def get_schedule( # shifting the schedule to favor high timesteps for higher signal images if shift: # eastimate mu based on linear estimation between two points - mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)(image_seq_len) + mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)( + image_seq_len + ) timesteps = time_shift(mu, 1.0, timesteps) return timesteps.tolist() @@ -520,9 +572,13 @@ def retrieve_timesteps( second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" @@ -532,7 +588,9 @@ def retrieve_timesteps( timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" @@ -593,7 +651,9 @@ def denoise( # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image current_timestep = 1 - t / scheduler.config.num_train_timesteps # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - current_timestep = current_timestep * torch.ones(img.shape[0], device=img.device) + current_timestep = current_timestep * torch.ones( + img.shape[0], device=img.device + ) noise_pred_cond = model( img, @@ -610,12 +670,20 @@ def denoise( cap_feats=neg_txt, # Gemma2的hidden states作为caption features cap_mask=neg_txt_mask.to(dtype=torch.int32), # Gemma2的attention mask ) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_cond - noise_pred_uncond + ) # apply normalization after classifier-free guidance if float(renorm_cfg) > 0.0: - cond_norm = torch.linalg.vector_norm(noise_pred_cond, dim=tuple(range(1, len(noise_pred_cond.shape))), keepdim=True) + cond_norm = torch.linalg.vector_norm( + noise_pred_cond, + dim=tuple(range(1, len(noise_pred_cond.shape))), + keepdim=True, + ) max_new_norm = cond_norm * float(renorm_cfg) - noise_norm = torch.linalg.vector_norm(noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True) + noise_norm = torch.linalg.vector_norm( + noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True + ) if noise_norm >= max_new_norm: noise_pred = noise_pred * (max_new_norm / noise_norm) else: @@ -640,7 +708,11 @@ def denoise( # region train def get_sigmas( - noise_scheduler: FlowMatchEulerDiscreteScheduler, timesteps: Tensor, device: torch.device, n_dim=4, dtype=torch.float32 + noise_scheduler: FlowMatchEulerDiscreteScheduler, + timesteps: Tensor, + device: torch.device, + n_dim=4, + dtype=torch.float32, ) -> Tensor: """ Get sigmas for timesteps @@ -667,7 +739,11 @@ def get_sigmas( def compute_density_for_timestep_sampling( - weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None + weighting_scheme: str, + batch_size: int, + logit_mean: float = None, + logit_std: float = None, + mode_scale: float = None, ): """ Compute the density for sampling the timesteps when doing SD3 training. @@ -688,7 +764,9 @@ def compute_density_for_timestep_sampling( """ if weighting_scheme == "logit_normal": # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.normal( + mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu" + ) u = torch.nn.functional.sigmoid(u) elif weighting_scheme == "mode": u = torch.rand(size=(batch_size,), device="cpu") @@ -722,7 +800,9 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None) -> Tensor return weighting -def get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) -> Tuple[Tensor, Tensor, Tensor]: +def get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, device, dtype +) -> Tuple[Tensor, Tensor, Tensor]: """ Get noisy model input and timesteps. @@ -753,27 +833,27 @@ def get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, d timesteps = t * 1000.0 t = t.view(-1, 1, 1, 1) - noisy_model_input = (1 - t) * latents + t * noise + noisy_model_input = (1 - t) * noise + t * latents elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift logits_norm = torch.randn(bsz, device=device) - logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling + logits_norm = ( + logits_norm * args.sigmoid_scale + ) # larger scale for more uniform sampling timesteps = logits_norm.sigmoid() timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) t = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * noise + noisy_model_input = (1 - t) * noise + t * latents elif args.timestep_sampling == "nextdit_shift": - logits_norm = torch.randn(bsz, device=device) - logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling - timesteps = logits_norm.sigmoid() - mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) - timesteps = time_shift(mu, 1.0, timesteps) + t = torch.rand((bsz,), device=device) + mu = get_lin_function(y1=0.5, y2=1.15)((h // 16) * (w // 16)) # lumina use //16 + t = time_shift(mu, 1.0, t) - t = timesteps.view(-1, 1, 1, 1) - timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * noise + timesteps = t * 1000.0 + t = t.view(-1, 1, 1, 1) + noisy_model_input = (1 - t) * noise + t * latents else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -788,8 +868,10 @@ def get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, d timesteps = noise_scheduler.timesteps[indices].to(device=device) # Add noise according to flow matching. - sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + sigmas = get_sigmas( + noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype + ) + noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas @@ -821,7 +903,9 @@ def apply_model_prediction_type( # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + weighting = compute_loss_weighting_for_sd3( + weighting_scheme=args.weighting_scheme, sigmas=sigmas + ) return model_pred, weighting @@ -863,15 +947,27 @@ def save_models( def save_lumina_model_on_train_end( - args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, lumina: lumina_models.NextDiT + args: argparse.Namespace, + save_dtype: torch.dtype, + epoch: int, + global_step: int, + lumina: lumina_models.NextDiT, ): def sd_saver(ckpt_file, epoch_no, global_step): sai_metadata = train_util.get_sai_model_spec( - None, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2" + None, + args, + False, + False, + False, + is_stable_diffusion_ckpt=True, + lumina="lumina2", ) save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save) - train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) + train_util.save_sd_model_on_train_end_common( + args, True, True, epoch, global_step, sd_saver, None + ) # epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合してている @@ -901,7 +997,15 @@ def save_lumina_model_on_epoch_end_or_stepwise( """ def sd_saver(ckpt_file: str, epoch_no: int, global_step: int): - sai_metadata = train_util.get_sai_model_spec({}, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2") + sai_metadata = train_util.get_sai_model_spec( + {}, + args, + False, + False, + False, + is_stable_diffusion_ckpt=True, + lumina="lumina2", + ) save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save) train_util.save_sd_model_on_epoch_end_or_stepwise_common( @@ -927,7 +1031,11 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): type=str, help="path to gemma2 model (*.sft or *.safetensors), should be float16 / gemma2のパス(*.sftまたは*.safetensors)、float16が前提", ) - parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") + parser.add_argument( + "--ae", + type=str, + help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)", + ) parser.add_argument( "--gemma2_max_token_length", type=int, diff --git a/lumina_train.py b/lumina_train.py new file mode 100644 index 00000000..330d0093 --- /dev/null +++ b/lumina_train.py @@ -0,0 +1,953 @@ +# training with captions + +# Swap blocks between CPU and GPU: +# This implementation is inspired by and based on the work of 2kpr. +# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading. +# The original idea has been adapted and extended to fit the current project's needs. + +# Key features: +# - CPU offloading during forward and backward passes +# - Use of fused optimizer and grad_hook for efficient gradient processing +# - Per-block fused optimizer instances + +import argparse +import copy +import math +import os +from multiprocessing import Value +import toml + +from tqdm import tqdm + +import torch +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate.utils import set_seed +from library import ( + deepspeed_utils, + lumina_train_util, + lumina_util, + strategy_base, + strategy_lumina, +) +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler + +import library.train_util as train_util + +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import library.config_util as config_util + +# import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + # sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + + # assert ( + # not args.weighted_captions + # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cpu_offload_checkpointing and not args.gradient_checkpointing: + logger.warning( + "cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります" + ) + args.gradient_checkpointing = True + + # assert ( + # args.blocks_to_swap is None or args.blocks_to_swap == 0 + # ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_lumina.LuminaLatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator( + ConfigSanitizer(True, True, args.masked_loss, True) + ) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group, val_dataset_group = ( + config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + ) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = ( + train_dataset_group if args.max_data_loader_n_workers == 0 else None + ) + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 + + if args.debug_dataset: + if args.cache_text_encoder_outputs: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + strategy_lumina.LuminaTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + False, + ) + ) + strategy_base.TokenizeStrategy.set_strategy( + strategy_lumina.LuminaTokenizeStrategy() + ) + + train_dataset_group.set_current_strategies() + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + + # load VAE for caching latents + ae = None + if cache_latents: + ae = lumina_util.load_ae( + args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors + ) + ae.to(accelerator.device, dtype=weight_dtype) + ae.requires_grad_(False) + ae.eval() + + train_dataset_group.new_cache_latents(ae, accelerator) + + ae.to("cpu") # if no sampling, vae can be deleted + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # prepare tokenize strategy + if args.gemma2_max_token_length is None: + gemma2_max_token_length = 256 + else: + gemma2_max_token_length = args.gemma2_max_token_length + + lumina_tokenize_strategy = strategy_lumina.LuminaTokenizeStrategy( + gemma2_max_token_length + ) + strategy_base.TokenizeStrategy.set_strategy(lumina_tokenize_strategy) + + # load gemma2 for caching text encoder outputs + gemma2 = lumina_util.load_gemma2( + args.gemma2, weight_dtype, "cpu", args.disable_mmap_load_safetensors + ) + gemma2.eval() + gemma2.requires_grad_(False) + + text_encoding_strategy = strategy_lumina.LuminaTextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # cache text encoder outputs + sample_prompts_te_outputs = None + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad here + gemma2.to(accelerator.device) + + text_encoder_caching_strategy = ( + strategy_lumina.LuminaTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + False, + False, + ) + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + text_encoder_caching_strategy + ) + + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([gemma2], accelerator) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info( + f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}" + ) + + text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = ( + strategy_base.TextEncodingStrategy.get_strategy() + ) + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [ + prompt_dict.get("prompt", ""), + prompt_dict.get("negative_prompt", ""), + ]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = lumina_tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = ( + text_encoding_strategy.encode_tokens( + lumina_tokenize_strategy, + [gemma2], + tokens_and_masks, + ) + ) + + accelerator.wait_for_everyone() + + # now we can delete Text Encoders to free memory + gemma2 = None + clean_memory_on_device(accelerator.device) + + # load lumina + nextdit = lumina_util.load_lumina_model( + args.pretrained_model_name_or_path, + loading_dtype, + torch.device("cpu"), + disable_mmap=args.disable_mmap_load_safetensors, + use_flash_attn=args.use_flash_attn, + ) + + if args.gradient_checkpointing: + nextdit.enable_gradient_checkpointing( + cpu_offload=args.cpu_offload_checkpointing + ) + + nextdit.requires_grad_(True) + + # block swap + + # backward compatibility + # if args.blocks_to_swap is None: + # blocks_to_swap = args.double_blocks_to_swap or 0 + # if args.single_blocks_to_swap is not None: + # blocks_to_swap += args.single_blocks_to_swap // 2 + # if blocks_to_swap > 0: + # logger.warning( + # "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead." + # " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。" + # ) + # logger.info( + # f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}." + # ) + # args.blocks_to_swap = blocks_to_swap + # del blocks_to_swap + + # is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + # if is_swapping_blocks: + # # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # # This idea is based on 2kpr's great work. Thank you! + # logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + # flux.enable_block_swap(args.blocks_to_swap, accelerator.device) + + if not cache_latents: + # load VAE here if not cached + ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu") + ae.requires_grad_(False) + ae.eval() + ae.to(accelerator.device, dtype=weight_dtype) + + training_models = [] + params_to_optimize = [] + training_models.append(nextdit) + name_and_params = list(nextdit.named_parameters()) + # single param group for now + params_to_optimize.append( + {"params": [p for _, p in name_and_params], "lr": args.learning_rate} + ) + param_names = [[n for n, _ in name_and_params]] + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + if args.blockwise_fused_optimizers: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. + # This balances memory usage and management complexity. + + # split params into groups. currently different learning rates are not supported + grouped_params = [] + param_group = {} + for group in params_to_optimize: + named_parameters = list(nextdit.named_parameters()) + assert len(named_parameters) == len( + group["params"] + ), "number of parameters does not match" + for p, np in zip(group["params"], named_parameters): + # determine target layer and block index for each parameter + block_type = "other" # double, single or other + if np[0].startswith("double_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "double" + elif np[0].startswith("single_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "single" + else: + block_index = -1 + + param_group_key = (block_type, block_index) + if param_group_key not in param_group: + param_group[param_group_key] = [] + param_group[param_group_key].append(p) + + block_types_and_indices = [] + for param_group_key, param_group in param_group.items(): + block_types_and_indices.append(param_group_key) + grouped_params.append({"params": param_group, "lr": args.learning_rate}) + + num_params = 0 + for p in param_group: + num_params += p.numel() + accelerator.print(f"block {param_group_key}: {num_params} parameters") + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info( + f"using {len(optimizers)} optimizers for blockwise fused optimizers" + ) + + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError( + "Schedule-free optimizer is not supported with blockwise fused optimizers" + ) + optimizer_train_fn = lambda: None # dummy function + optimizer_eval_fn = lambda: None # dummy function + else: + _, _, optimizer = train_util.get_optimizer( + args, trainable_params=params_to_optimize + ) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn( + optimizer, args + ) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min( + args.max_data_loader_n_workers, os.cpu_count() + ) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) + / accelerator.num_processes + / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + if args.blockwise_fused_optimizers: + # prepare lr schedulers for each optimizer + lr_schedulers = [ + train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + for optimizer in optimizers + ] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix( + args, optimizer, accelerator.num_processes + ) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + nextdit.to(weight_dtype) + if gemma2 is not None: + gemma2.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + nextdit.to(weight_dtype) + if gemma2 is not None: + gemma2.to(weight_dtype) + + # if we don't cache text encoder outputs, move them to device + if not args.cache_text_encoder_outputs: + gemma2.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, nextdit=nextdit) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + # accelerator does some magic + # if we doesn't swap blocks, we can move the model to device + nextdit = accelerator.prepare( + nextdit, device_placement=[not is_swapping_blocks] + ) + if is_swapping_blocks: + accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks( + accelerator.device + ) # reduce peak memory usage + optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + optimizer, train_dataloader, lr_scheduler + ) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + for parameter, param_name in zip(param_group["params"], param_name_group): + if parameter.requires_grad: + + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None + + return grad_hook + + parameter.register_post_accumulate_grad_hook( + create_grad_hook(param_name, param_group) + ) + + elif args.blockwise_fused_optimizers: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def grad_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_( + parameter, args.max_grad_norm + ) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(grad_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = ( + math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + ) + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print( + f" num examples / サンプル数: {train_dataset_group.num_train_images}" + ) + accelerator.print( + f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}" + ) + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print( + f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}" + ) + accelerator.print( + f" total optimization steps / 学習ステップ数: {args.max_train_steps}" + ) + + progress_bar = tqdm( + range(args.max_train_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="steps", + ) + global_step = 0 + + noise_scheduler = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=1000, shift=args.discrete_flow_shift + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + if is_swapping_blocks: + accelerator.unwrap_model(nextdit).prepare_block_swap_before_forward() + + # For --sample_at_first + optimizer_eval_fn() + lumina_train_util.sample_images( + accelerator, + args, + 0, + global_step, + nextdit, + ae, + gemma2, + sample_prompts_te_outputs, + ) + optimizer_train_fn() + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) + + loss_recorder = train_util.LossRecorder() + epoch = 0 # avoid error when max_train_steps is 0 + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + + if args.blockwise_fused_optimizers: + optimizer_hooked_count = { + i: 0 for i in range(len(optimizers)) + } # reset counter for each step + + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to( + accelerator.device, dtype=weight_dtype + ) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = ae.encode(batch["images"].to(ae.dtype)).to( + accelerator.device, dtype=weight_dtype + ) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list + else: + # not cached or training, so get from text encoders + tokens_and_masks = batch["input_ids_list"] + with torch.no_grad(): + input_ids = [ + ids.to(accelerator.device) + for ids in batch["input_ids_list"] + ] + text_encoder_conds = text_encoding_strategy.encode_tokens( + lumina_tokenize_strategy, + [gemma2], + input_ids, + ) + if args.full_fp16: + text_encoder_conds = [ + c.to(weight_dtype) for c in text_encoder_conds + ] + + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = ( + lumina_train_util.get_noisy_model_input_and_timesteps( + args, + noise_scheduler_copy, + latents, + noise, + accelerator.device, + weight_dtype, + ) + ) + # call model + gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds + + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = nextdit( + x=img, # image latents (B, C, H, W) + t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期 + cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features + cap_mask=gemma2_attn_mask.to( + dtype=torch.int32 + ), # Gemma2的attention mask + ) + # apply model prediction type + model_pred, weighting = lumina_train_util.apply_model_prediction_type( + args, model_pred, noisy_model_input, sigmas + ) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # calculate loss + huber_c = train_util.get_huber_threshold_if_needed( + args, timesteps, noise_scheduler + ) + loss = train_util.conditional_loss( + model_pred.float(), target.float(), args.loss_type, "none", huber_c + ) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ( + "alpha_masks" in batch and batch["alpha_masks"] is not None + ): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + loss = loss.mean() + + # backward + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.blockwise_fused_optimizers: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + optimizer_eval_fn() + lumina_train_util.sample_images( + accelerator, + args, + None, + global_step, + nextdit, + ae, + gemma2, + sample_prompts_te_outputs, + ) + + # 指定ステップごとにモデルを保存 + if ( + args.save_every_n_steps is not None + and global_step % args.save_every_n_steps == 0 + ): + accelerator.wait_for_everyone() + if accelerator.is_main_process: + lumina_train_util.save_lumina_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(nextdit), + ) + optimizer_train_fn() + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if len(accelerator.trackers) > 0: + logs = {"loss": current_loss} + train_util.append_lr_to_logs( + logs, lr_scheduler, args.optimizer_type, including_unet=True + ) + + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + optimizer_eval_fn() + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + lumina_train_util.save_lumina_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(nextdit), + ) + + lumina_train_util.sample_images( + accelerator, + args, + epoch + 1, + global_step, + nextdit, + ae, + gemma2, + sample_prompts_te_outputs, + ) + optimizer_train_fn() + + is_main_process = accelerator.is_main_process + # if is_main_process: + nextdit = accelerator.unwrap_model(nextdit) + + accelerator.end_training() + optimizer_eval_fn() + + if args.save_state or args.save_state_on_train_end: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + lumina_train_util.save_lumina_model_on_train_end( + args, save_dtype, epoch, global_step, nextdit + ) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) # TODO split this + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + add_custom_train_arguments(parser) # TODO remove this from here + train_util.add_dit_training_arguments(parser) + lumina_train_util.add_lumina_train_arguments(parser) + + parser.add_argument( + "--mem_eff_save", + action="store_true", + help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う", + ) + + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます", + ) + parser.add_argument( + "--blockwise_fused_optimizers", + action="store_true", + help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", + ) + parser.add_argument( + "--skip_latents_validity_check", + action="store_true", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", + ) + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/lumina_train_network.py b/lumina_train_network.py index 0fd4da6b..5f20c014 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -15,7 +15,6 @@ from accelerate import Accelerator import train_network from library import ( lumina_models, - flux_train_utils, lumina_util, lumina_train_util, sd3_train_utils, @@ -250,36 +249,10 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): ): assert isinstance(noise_scheduler, sd3_train_utils.FlowMatchEulerDiscreteScheduler) noise = torch.randn_like(latents) - bsz = latents.shape[0] - - # Sample a random timestep for each image - # for weighting schemes where we sample timesteps non-uniformly - u = lumina_train_util.compute_density_for_timestep_sampling( - weighting_scheme=args.weighting_scheme, - batch_size=bsz, - logit_mean=args.logit_mean, - logit_std=args.logit_std, - mode_scale=args.mode_scale, + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = lumina_train_util.get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype ) - indices = (u * noise_scheduler.config.num_train_timesteps).long() - timesteps = noise_scheduler.timesteps[indices].to(device=latents.device) - - def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): - sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype) - schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) - timesteps = timesteps.to(accelerator.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < n_dim: - sigma = sigma.unsqueeze(-1) - return sigma - - # Add noise according to flow matching. - # zt = (1 - texp) * x + texp * z1 - # Lumina2 reverses the lerp i.e., sigma of 1.0 should mean `latents` - sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype) - noisy_model_input = (1.0 - sigmas) * noise + sigmas * latents # ensure the hidden state will require grad if args.gradient_checkpointing: @@ -310,7 +283,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): ) # apply model prediction type - model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + model_pred, weighting = lumina_train_util.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) # flow matching loss target = latents - noise @@ -336,7 +309,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # model_pred_prior = lumina_util.unpack_latents( # model_pred_prior, packed_latent_height, packed_latent_width # ) - model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( + model_pred_prior, _ = lumina_train_util.apply_model_prediction_type( args, model_pred_prior, noisy_model_input[diff_output_pr_indices],