[Experimental Feature] FP8 weight dtype for base model when running train_network (or sdxl_train_network) (#1057)

* Add fp8 support

* remove some debug prints

* Better implementation for te

* Fix some misunderstanding

* as same as unet, add explicit convert

* better impl for convert TE to fp8

* fp8 for not only unet

* Better cache TE and TE lr

* match arg name

* Fix with list

* Add timeout settings

* Fix arg style

* Add custom seperator

* Fix typo

* Fix typo again

* Fix dtype error

* Fix gradient problem

* Fix req grad

* fix merge

* Fix merge

* Resolve merge

* arrangement and document

* Resolve merge error

* Add assert for mixed precision
This commit is contained in:
Kohaku-Blueleaf
2024-01-20 08:46:53 +08:00
committed by GitHub
parent 0395a35543
commit 9cfa68c92f
2 changed files with 33 additions and 6 deletions

View File

@@ -2904,6 +2904,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
) # TODO move to SDXL training, because it is not supported by SD1/2
parser.add_argument(
"--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う"
)
parser.add_argument(
"--ddp_timeout",
type=int,

View File

@@ -390,16 +390,36 @@ class NetworkTrainer:
accelerator.print("enable full bf16 training.")
network.to(weight_dtype)
unet_weight_dtype = te_weight_dtype = weight_dtype
# Experimental Feature: Put base model into fp8 to save vram
if args.fp8_base:
assert (
torch.__version__ >= '2.1.0'
), "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
assert (
args.mixed_precision != 'no'
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
accelerator.print("enable fp8 training.")
unet_weight_dtype = torch.float8_e4m3fn
te_weight_dtype = torch.float8_e4m3fn
unet.requires_grad_(False)
unet.to(dtype=weight_dtype)
unet.to(dtype=unet_weight_dtype)
for t_enc in text_encoders:
t_enc.requires_grad_(False)
t_enc.to(dtype=te_weight_dtype)
# nn.Embedding not support FP8
t_enc.text_model.embeddings.to(dtype=(
weight_dtype
if te_weight_dtype == torch.float8_e4m3fn
else te_weight_dtype
))
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if train_unet:
unet = accelerator.prepare(unet)
else:
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
if train_text_encoder:
if len(text_encoders) > 1:
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
@@ -421,9 +441,6 @@ class NetworkTrainer:
if train_text_encoder:
t_enc.text_model.embeddings.requires_grad_(True)
# set top parameter requires_grad = True for gradient checkpointing works
if not train_text_encoder: # train U-Net only
unet.parameters().__next__().requires_grad_(True)
else:
unet.eval()
for t_enc in text_encoders:
@@ -778,10 +795,17 @@ class NetworkTrainer:
args, noise_scheduler, latents
)
# ensure the hidden state will require grad
if args.gradient_checkpointing:
for x in noisy_latents:
x.requires_grad_(True)
for t in text_encoder_conds:
t.requires_grad_(True)
# Predict the noise residual
with accelerator.autocast():
noise_pred = self.call_unet(
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
args, accelerator, unet, noisy_latents.requires_grad_(train_unet), timesteps, text_encoder_conds, batch, weight_dtype
)
if args.v_parameterization: