pytorch profiler options

This commit is contained in:
Darren Laurie
2025-04-06 16:59:48 +08:00
parent d932129093
commit 75712d1f2e
2 changed files with 173 additions and 162 deletions

View File

@@ -3948,6 +3948,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
action="store_true",
help="enable static_graph for DDP / DDPでstatic_graphを有効にする",
)
parser.add_argument(
"--profiler_path",
type=str,
default="/dev/shm/trace",
help="Path for storing PyTorch Profiler traces. Recommended to store in RAM drive. / PyTorch Profiler トレースを保存するためのパス。RAM ドライブに保存することをお勧めします。",
)
parser.add_argument(
"--clip_skip",
type=int,
@@ -5478,7 +5484,7 @@ def prepare_accelerator(args: argparse.Namespace):
(
ProfileKwargs(
activities=["cpu", "cuda"],
output_trace_dir="/dev/shm/trace",
output_trace_dir=args.profiler_path,
profile_memory=True,
record_shapes=True,
with_flops=True

View File

@@ -7,6 +7,7 @@ import random
import time
import json
from multiprocessing import Value
from contextlib import nullcontext
from tqdm import tqdm
@@ -1375,6 +1376,10 @@ class NativeTrainer:
clean_memory_on_device(accelerator.device)
enable_profiler = args.enable_profiler
if enable_profiler:
logger.warning(f"Pytorch profiler enabled. Disable after capturing traces. / Pytorch プロファイラーが有効になっています。トレースをキャプチャした後は無効にします。")
for epoch in range(epoch_to_start, num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n")
current_epoch.value = epoch + 1
@@ -1391,148 +1396,23 @@ class NativeTrainer:
initial_step = 1
for step, batch in enumerate(skipped_dataloader or train_dataloader):
#Enable this for profiler. Hint: select a big area (until EPOCH VALIDATION) and tab / shift tab
#with accelerator.profile() as prof:
current_step.value = global_step
if initial_step > 0:
initial_step -= 1
continue
with accelerator.profile() if enable_profiler else nullcontext() as prof:
current_step.value = global_step
if initial_step > 0:
initial_step -= 1
continue
if args.fused_optimizer_groups:
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
# Code guide: "network" here was misrepresented as training_model, however some features are capable for all "prepared" models.
# Tne correct specific "network" operation has been removed.
# The process_batch will wrap all the inference logic (because it will be used for validation dataset also)
with accelerator.accumulate(*training_models):
# 250331: From HF guide
# 250406: No need
#optimizer.zero_grad(set_to_none=True)
# temporary, for batch processing
self.on_step_start(args, accelerator, text_encoders, unet, batch, weight_dtype)
loss = self.process_batch(
batch,
text_encoders,
unet,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=True,
train_text_encoder=train_text_encoder,
train_unet=train_unet
)
accelerator.backward(loss)
#250331: It is required to sync manually. See torch.Tensor.grad
if accelerator.sync_gradients:
for training_model in training_models:
self.all_reduce_training_model(accelerator, training_model) # sync DDP grad manually
if (args.max_grad_norm != 0.0) and hasattr(training_model, "get_trainable_params"):
params_to_clip = accelerator.unwrap_model(training_model).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
# lora_flux exclusive
if hasattr(training_model, "update_grad_norms"):
training_model.update_grad_norms()
if hasattr(training_model, "update_norms"):
training_model.update_norms()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if args.scale_weight_norms:
for training_model in training_models:
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(training_model).apply_max_norm_regularization(
args.scale_weight_norms, accelerator.device
)
# TODO: Multiple models
mean_grad_norm = None
mean_combined_norm = None
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
else:
keys_scaled, mean_norm, maximum_norm = None, None, None
mean_grad_norm = None
mean_combined_norm = None
max_mean_logs = {}
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
optimizer_eval_fn()
self.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
# Train network has different approach: It will upload to hf or remove old file immediately.
# Train Native will keep the old *_train_utils.approach, however the class reference is so messy.
# Hint: self.load_target_model
self.save_model_on_epoch_end_or_stepwise(args, False, accelerator, save_dtype, epoch, num_train_epochs, global_step, text_encoders, vae, unet)
optimizer_train_fn()
current_loss = loss.detach().item()
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
if block_lrs is None:
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet)
else:
self.append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type) # U-Net is included in block_lrs
accelerator.log(logs, step=global_step)
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**{**max_mean_logs, **logs})
if is_tracking:
logs = self.generate_step_logs(
args,
current_loss,
avr_loss,
lr_scheduler,
lr_descriptions,
optimizer,
keys_scaled,
mean_norm,
maximum_norm,
mean_grad_norm,
mean_combined_norm,
)
accelerator.log(logs, step=global_step)
# VALIDATION PER STEP
should_validate_step = (
args.validate_every_n_steps is not None
and global_step != 0 # Skip first step
and global_step % args.validate_every_n_steps == 0
)
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
val_progress_bar = tqdm(
range(validation_steps), smoothing=0,
disable=not accelerator.is_local_main_process,
desc="validation steps"
)
for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps:
break
if args.fused_optimizer_groups:
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
# Code guide: "network" here was misrepresented as training_model, however some features are capable for all "prepared" models.
# Tne correct specific "network" operation has been removed.
# The process_batch will wrap all the inference logic (because it will be used for validation dataset also)
with accelerator.accumulate(*training_models):
# 250331: From HF guide
# 250406: No need
#optimizer.zero_grad(set_to_none=True)
# temporary, for batch processing
self.on_step_start(args, accelerator, text_encoders, unet, batch, weight_dtype)
@@ -1548,33 +1428,157 @@ class NativeTrainer:
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=False,
train_unet=False
is_train=True,
train_text_encoder=train_text_encoder,
train_unet=train_unet
)
current_loss = loss.detach().item()
val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
val_progress_bar.update(1)
val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average })
accelerator.backward(loss)
if is_tracking:
logs = {
"loss/validation/step_current": current_loss,
"val_step": (epoch * validation_steps) + val_step,
}
accelerator.log(logs, step=global_step)
#250331: It is required to sync manually. See torch.Tensor.grad
if accelerator.sync_gradients:
for training_model in training_models:
self.all_reduce_training_model(accelerator, training_model) # sync DDP grad manually
if (args.max_grad_norm != 0.0) and hasattr(training_model, "get_trainable_params"):
params_to_clip = accelerator.unwrap_model(training_model).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
# lora_flux exclusive
if hasattr(training_model, "update_grad_norms"):
training_model.update_grad_norms()
if hasattr(training_model, "update_norms"):
training_model.update_norms()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if args.scale_weight_norms:
for training_model in training_models:
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(training_model).apply_max_norm_regularization(
args.scale_weight_norms, accelerator.device
)
# TODO: Multiple models
mean_grad_norm = None
mean_combined_norm = None
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
else:
keys_scaled, mean_norm, maximum_norm = None, None, None
mean_grad_norm = None
mean_combined_norm = None
max_mean_logs = {}
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
optimizer_eval_fn()
self.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
# Train network has different approach: It will upload to hf or remove old file immediately.
# Train Native will keep the old *_train_utils.approach, however the class reference is so messy.
# Hint: self.load_target_model
self.save_model_on_epoch_end_or_stepwise(args, False, accelerator, save_dtype, epoch, num_train_epochs, global_step, text_encoders, vae, unet)
optimizer_train_fn()
current_loss = loss.detach().item()
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
if block_lrs is None:
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet)
else:
self.append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type) # U-Net is included in block_lrs
accelerator.log(logs, step=global_step)
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**{**max_mean_logs, **logs})
if is_tracking:
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
logs = {
"loss/validation/step_average": val_step_loss_recorder.moving_average,
"loss/validation/step_divergence": loss_validation_divergence,
}
logs = self.generate_step_logs(
args,
current_loss,
avr_loss,
lr_scheduler,
lr_descriptions,
optimizer,
keys_scaled,
mean_norm,
maximum_norm,
mean_grad_norm,
mean_combined_norm,
)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
# VALIDATION PER STEP
should_validate_step = (
args.validate_every_n_steps is not None
and global_step != 0 # Skip first step
and global_step % args.validate_every_n_steps == 0
)
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
val_progress_bar = tqdm(
range(validation_steps), smoothing=0,
disable=not accelerator.is_local_main_process,
desc="validation steps"
)
for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps:
break
# temporary, for batch processing
self.on_step_start(args, accelerator, text_encoders, unet, batch, weight_dtype)
loss = self.process_batch(
batch,
text_encoders,
unet,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=False,
train_unet=False
)
current_loss = loss.detach().item()
val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
val_progress_bar.update(1)
val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average })
if is_tracking:
logs = {
"loss/validation/step_current": current_loss,
"val_step": (epoch * validation_steps) + val_step,
}
accelerator.log(logs, step=global_step)
if is_tracking:
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
logs = {
"loss/validation/step_average": val_step_loss_recorder.moving_average,
"loss/validation/step_divergence": loss_validation_divergence,
}
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
# EPOCH VALIDATION
should_validate_epoch = (
@@ -1757,6 +1761,7 @@ def setup_parser() -> argparse.ArgumentParser:
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ未指定時と同じ。initial_epochを上書きする",
)
parser.add_argument("--enable_profiler", action="store_true", help="Enable PyTorch Profiler for in depth analysis on tracing training process. Enable will make training very slow. / トレーニング プロセスのトレースに関する詳細な分析を行うには、PyTorch Profiler を有効にします。有効にすると、トレーニングが非常に遅くなります。")
#Append to add_dataset_arguments(parser)?
parser.add_argument(