[Experimental Feature] FP8 weight dtype for base model when running train_network (or sdxl_train_network) (#1057)

* Add fp8 support

* remove some debug prints

* Better implementation for te

* Fix some misunderstanding

* as same as unet, add explicit convert

* better impl for convert TE to fp8

* fp8 for not only unet

* Better cache TE and TE lr

* match arg name

* Fix with list

* Add timeout settings

* Fix arg style

* Add custom seperator

* Fix typo

* Fix typo again

* Fix dtype error

* Fix gradient problem

* Fix req grad

* fix merge

* Fix merge

* Resolve merge

* arrangement and document

* Resolve merge error

* Add assert for mixed precision
This commit is contained in:
Kohaku-Blueleaf
2024-01-20 08:46:53 +08:00
committed by GitHub
parent 0395a35543
commit 9cfa68c92f
2 changed files with 33 additions and 6 deletions

View File

@@ -2904,6 +2904,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
) # TODO move to SDXL training, because it is not supported by SD1/2
parser.add_argument(
"--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う"
)
parser.add_argument(
"--ddp_timeout",
type=int,