mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
[Experimental Feature] FP8 weight dtype for base model when running train_network (or sdxl_train_network) (#1057)
* Add fp8 support * remove some debug prints * Better implementation for te * Fix some misunderstanding * as same as unet, add explicit convert * better impl for convert TE to fp8 * fp8 for not only unet * Better cache TE and TE lr * match arg name * Fix with list * Add timeout settings * Fix arg style * Add custom seperator * Fix typo * Fix typo again * Fix dtype error * Fix gradient problem * Fix req grad * fix merge * Fix merge * Resolve merge * arrangement and document * Resolve merge error * Add assert for mixed precision
This commit is contained in:
@@ -2904,6 +2904,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
|
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
|
||||||
) # TODO move to SDXL training, because it is not supported by SD1/2
|
) # TODO move to SDXL training, because it is not supported by SD1/2
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ddp_timeout",
|
"--ddp_timeout",
|
||||||
type=int,
|
type=int,
|
||||||
|
|||||||
@@ -390,16 +390,36 @@ class NetworkTrainer:
|
|||||||
accelerator.print("enable full bf16 training.")
|
accelerator.print("enable full bf16 training.")
|
||||||
network.to(weight_dtype)
|
network.to(weight_dtype)
|
||||||
|
|
||||||
|
unet_weight_dtype = te_weight_dtype = weight_dtype
|
||||||
|
# Experimental Feature: Put base model into fp8 to save vram
|
||||||
|
if args.fp8_base:
|
||||||
|
assert (
|
||||||
|
torch.__version__ >= '2.1.0'
|
||||||
|
), "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
|
||||||
|
assert (
|
||||||
|
args.mixed_precision != 'no'
|
||||||
|
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
|
||||||
|
accelerator.print("enable fp8 training.")
|
||||||
|
unet_weight_dtype = torch.float8_e4m3fn
|
||||||
|
te_weight_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
unet.requires_grad_(False)
|
unet.requires_grad_(False)
|
||||||
unet.to(dtype=weight_dtype)
|
unet.to(dtype=unet_weight_dtype)
|
||||||
for t_enc in text_encoders:
|
for t_enc in text_encoders:
|
||||||
t_enc.requires_grad_(False)
|
t_enc.requires_grad_(False)
|
||||||
|
t_enc.to(dtype=te_weight_dtype)
|
||||||
|
# nn.Embedding not support FP8
|
||||||
|
t_enc.text_model.embeddings.to(dtype=(
|
||||||
|
weight_dtype
|
||||||
|
if te_weight_dtype == torch.float8_e4m3fn
|
||||||
|
else te_weight_dtype
|
||||||
|
))
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
||||||
if train_unet:
|
if train_unet:
|
||||||
unet = accelerator.prepare(unet)
|
unet = accelerator.prepare(unet)
|
||||||
else:
|
else:
|
||||||
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
|
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
|
||||||
if train_text_encoder:
|
if train_text_encoder:
|
||||||
if len(text_encoders) > 1:
|
if len(text_encoders) > 1:
|
||||||
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
|
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
|
||||||
@@ -421,9 +441,6 @@ class NetworkTrainer:
|
|||||||
if train_text_encoder:
|
if train_text_encoder:
|
||||||
t_enc.text_model.embeddings.requires_grad_(True)
|
t_enc.text_model.embeddings.requires_grad_(True)
|
||||||
|
|
||||||
# set top parameter requires_grad = True for gradient checkpointing works
|
|
||||||
if not train_text_encoder: # train U-Net only
|
|
||||||
unet.parameters().__next__().requires_grad_(True)
|
|
||||||
else:
|
else:
|
||||||
unet.eval()
|
unet.eval()
|
||||||
for t_enc in text_encoders:
|
for t_enc in text_encoders:
|
||||||
@@ -778,10 +795,17 @@ class NetworkTrainer:
|
|||||||
args, noise_scheduler, latents
|
args, noise_scheduler, latents
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ensure the hidden state will require grad
|
||||||
|
if args.gradient_checkpointing:
|
||||||
|
for x in noisy_latents:
|
||||||
|
x.requires_grad_(True)
|
||||||
|
for t in text_encoder_conds:
|
||||||
|
t.requires_grad_(True)
|
||||||
|
|
||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
noise_pred = self.call_unet(
|
noise_pred = self.call_unet(
|
||||||
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
|
args, accelerator, unet, noisy_latents.requires_grad_(train_unet), timesteps, text_encoder_conds, batch, weight_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.v_parameterization:
|
if args.v_parameterization:
|
||||||
|
|||||||
Reference in New Issue
Block a user