profiler, initial_step fix, experimental req txt

This commit is contained in:
Darren Laurie
2025-04-06 15:37:14 +08:00
parent d35c51a59e
commit 36b007145a
3 changed files with 75 additions and 12 deletions

View File

@@ -14,7 +14,7 @@ import shutil
import time
import typing
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState, DataLoaderConfiguration
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, ProfileKwargs, PartialState, DataLoaderConfiguration
import glob
import math
import os
@@ -5448,6 +5448,15 @@ def prepare_accelerator(args: argparse.Namespace):
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
else None
),
(
ProfileKwargs(
activities=["cpu", "cuda"],
output_trace_dir="/dev/shm/trace",
profile_memory=True,
record_shapes=True,
with_flops=True
)
)
]
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)

48
requirements-2.txt Normal file
View File

@@ -0,0 +1,48 @@
accelerate
transformers
diffusers[torch]
ftfy
# albumentations==1.3.0
opencv-python
einops
pytorch-lightning
bitsandbytes
prodigyopt
lion-pytorch
schedulefree
tensorboard
safetensors
torchao
pytorch-optimizer
# gradio==3.16.2
altair
easygui
toml
voluptuous
huggingface-hub
# for Image utils
imagesize
numpy
# for BLIP captioning
# requests==2.28.2
# timm==0.6.12
# fairscale==0.4.13
# for WD14 captioning (tensorflow)
# tensorflow==2.10.1
# for WD14 captioning (onnx)
# onnx==1.15.0
# onnxruntime-gpu==1.17.1
# onnxruntime==1.17.1
# for cuda 12.1(default 11.8)
# onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
# this is for onnx:
# protobuf==3.20.3
# open clip for SDXL
# open-clip-torch==2.20.0
# For logging
rich
# for T5XXL tokenizer (SD3/FLUX)
sentencepiece
# for kohya_ss library
-e .

View File

@@ -708,10 +708,10 @@ class NativeTrainer:
unet.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared
# TODO: SDXL Model Specific
# TODO: Is casting to torch.tensor slowing down the performance so much? (20% slower)
# TODO: Why casting to torch.tensor will slow down the performance so much? (20% slower)
training_models = []
params_to_optimize = []
using_torchao = args.optimizer_type.endswith("4bit") or args.optimizer_type.endswith("Fp8")
using_torchao = args.optimizer_type.endswith("4bit") or args.optimizer_type.endswith("Fp8")
if train_unet:
training_models.append(unet)
if block_lrs is None:
@@ -1302,7 +1302,9 @@ class NativeTrainer:
f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります"
)
logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
initial_step *= args.gradient_accumulation_steps
#250406: Why multiply? It has been included.
#initial_step *= args.gradient_accumulation_steps
# set epoch to start to make initial_step less than len(train_dataloader)
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -1350,8 +1352,7 @@ class NativeTrainer:
for skip_epoch in range(epoch_to_start): # skip epochs
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
initial_step -= len(train_dataloader)
# I have found that the log is screwed up. This should be divided back.
global_step = int(initial_step / args.gradient_accumulation_steps)
global_step = initial_step
# log device and dtype for each model
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
@@ -1380,6 +1381,8 @@ class NativeTrainer:
initial_step = 1
for step, batch in enumerate(skipped_dataloader or train_dataloader):
#Enable this for profiler. Hint: select a big area (until EPOCH VALIDATION) and tab / shift tab
#with accelerator.profile() as prof:
current_step.value = global_step
if initial_step > 0:
initial_step -= 1
@@ -1392,7 +1395,10 @@ class NativeTrainer:
# Tne correct specific "network" operation has been removed.
# The process_batch will wrap all the inference logic (because it will be used for validation dataset also)
with accelerator.accumulate(*training_models):
# 250331: From HF guide
# 250406: No need
#optimizer.zero_grad(set_to_none=True)
# temporary, for batch processing
self.on_step_start(args, accelerator, text_encoders, unet, batch, weight_dtype)
@@ -1415,13 +1421,13 @@ class NativeTrainer:
accelerator.backward(loss)
#250331: It is required to sync manually. See torch.Tensor.grad
if accelerator.sync_gradients:
for training_model in training_models:
self.all_reduce_training_model(accelerator, training_model) # sync DDP grad manually
if args.max_grad_norm != 0.0:
if hasattr(training_model, "get_trainable_params"):
params_to_clip = accelerator.unwrap_model(training_model).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
self.all_reduce_training_model(accelerator, training_model) # sync DDP grad manually
if (args.max_grad_norm != 0.0) and hasattr(training_model, "get_trainable_params"):
params_to_clip = accelerator.unwrap_model(training_model).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()