mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 00:32:25 +00:00
pytorch profiler options
This commit is contained in:
@@ -3948,6 +3948,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
action="store_true",
|
||||
help="enable static_graph for DDP / DDPでstatic_graphを有効にする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profiler_path",
|
||||
type=str,
|
||||
default="/dev/shm/trace",
|
||||
help="Path for storing PyTorch Profiler traces. Recommended to store in RAM drive. / PyTorch Profiler トレースを保存するためのパス。RAM ドライブに保存することをお勧めします。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip_skip",
|
||||
type=int,
|
||||
@@ -5478,7 +5484,7 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
(
|
||||
ProfileKwargs(
|
||||
activities=["cpu", "cuda"],
|
||||
output_trace_dir="/dev/shm/trace",
|
||||
output_trace_dir=args.profiler_path,
|
||||
profile_memory=True,
|
||||
record_shapes=True,
|
||||
with_flops=True
|
||||
|
||||
327
train_native.py
327
train_native.py
@@ -7,6 +7,7 @@ import random
|
||||
import time
|
||||
import json
|
||||
from multiprocessing import Value
|
||||
from contextlib import nullcontext
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -1375,6 +1376,10 @@ class NativeTrainer:
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
enable_profiler = args.enable_profiler
|
||||
if enable_profiler:
|
||||
logger.warning(f"Pytorch profiler enabled. Disable after capturing traces. / Pytorch プロファイラーが有効になっています。トレースをキャプチャした後は無効にします。")
|
||||
|
||||
for epoch in range(epoch_to_start, num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n")
|
||||
current_epoch.value = epoch + 1
|
||||
@@ -1391,148 +1396,23 @@ class NativeTrainer:
|
||||
initial_step = 1
|
||||
|
||||
for step, batch in enumerate(skipped_dataloader or train_dataloader):
|
||||
#Enable this for profiler. Hint: select a big area (until EPOCH VALIDATION) and tab / shift tab
|
||||
#with accelerator.profile() as prof:
|
||||
current_step.value = global_step
|
||||
if initial_step > 0:
|
||||
initial_step -= 1
|
||||
continue
|
||||
with accelerator.profile() if enable_profiler else nullcontext() as prof:
|
||||
current_step.value = global_step
|
||||
if initial_step > 0:
|
||||
initial_step -= 1
|
||||
continue
|
||||
|
||||
if args.fused_optimizer_groups:
|
||||
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
|
||||
|
||||
# Code guide: "network" here was misrepresented as training_model, however some features are capable for all "prepared" models.
|
||||
# Tne correct specific "network" operation has been removed.
|
||||
# The process_batch will wrap all the inference logic (because it will be used for validation dataset also)
|
||||
with accelerator.accumulate(*training_models):
|
||||
# 250331: From HF guide
|
||||
# 250406: No need
|
||||
#optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, text_encoders, unet, batch, weight_dtype)
|
||||
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=True,
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet
|
||||
)
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
#250331: It is required to sync manually. See torch.Tensor.grad
|
||||
if accelerator.sync_gradients:
|
||||
for training_model in training_models:
|
||||
self.all_reduce_training_model(accelerator, training_model) # sync DDP grad manually
|
||||
if (args.max_grad_norm != 0.0) and hasattr(training_model, "get_trainable_params"):
|
||||
params_to_clip = accelerator.unwrap_model(training_model).get_trainable_params()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
# lora_flux exclusive
|
||||
if hasattr(training_model, "update_grad_norms"):
|
||||
training_model.update_grad_norms()
|
||||
if hasattr(training_model, "update_norms"):
|
||||
training_model.update_norms()
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if args.scale_weight_norms:
|
||||
for training_model in training_models:
|
||||
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(training_model).apply_max_norm_regularization(
|
||||
args.scale_weight_norms, accelerator.device
|
||||
)
|
||||
# TODO: Multiple models
|
||||
mean_grad_norm = None
|
||||
mean_combined_norm = None
|
||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||
else:
|
||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||
mean_grad_norm = None
|
||||
mean_combined_norm = None
|
||||
max_mean_logs = {}
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
optimizer_eval_fn()
|
||||
self.sample_images(
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
# Train network has different approach: It will upload to hf or remove old file immediately.
|
||||
# Train Native will keep the old *_train_utils.approach, however the class reference is so messy.
|
||||
# Hint: self.load_target_model
|
||||
self.save_model_on_epoch_end_or_stepwise(args, False, accelerator, save_dtype, epoch, num_train_epochs, global_step, text_encoders, vae, unet)
|
||||
|
||||
optimizer_train_fn()
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss": current_loss}
|
||||
if block_lrs is None:
|
||||
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet)
|
||||
else:
|
||||
self.append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type) # U-Net is included in block_lrs
|
||||
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||
|
||||
if is_tracking:
|
||||
logs = self.generate_step_logs(
|
||||
args,
|
||||
current_loss,
|
||||
avr_loss,
|
||||
lr_scheduler,
|
||||
lr_descriptions,
|
||||
optimizer,
|
||||
keys_scaled,
|
||||
mean_norm,
|
||||
maximum_norm,
|
||||
mean_grad_norm,
|
||||
mean_combined_norm,
|
||||
)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
# VALIDATION PER STEP
|
||||
should_validate_step = (
|
||||
args.validate_every_n_steps is not None
|
||||
and global_step != 0 # Skip first step
|
||||
and global_step % args.validate_every_n_steps == 0
|
||||
)
|
||||
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
|
||||
val_progress_bar = tqdm(
|
||||
range(validation_steps), smoothing=0,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc="validation steps"
|
||||
)
|
||||
for val_step, batch in enumerate(val_dataloader):
|
||||
if val_step >= validation_steps:
|
||||
break
|
||||
if args.fused_optimizer_groups:
|
||||
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
|
||||
|
||||
# Code guide: "network" here was misrepresented as training_model, however some features are capable for all "prepared" models.
|
||||
# Tne correct specific "network" operation has been removed.
|
||||
# The process_batch will wrap all the inference logic (because it will be used for validation dataset also)
|
||||
with accelerator.accumulate(*training_models):
|
||||
# 250331: From HF guide
|
||||
# 250406: No need
|
||||
#optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, text_encoders, unet, batch, weight_dtype)
|
||||
|
||||
@@ -1548,33 +1428,157 @@ class NativeTrainer:
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=False,
|
||||
train_text_encoder=False,
|
||||
train_unet=False
|
||||
is_train=True,
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||
val_progress_bar.update(1)
|
||||
val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average })
|
||||
accelerator.backward(loss)
|
||||
|
||||
if is_tracking:
|
||||
logs = {
|
||||
"loss/validation/step_current": current_loss,
|
||||
"val_step": (epoch * validation_steps) + val_step,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
#250331: It is required to sync manually. See torch.Tensor.grad
|
||||
if accelerator.sync_gradients:
|
||||
for training_model in training_models:
|
||||
self.all_reduce_training_model(accelerator, training_model) # sync DDP grad manually
|
||||
if (args.max_grad_norm != 0.0) and hasattr(training_model, "get_trainable_params"):
|
||||
params_to_clip = accelerator.unwrap_model(training_model).get_trainable_params()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
# lora_flux exclusive
|
||||
if hasattr(training_model, "update_grad_norms"):
|
||||
training_model.update_grad_norms()
|
||||
if hasattr(training_model, "update_norms"):
|
||||
training_model.update_norms()
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if args.scale_weight_norms:
|
||||
for training_model in training_models:
|
||||
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(training_model).apply_max_norm_regularization(
|
||||
args.scale_weight_norms, accelerator.device
|
||||
)
|
||||
# TODO: Multiple models
|
||||
mean_grad_norm = None
|
||||
mean_combined_norm = None
|
||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||
else:
|
||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||
mean_grad_norm = None
|
||||
mean_combined_norm = None
|
||||
max_mean_logs = {}
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
optimizer_eval_fn()
|
||||
self.sample_images(
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
# Train network has different approach: It will upload to hf or remove old file immediately.
|
||||
# Train Native will keep the old *_train_utils.approach, however the class reference is so messy.
|
||||
# Hint: self.load_target_model
|
||||
self.save_model_on_epoch_end_or_stepwise(args, False, accelerator, save_dtype, epoch, num_train_epochs, global_step, text_encoders, vae, unet)
|
||||
|
||||
optimizer_train_fn()
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss": current_loss}
|
||||
if block_lrs is None:
|
||||
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet)
|
||||
else:
|
||||
self.append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type) # U-Net is included in block_lrs
|
||||
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||
|
||||
if is_tracking:
|
||||
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
|
||||
logs = {
|
||||
"loss/validation/step_average": val_step_loss_recorder.moving_average,
|
||||
"loss/validation/step_divergence": loss_validation_divergence,
|
||||
}
|
||||
logs = self.generate_step_logs(
|
||||
args,
|
||||
current_loss,
|
||||
avr_loss,
|
||||
lr_scheduler,
|
||||
lr_descriptions,
|
||||
optimizer,
|
||||
keys_scaled,
|
||||
mean_norm,
|
||||
maximum_norm,
|
||||
mean_grad_norm,
|
||||
mean_combined_norm,
|
||||
)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
# VALIDATION PER STEP
|
||||
should_validate_step = (
|
||||
args.validate_every_n_steps is not None
|
||||
and global_step != 0 # Skip first step
|
||||
and global_step % args.validate_every_n_steps == 0
|
||||
)
|
||||
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
|
||||
val_progress_bar = tqdm(
|
||||
range(validation_steps), smoothing=0,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc="validation steps"
|
||||
)
|
||||
for val_step, batch in enumerate(val_dataloader):
|
||||
if val_step >= validation_steps:
|
||||
break
|
||||
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, text_encoders, unet, batch, weight_dtype)
|
||||
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=False,
|
||||
train_text_encoder=False,
|
||||
train_unet=False
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||
val_progress_bar.update(1)
|
||||
val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average })
|
||||
|
||||
if is_tracking:
|
||||
logs = {
|
||||
"loss/validation/step_current": current_loss,
|
||||
"val_step": (epoch * validation_steps) + val_step,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if is_tracking:
|
||||
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
|
||||
logs = {
|
||||
"loss/validation/step_average": val_step_loss_recorder.moving_average,
|
||||
"loss/validation/step_divergence": loss_validation_divergence,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
# EPOCH VALIDATION
|
||||
should_validate_epoch = (
|
||||
@@ -1757,6 +1761,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
|
||||
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする",
|
||||
)
|
||||
parser.add_argument("--enable_profiler", action="store_true", help="Enable PyTorch Profiler for in depth analysis on tracing training process. Enable will make training very slow. / トレーニング プロセスのトレースに関する詳細な分析を行うには、PyTorch Profiler を有効にします。有効にすると、トレーニングが非常に遅くなります。")
|
||||
|
||||
#Append to add_dataset_arguments(parser)?
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user