mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
feat: add --text_encoder_cpu option to reduce VRAM usage by running text encoders on CPU for training
This commit is contained in:
@@ -184,6 +184,8 @@ The script adds HunyuanImage-2.1 specific arguments. For common arguments (like
|
||||
- Enables training the DiT model in scaled FP8 format. This can significantly reduce VRAM usage (can run with as little as 8GB VRAM when combined with `--blocks_to_swap`), but the training results may vary. This is a newer alternative to the unsupported `--fp8_base` option.
|
||||
* `--fp8_vl`
|
||||
- Use FP8 for the VLM (Qwen2.5-VL) text encoder.
|
||||
* `--text_encoder_cpu`
|
||||
- Runs the text encoders on CPU to reduce VRAM usage. This is useful when VRAM is insufficient (less than 12GB). Encoding one text may take a few minutes (depending on CPU). It is highly recommended to use this option with `--cache_text_encoder_outputs_to_disk` to avoid repeated encoding every time training starts.
|
||||
* `--blocks_to_swap=<integer>` **[Experimental Feature]**
|
||||
- Setting to reduce VRAM usage by swapping parts of the model (Transformer blocks) between CPU and GPU. Specify the number of blocks to swap as an integer (e.g., `18`). Larger values reduce VRAM usage but decrease training speed. Adjust according to your GPU's VRAM capacity. Can be used with `gradient_checkpointing`.
|
||||
* `--cache_text_encoder_outputs`
|
||||
@@ -450,8 +452,9 @@ python hunyuan_image_minimal_inference.py \
|
||||
- `--image_size`: Resolution (inference is most stable at 2048x2048)
|
||||
- `--guidance_scale`: CFG scale (default: 3.5)
|
||||
- `--flow_shift`: Flow matching shift parameter (default: 5.0)
|
||||
- `--text_encoder_cpu`: Run the text encoders on CPU to reduce VRAM usage
|
||||
|
||||
`--split_attn` is not supported (since inference is done one at a time).
|
||||
`--split_attn` is not supported (since inference is done one at a time). `--fp8_vl` is not supported, please use CPU for the text encoder if VRAM is insufficient.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
@@ -464,8 +467,9 @@ python hunyuan_image_minimal_inference.py \
|
||||
- `--image_size`: 解像度(2048x2048で最も安定)
|
||||
- `--guidance_scale`: CFGスケール(推奨: 3.5)
|
||||
- `--flow_shift`: Flow Matchingシフトパラメータ(デフォルト: 5.0)
|
||||
- `--text_encoder_cpu`: テキストエンコーダをCPUで実行してVRAM使用量削減
|
||||
|
||||
`--split_attn`はサポートされていません(1件ずつ推論するため)。
|
||||
`--split_attn`はサポートされていません(1件ずつ推論するため)。`--fp8_vl`もサポートされていません。VRAMが不足する場合はテキストエンコーダをCPUで実行してください。
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
@@ -350,7 +350,7 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
|
||||
vl_dtype = torch.float8_e4m3fn if args.fp8_vl else torch.bfloat16
|
||||
vl_device = "cpu"
|
||||
vl_device = "cpu" # loading to cpu and move to gpu later in cache_text_encoder_outputs_if_needed
|
||||
_, text_encoder_vlm = hunyuan_image_text_encoder.load_qwen2_5_vl(
|
||||
args.text_encoder, dtype=vl_dtype, device=vl_device, disable_mmap=args.disable_mmap_load_safetensors
|
||||
)
|
||||
@@ -440,6 +440,7 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer):
|
||||
def cache_text_encoder_outputs_if_needed(
|
||||
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
||||
):
|
||||
vlm_device = "cpu" if args.text_encoder_cpu else accelerator.device
|
||||
if args.cache_text_encoder_outputs:
|
||||
if not args.lowram:
|
||||
# メモリ消費を減らす
|
||||
@@ -448,9 +449,9 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer):
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
logger.info("move text encoders to gpu")
|
||||
text_encoders[0].to(accelerator.device)
|
||||
text_encoders[1].to(accelerator.device)
|
||||
logger.info(f"move text encoders to {vlm_device} to encode and cache text encoder outputs")
|
||||
text_encoders[0].to(vlm_device)
|
||||
text_encoders[1].to(vlm_device)
|
||||
|
||||
# VLM (bf16) and byT5 (fp16) are used for encoding, so we cannot use autocast here
|
||||
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
||||
@@ -491,8 +492,8 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer):
|
||||
vae.to(org_vae_device)
|
||||
else:
|
||||
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
||||
text_encoders[0].to(accelerator.device)
|
||||
text_encoders[1].to(accelerator.device)
|
||||
text_encoders[0].to(vlm_device)
|
||||
text_encoders[1].to(vlm_device)
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
|
||||
text_encoders = text_encoder # for compatibility
|
||||
@@ -667,8 +668,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=5.0,
|
||||
help="Discrete flow shift for the Euler Discrete Scheduler, default is 5.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは5.0。",
|
||||
)
|
||||
parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
|
||||
parser.add_argument("--fp8_vl", action="store_true", help="use fp8 for VLM text encoder / VLMテキストエンコーダにfp8を使用する")
|
||||
parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
|
||||
parser.add_argument("--fp8_vl", action="store_true", help="Use fp8 for VLM text encoder / VLMテキストエンコーダにfp8を使用する")
|
||||
parser.add_argument(
|
||||
"--text_encoder_cpu", action="store_true", help="Inference on CPU for Text Encoders / テキストエンコーダをCPUで推論する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_enable_tiling",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user