diff --git a/library/train_util.py b/library/train_util.py index f894390a..1a5dd1ae 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3948,6 +3948,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: action="store_true", help="enable static_graph for DDP / DDPでstatic_graphを有効にする", ) + parser.add_argument( + "--profiler_path", + type=str, + default="/dev/shm/trace", + help="Path for storing PyTorch Profiler traces. Recommended to store in RAM drive. / PyTorch Profiler トレースを保存するためのパス。RAM ドライブに保存することをお勧めします。", + ) parser.add_argument( "--clip_skip", type=int, @@ -5478,7 +5484,7 @@ def prepare_accelerator(args: argparse.Namespace): ( ProfileKwargs( activities=["cpu", "cuda"], - output_trace_dir="/dev/shm/trace", + output_trace_dir=args.profiler_path, profile_memory=True, record_shapes=True, with_flops=True diff --git a/train_native.py b/train_native.py index 42e838da..8bb0765c 100644 --- a/train_native.py +++ b/train_native.py @@ -7,6 +7,7 @@ import random import time import json from multiprocessing import Value +from contextlib import nullcontext from tqdm import tqdm @@ -1375,6 +1376,10 @@ class NativeTrainer: clean_memory_on_device(accelerator.device) + enable_profiler = args.enable_profiler + if enable_profiler: + logger.warning(f"Pytorch profiler enabled. Disable after capturing traces. / Pytorch プロファイラーが有効になっています。トレースをキャプチャした後は無効にします。") + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") current_epoch.value = epoch + 1 @@ -1391,148 +1396,23 @@ class NativeTrainer: initial_step = 1 for step, batch in enumerate(skipped_dataloader or train_dataloader): - #Enable this for profiler. Hint: select a big area (until EPOCH VALIDATION) and tab / shift tab - #with accelerator.profile() as prof: - current_step.value = global_step - if initial_step > 0: - initial_step -= 1 - continue + with accelerator.profile() if enable_profiler else nullcontext() as prof: + current_step.value = global_step + if initial_step > 0: + initial_step -= 1 + continue - if args.fused_optimizer_groups: - optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step - - # Code guide: "network" here was misrepresented as training_model, however some features are capable for all "prepared" models. - # Tne correct specific "network" operation has been removed. - # The process_batch will wrap all the inference logic (because it will be used for validation dataset also) - with accelerator.accumulate(*training_models): - # 250331: From HF guide - # 250406: No need - #optimizer.zero_grad(set_to_none=True) - - # temporary, for batch processing - self.on_step_start(args, accelerator, text_encoders, unet, batch, weight_dtype) - - loss = self.process_batch( - batch, - text_encoders, - unet, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, - is_train=True, - train_text_encoder=train_text_encoder, - train_unet=train_unet - ) - - accelerator.backward(loss) - - #250331: It is required to sync manually. See torch.Tensor.grad - if accelerator.sync_gradients: - for training_model in training_models: - self.all_reduce_training_model(accelerator, training_model) # sync DDP grad manually - if (args.max_grad_norm != 0.0) and hasattr(training_model, "get_trainable_params"): - params_to_clip = accelerator.unwrap_model(training_model).get_trainable_params() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - # lora_flux exclusive - if hasattr(training_model, "update_grad_norms"): - training_model.update_grad_norms() - if hasattr(training_model, "update_norms"): - training_model.update_norms() - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - if args.scale_weight_norms: - for training_model in training_models: - keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(training_model).apply_max_norm_regularization( - args.scale_weight_norms, accelerator.device - ) - # TODO: Multiple models - mean_grad_norm = None - mean_combined_norm = None - max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} - else: - keys_scaled, mean_norm, maximum_norm = None, None, None - mean_grad_norm = None - mean_combined_norm = None - max_mean_logs = {} - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - optimizer_eval_fn() - self.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet - ) - - # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - # Train network has different approach: It will upload to hf or remove old file immediately. - # Train Native will keep the old *_train_utils.approach, however the class reference is so messy. - # Hint: self.load_target_model - self.save_model_on_epoch_end_or_stepwise(args, False, accelerator, save_dtype, epoch, num_train_epochs, global_step, text_encoders, vae, unet) - - optimizer_train_fn() - - current_loss = loss.detach().item() - - if len(accelerator.trackers) > 0: - logs = {"loss": current_loss} - if block_lrs is None: - train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet) - else: - self.append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type) # U-Net is included in block_lrs - - accelerator.log(logs, step=global_step) - - loss_recorder.add(epoch=epoch, step=step, loss=current_loss) - avr_loss: float = loss_recorder.moving_average - logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**{**max_mean_logs, **logs}) - - if is_tracking: - logs = self.generate_step_logs( - args, - current_loss, - avr_loss, - lr_scheduler, - lr_descriptions, - optimizer, - keys_scaled, - mean_norm, - maximum_norm, - mean_grad_norm, - mean_combined_norm, - ) - accelerator.log(logs, step=global_step) - - # VALIDATION PER STEP - should_validate_step = ( - args.validate_every_n_steps is not None - and global_step != 0 # Skip first step - and global_step % args.validate_every_n_steps == 0 - ) - if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: - val_progress_bar = tqdm( - range(validation_steps), smoothing=0, - disable=not accelerator.is_local_main_process, - desc="validation steps" - ) - for val_step, batch in enumerate(val_dataloader): - if val_step >= validation_steps: - break + if args.fused_optimizer_groups: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step + # Code guide: "network" here was misrepresented as training_model, however some features are capable for all "prepared" models. + # Tne correct specific "network" operation has been removed. + # The process_batch will wrap all the inference logic (because it will be used for validation dataset also) + with accelerator.accumulate(*training_models): + # 250331: From HF guide + # 250406: No need + #optimizer.zero_grad(set_to_none=True) + # temporary, for batch processing self.on_step_start(args, accelerator, text_encoders, unet, batch, weight_dtype) @@ -1548,33 +1428,157 @@ class NativeTrainer: args, text_encoding_strategy, tokenize_strategy, - is_train=False, - train_text_encoder=False, - train_unet=False + is_train=True, + train_text_encoder=train_text_encoder, + train_unet=train_unet ) - current_loss = loss.detach().item() - val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) - val_progress_bar.update(1) - val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average }) + accelerator.backward(loss) - if is_tracking: - logs = { - "loss/validation/step_current": current_loss, - "val_step": (epoch * validation_steps) + val_step, - } - accelerator.log(logs, step=global_step) + #250331: It is required to sync manually. See torch.Tensor.grad + if accelerator.sync_gradients: + for training_model in training_models: + self.all_reduce_training_model(accelerator, training_model) # sync DDP grad manually + if (args.max_grad_norm != 0.0) and hasattr(training_model, "get_trainable_params"): + params_to_clip = accelerator.unwrap_model(training_model).get_trainable_params() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + # lora_flux exclusive + if hasattr(training_model, "update_grad_norms"): + training_model.update_grad_norms() + if hasattr(training_model, "update_norms"): + training_model.update_norms() + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + if args.scale_weight_norms: + for training_model in training_models: + keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(training_model).apply_max_norm_regularization( + args.scale_weight_norms, accelerator.device + ) + # TODO: Multiple models + mean_grad_norm = None + mean_combined_norm = None + max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} + else: + keys_scaled, mean_norm, maximum_norm = None, None, None + mean_grad_norm = None + mean_combined_norm = None + max_mean_logs = {} + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + optimizer_eval_fn() + self.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + # Train network has different approach: It will upload to hf or remove old file immediately. + # Train Native will keep the old *_train_utils.approach, however the class reference is so messy. + # Hint: self.load_target_model + self.save_model_on_epoch_end_or_stepwise(args, False, accelerator, save_dtype, epoch, num_train_epochs, global_step, text_encoders, vae, unet) + + optimizer_train_fn() + + current_loss = loss.detach().item() + + if len(accelerator.trackers) > 0: + logs = {"loss": current_loss} + if block_lrs is None: + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet) + else: + self.append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type) # U-Net is included in block_lrs + + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**{**max_mean_logs, **logs}) if is_tracking: - loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average - logs = { - "loss/validation/step_average": val_step_loss_recorder.moving_average, - "loss/validation/step_divergence": loss_validation_divergence, - } + logs = self.generate_step_logs( + args, + current_loss, + avr_loss, + lr_scheduler, + lr_descriptions, + optimizer, + keys_scaled, + mean_norm, + maximum_norm, + mean_grad_norm, + mean_combined_norm, + ) accelerator.log(logs, step=global_step) - - if global_step >= args.max_train_steps: - break + + # VALIDATION PER STEP + should_validate_step = ( + args.validate_every_n_steps is not None + and global_step != 0 # Skip first step + and global_step % args.validate_every_n_steps == 0 + ) + if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: + val_progress_bar = tqdm( + range(validation_steps), smoothing=0, + disable=not accelerator.is_local_main_process, + desc="validation steps" + ) + for val_step, batch in enumerate(val_dataloader): + if val_step >= validation_steps: + break + + # temporary, for batch processing + self.on_step_start(args, accelerator, text_encoders, unet, batch, weight_dtype) + + loss = self.process_batch( + batch, + text_encoders, + unet, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False, + train_text_encoder=False, + train_unet=False + ) + + current_loss = loss.detach().item() + val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) + val_progress_bar.update(1) + val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average }) + + if is_tracking: + logs = { + "loss/validation/step_current": current_loss, + "val_step": (epoch * validation_steps) + val_step, + } + accelerator.log(logs, step=global_step) + + if is_tracking: + loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average + logs = { + "loss/validation/step_average": val_step_loss_recorder.moving_average, + "loss/validation/step_divergence": loss_validation_divergence, + } + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break # EPOCH VALIDATION should_validate_epoch = ( @@ -1757,6 +1761,7 @@ def setup_parser() -> argparse.ArgumentParser: help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch." + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする", ) + parser.add_argument("--enable_profiler", action="store_true", help="Enable PyTorch Profiler for in depth analysis on tracing training process. Enable will make training very slow. / トレーニング プロセスのトレースに関する詳細な分析を行うには、PyTorch Profiler を有効にします。有効にすると、トレーニングが非常に遅くなります。") #Append to add_dataset_arguments(parser)? parser.add_argument(