mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
[Experimental Feature] FP8 weight dtype for base model when running train_network (or sdxl_train_network) (#1057)
* Add fp8 support * remove some debug prints * Better implementation for te * Fix some misunderstanding * as same as unet, add explicit convert * better impl for convert TE to fp8 * fp8 for not only unet * Better cache TE and TE lr * match arg name * Fix with list * Add timeout settings * Fix arg style * Add custom seperator * Fix typo * Fix typo again * Fix dtype error * Fix gradient problem * Fix req grad * fix merge * Fix merge * Resolve merge * arrangement and document * Resolve merge error * Add assert for mixed precision
This commit is contained in:
@@ -2904,6 +2904,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument(
|
||||
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
|
||||
) # TODO move to SDXL training, because it is not supported by SD1/2
|
||||
parser.add_argument(
|
||||
"--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ddp_timeout",
|
||||
type=int,
|
||||
|
||||
@@ -390,16 +390,36 @@ class NetworkTrainer:
|
||||
accelerator.print("enable full bf16 training.")
|
||||
network.to(weight_dtype)
|
||||
|
||||
unet_weight_dtype = te_weight_dtype = weight_dtype
|
||||
# Experimental Feature: Put base model into fp8 to save vram
|
||||
if args.fp8_base:
|
||||
assert (
|
||||
torch.__version__ >= '2.1.0'
|
||||
), "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
|
||||
assert (
|
||||
args.mixed_precision != 'no'
|
||||
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
|
||||
accelerator.print("enable fp8 training.")
|
||||
unet_weight_dtype = torch.float8_e4m3fn
|
||||
te_weight_dtype = torch.float8_e4m3fn
|
||||
|
||||
unet.requires_grad_(False)
|
||||
unet.to(dtype=weight_dtype)
|
||||
unet.to(dtype=unet_weight_dtype)
|
||||
for t_enc in text_encoders:
|
||||
t_enc.requires_grad_(False)
|
||||
t_enc.to(dtype=te_weight_dtype)
|
||||
# nn.Embedding not support FP8
|
||||
t_enc.text_model.embeddings.to(dtype=(
|
||||
weight_dtype
|
||||
if te_weight_dtype == torch.float8_e4m3fn
|
||||
else te_weight_dtype
|
||||
))
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
||||
if train_unet:
|
||||
unet = accelerator.prepare(unet)
|
||||
else:
|
||||
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
|
||||
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
|
||||
if train_text_encoder:
|
||||
if len(text_encoders) > 1:
|
||||
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
|
||||
@@ -421,9 +441,6 @@ class NetworkTrainer:
|
||||
if train_text_encoder:
|
||||
t_enc.text_model.embeddings.requires_grad_(True)
|
||||
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
if not train_text_encoder: # train U-Net only
|
||||
unet.parameters().__next__().requires_grad_(True)
|
||||
else:
|
||||
unet.eval()
|
||||
for t_enc in text_encoders:
|
||||
@@ -778,10 +795,17 @@ class NetworkTrainer:
|
||||
args, noise_scheduler, latents
|
||||
)
|
||||
|
||||
# ensure the hidden state will require grad
|
||||
if args.gradient_checkpointing:
|
||||
for x in noisy_latents:
|
||||
x.requires_grad_(True)
|
||||
for t in text_encoder_conds:
|
||||
t.requires_grad_(True)
|
||||
|
||||
# Predict the noise residual
|
||||
with accelerator.autocast():
|
||||
noise_pred = self.call_unet(
|
||||
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
|
||||
args, accelerator, unet, noisy_latents.requires_grad_(train_unet), timesteps, text_encoder_conds, batch, weight_dtype
|
||||
)
|
||||
|
||||
if args.v_parameterization:
|
||||
|
||||
Reference in New Issue
Block a user