Compare commits

...

5 Commits

Author SHA1 Message Date
Kohya S
5585347a52 Merge branch 'sd3' into feat-flux-chroma-fp8-scaled 2026-01-18 16:58:01 +09:00
Kohya S
b9357c662d Merge branch 'sd3' into feat-flux-chroma-fp8-scaled 2026-01-18 14:37:12 +09:00
Kohya S.
08dc0858b5 Merge pull request #2224 from rockerBOO/block-swap-fp8-scaled
Block swap and fp8 scaled, fp8 already quantized warning
2026-01-18 14:35:32 +09:00
rockerBOO
f9710863ca Set is_swapping_blocks before loading_device, add warning for ignoring fp8_scaled if already fp8 2025-10-10 15:58:21 -04:00
Kohya S
f6137a7175 feat: add fp8 optimization for FLUX 2025-09-25 22:10:11 +09:00
8 changed files with 237 additions and 111 deletions

View File

@@ -18,6 +18,7 @@ from safetensors.torch import load_file
from library import device_utils
from library.device_utils import init_ipex, get_preferred_device
from library.safetensors_utils import MemoryEfficientSafeOpen
from networks import oft_flux
init_ipex()
@@ -325,6 +326,7 @@ def generate_image(
# generate image
logger.info("Generating image...")
if args.offload and not (args.blocks_to_swap is not None and args.blocks_to_swap > 0):
model = model.to(device)
if steps is None:
steps = 4 if is_schnell else 50
@@ -411,12 +413,16 @@ if __name__ == "__main__":
parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae")
parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl")
parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux")
parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for flux model")
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev")
parser.add_argument("--guidance", type=float, default=3.5)
parser.add_argument("--negative_prompt", type=str, default=None)
parser.add_argument("--cfg_scale", type=float, default=1.0)
parser.add_argument("--offload", action="store_true", help="Offload to CPU")
parser.add_argument(
"--blocks_to_swap", type=int, default=None, help="Number of blocks to swap between CPU and GPU to reduce memory usage"
)
parser.add_argument(
"--lora_weights",
type=str,
@@ -442,6 +448,8 @@ if __name__ == "__main__":
t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype)
ae_dtype = str_to_dtype(args.ae_dtype, dtype)
flux_dtype = str_to_dtype(args.flux_dtype, dtype)
if args.fp8_scaled and flux_dtype.itemsize == 1:
raise ValueError("fp8_scaled is not supported for fp8 flux_dtype")
logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}")
@@ -470,13 +478,68 @@ if __name__ == "__main__":
# if is_fp8(t5xxl_dtype):
# t5xxl = accelerator.prepare(t5xxl)
# check LoRA and OFT weights can be mergeable
mergeable_lora_weights = None
mergeable_lora_multipliers = None
if args.fp8_scaled and args.lora_weights:
assert args.merge_lora_weights, "LoRA weights must be merged when using fp8_scaled"
mergeable_lora_weights = []
mergeable_lora_multipliers = []
for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = 1.0
with MemoryEfficientSafeOpen(weights_file) as f:
keys = list(f.keys())
is_lora = is_oft = False
includes_text_encoder = False
for key in keys:
if key.startswith("lora"):
is_lora = True
if key.startswith("oft"):
is_oft = True
if key.startswith("lora_te") or key.startswith("oft_te"):
includes_text_encoder = True
if (is_lora or is_oft) and includes_text_encoder:
break
if includes_text_encoder or is_oft:
raise ValueError(
f"LoRA weights {weights_file} that includes text encoder or OFT weights cannot be merged when using fp8_scaled"
)
mergeable_lora_weights.append(weights_file)
mergeable_lora_multipliers.append(multiplier)
# DiT
loading_dtype = None if args.fp8_scaled else flux_dtype
loading_device = "cpu" if args.blocks_to_swap or args.offload else device
is_schnell, model = flux_utils.load_flow_model(
args.ckpt_path, None, loading_device, disable_mmap=True, model_type=args.model_type
device,
args.ckpt_path,
loading_dtype,
loading_device,
args.model_type,
args.fp8_scaled,
lora_weights_list=mergeable_lora_weights,
lora_multipliers=mergeable_lora_multipliers,
)
model.eval()
logger.info(f"Casting model to {flux_dtype}")
model.to(flux_dtype) # make sure model is dtype
if args.blocks_to_swap is not None and args.blocks_to_swap > 0:
model.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=False)
model.move_to_device_except_swap_blocks(device)
model.prepare_block_swap_before_forward()
# logger.info(f"Casting model to {flux_dtype}")
# model.to(flux_dtype) # make sure model is dtype
# if is_fp8(flux_dtype):
# model = accelerator.prepare(model)
# if args.offload:
@@ -494,6 +557,7 @@ if __name__ == "__main__":
# LoRA
lora_models: List[lora_flux.LoRANetwork] = []
if not args.fp8_scaled: # LoRA cannot be applied after fp8 scaling and quantization
for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")

View File

@@ -271,7 +271,7 @@ def train(args):
# load FLUX
_, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux"
accelerator.device, args.pretrained_model_name_or_path, weight_dtype, "cpu", model_type="flux"
)
if args.gradient_checkpointing:
@@ -302,7 +302,7 @@ def train(args):
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
# This idea is based on 2kpr's great work. Thank you!
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
flux.enable_block_swap(args.blocks_to_swap, accelerator.device)
flux.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=True)
if not cache_latents:
# load VAE here if not cached

View File

@@ -265,7 +265,7 @@ def train(args):
# load FLUX
is_schnell, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux"
accelerator.device, args.pretrained_model_name_or_path, weight_dtype, "cpu", model_type="flux"
)
flux.requires_grad_(False)
@@ -304,7 +304,7 @@ def train(args):
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
# This idea is based on 2kpr's great work. Thank you!
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
flux.enable_block_swap(args.blocks_to_swap, accelerator.device)
flux.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=True)
flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
# ControlNet only has two blocks, so we can keep it on GPU
# controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device)

View File

@@ -51,7 +51,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
self.use_clip_l = True
else:
self.use_clip_l = False # Chroma does not use CLIP-L
assert args.apply_t5_attn_mask, "apply_t5_attn_mask must be True for Chroma / Chromaではapply_t5_attn_maskを指定する必要があります"
assert (
args.apply_t5_attn_mask
), "apply_t5_attn_mask must be True for Chroma / Chromaではapply_t5_attn_maskを指定する必要があります"
if args.fp8_base_unet:
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
@@ -99,18 +101,18 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
loading_dtype = None if args.fp8_base else weight_dtype
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
loading_dtype = None if args.fp8_base or args.fp8_scaled else weight_dtype
loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
# load with quantization if needed
_, model = flux_utils.load_flow_model(
args.pretrained_model_name_or_path,
loading_dtype,
"cpu",
disable_mmap=args.disable_mmap_load_safetensors,
model_type=self.model_type,
accelerator.device, args.pretrained_model_name_or_path, loading_dtype, loading_device, self.model_type, args.fp8_scaled
)
if args.fp8_base:
if args.fp8_base and not args.fp8_scaled:
# check dtype of model
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
@@ -125,12 +127,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# if args.split_mode:
# model = self.prepare_split_model(model, weight_dtype, accelerator)
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if self.is_swapping_blocks:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
model.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=True)
if self.use_clip_l:
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
@@ -309,6 +309,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def shift_scale_latents(self, args, latents):
return latents
def cast_unet(self, args):
return not args.fp8_scaled
def get_noise_pred_and_target(
self,
args,
@@ -525,6 +528,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_dit_training_arguments(parser)
flux_train_utils.add_flux_train_arguments(parser)
parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
parser.add_argument(
"--split_mode",
action="store_true",

View File

@@ -691,25 +691,32 @@ class DoubleStreamBlock(nn.Module):
) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
del vec
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = self.img_norm1(img.to(torch.float32)).to(img.dtype)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
del img_qkv
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = self.txt_norm1(txt.to(torch.float32)).to(txt.dtype)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
del txt_modulated
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
q = torch.cat((txt_q, img_q), dim=2)
del txt_q, img_q
k = torch.cat((txt_k, img_k), dim=2)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
# make attention mask if not None
attn_mask = None
@@ -725,14 +732,24 @@ class DoubleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
del q, k, v, attn
# calculate the img blocks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
del img_mod1, img_attn
img = img + img_mod2.gate * self.img_mlp(
(1 + img_mod2.scale) * self.img_norm2(img.to(torch.float32)).to(img.dtype) + img_mod2.shift
)
del img_mod2
# calculate the txt blocks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
del txt_mod1, txt_attn
txt = txt + txt_mod2.gate * self.txt_mlp(
(1 + txt_mod2.scale) * self.txt_norm2(txt.to(torch.float32)).to(txt.dtype) + txt_mod2.shift
)
del txt_mod2
return img, txt
def forward(
@@ -805,10 +822,14 @@ class SingleStreamBlock(nn.Module):
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
del vec
x_mod = (1 + mod.scale) * self.pre_norm(x.to(torch.float32)) + mod.shift
x_mod = x_mod.to(x.dtype)
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
del qkv
q, k = self.norm(q, k, v)
# make attention mask if not None
@@ -831,9 +852,12 @@ class SingleStreamBlock(nn.Module):
# compute attention
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
del q, k, v
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
del attn, mlp
return x + mod.gate * output
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
@@ -969,7 +993,7 @@ class Flux(nn.Module):
print("FLUX: Gradient checkpointing disabled.")
def enable_block_swap(self, num_blocks: int, device: torch.device):
def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool = False):
self.blocks_to_swap = num_blocks
double_blocks_to_swap = num_blocks // 2
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
@@ -980,10 +1004,10 @@ class Flux(nn.Module):
)
self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, double_blocks_to_swap, device # , debug=True
self.double_blocks, double_blocks_to_swap, device, supports_backward=supports_backward # , debug=True
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, single_blocks_to_swap, device # , debug=True
self.single_blocks, single_blocks_to_swap, device, supports_backward=supports_backward # , debug=True
)
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
@@ -1215,7 +1239,7 @@ class ControlNetFlux(nn.Module):
print("FLUX: Gradient checkpointing disabled.")
def enable_block_swap(self, num_blocks: int, device: torch.device):
def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool = False):
self.blocks_to_swap = num_blocks
double_blocks_to_swap = num_blocks // 2
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
@@ -1226,10 +1250,10 @@ class ControlNetFlux(nn.Module):
)
self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, double_blocks_to_swap, device # , debug=True
self.double_blocks, double_blocks_to_swap, device, supports_backward=supports_backward # , debug=True
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, single_blocks_to_swap, device # , debug=True
self.single_blocks, single_blocks_to_swap, device, supports_backward=supports_backward # , debug=True
)
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."

View File

@@ -1,7 +1,7 @@
import json
import os
from dataclasses import replace
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import einops
import torch
@@ -10,6 +10,8 @@ from safetensors import safe_open
from safetensors.torch import load_file
from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel
from library.fp8_optimization_utils import apply_fp8_monkey_patch
from library.lora_utils import load_safetensors_with_lora_and_fp8
from library.utils import setup_logging
setup_logging()
@@ -25,6 +27,9 @@ MODEL_NAME_DEV = "dev"
MODEL_NAME_SCHNELL = "schnell"
MODEL_VERSION_CHROMA = "chroma"
FP8_OPTIMIZATION_TARGET_KEYS = ["double_blocks", "single_blocks"]
FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_mod", "norm", "modulation"]
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
"""
@@ -93,17 +98,23 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int
def load_flow_model(
ckpt_path: str,
dtype: Optional[torch.dtype],
device: Union[str, torch.device],
disable_mmap: bool = False,
ckpt_path: str,
dit_weight_dtype: Optional[torch.dtype],
loading_device: Union[str, torch.device],
model_type: str = "flux",
fp8_scaled: bool = False,
lora_weights_list: Optional[Dict[str, torch.Tensor]] = None,
lora_multipliers: Optional[list[float]] = None,
) -> Tuple[bool, flux_models.Flux]:
device = torch.device(device) # device for calculation, typically "cuda"
loading_device = torch.device(loading_device)
# build model
if model_type == "flux":
is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
# build model
logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
with torch.device("meta"):
params = flux_models.configs[name].params
@@ -117,31 +128,8 @@ def load_flow_model(
params = replace(params, depth_single_blocks=num_single_blocks)
model = flux_models.Flux(params)
if dtype is not None:
model = model.to(dtype)
# load_sft doesn't support torch.device
logger.info(f"Loading state dict from {ckpt_path}")
sd = {}
for ckpt_path in ckpt_paths:
sd.update(load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype))
# convert Diffusers to BFL
if is_diffusers:
logger.info("Converting Diffusers to BFL")
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
logger.info("Converted Diffusers to BFL")
# if the key has annoying prefix, remove it
for key in list(sd.keys()):
new_key = key.replace("model.diffusion_model.", "")
if new_key == key:
break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key)
info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Flux: {info}")
return is_schnell, model
if dit_weight_dtype is not None:
model = model.to(dit_weight_dtype)
elif model_type == "chroma":
from . import chroma_models
@@ -150,12 +138,25 @@ def load_flow_model(
logger.info("Building Chroma model")
with torch.device("meta"):
model = chroma_models.Chroma(chroma_models.chroma_params)
if dtype is not None:
model = model.to(dtype)
if dit_weight_dtype is not None:
model = model.to(dit_weight_dtype)
# load_sft doesn't support torch.device
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
ckpt_paths = [ckpt_path]
# load state dict
logger.info(f"Loading DiT model from {ckpt_paths}, device={loading_device}")
sd = load_safetensors_with_lora_and_fp8(
model_files=ckpt_paths,
lora_weights_list=lora_weights_list,
lora_multipliers=lora_multipliers,
fp8_optimization=fp8_scaled,
calc_device=device,
move_to_device=(loading_device == device),
dit_weight_dtype=dit_weight_dtype,
target_keys=FP8_OPTIMIZATION_TARGET_KEYS,
exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS,
)
# if the key has annoying prefix, remove it
for key in list(sd.keys()):
@@ -164,14 +165,32 @@ def load_flow_model(
break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key)
if fp8_scaled:
apply_fp8_monkey_patch(model, sd, use_scaled_mm=False)
if loading_device.type != "cpu": # in case of no block swapping
# make sure all the model weights are on the loading_device
logger.info(f"Moving weights to {loading_device}")
for key in sd.keys():
sd[key] = sd[key].to(loading_device)
if model_type == "flux":
# convert Diffusers to BFL
if is_diffusers:
logger.info("Converting Diffusers to BFL")
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
logger.info("Converted Diffusers to BFL")
info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Flux: {info}")
return is_schnell, model
elif model_type == "chroma":
info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Chroma: {info}")
is_schnell = False # Chroma is not schnell
return is_schnell, model
else:
raise ValueError(f"Unsupported model_type: {model_type}. Supported types are 'flux' and 'chroma'.")
def load_ae(
ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False

View File

@@ -306,11 +306,22 @@ def load_safetensors_with_fp8_optimization(
state_dict[key] = value
continue
original_dtype = value.dtype
if original_dtype in (torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e4m3fnuz, torch.float8_e5m2fnuz):
logger.warning(
f"Skipping FP8 quantization for key {key} as it is already in FP8 format ({original_dtype}). "
"Loading checkpoint as-is without re-quantization."
)
target_device = calc_device if (calc_device is not None and move_to_device) else original_device
value = value.to(target_device)
state_dict[key] = value
continue
# Move to calculation device
if calc_device is not None:
value = value.to(calc_device)
original_dtype = value.dtype
quantized_weight, scale_tensor = quantize_weight(
key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size
)

View File

@@ -824,16 +824,19 @@ class NetworkTrainer:
accelerator.print("enable full bf16 training.")
network.to(weight_dtype)
unet_weight_dtype = te_weight_dtype = weight_dtype
unet_weight_dtype = weight_dtype
te_weight_dtype = weight_dtype if self.cast_text_encoder(args) else None
# Experimental Feature: Put base model into fp8 to save vram
if args.fp8_base or args.fp8_base_unet:
assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
assert (
args.mixed_precision != "no"
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
if self.cast_unet(args):
accelerator.print("enable fp8 training for U-Net.")
unet_weight_dtype = torch.float8_e4m3fn
if self.cast_text_encoder(args):
if not args.fp8_base_unet:
accelerator.print("enable fp8 training for Text Encoder.")
te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn
@@ -843,6 +846,7 @@ class NetworkTrainer:
# logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}")
# unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
if self.cast_unet(args):
logger.info(f"set U-Net weight dtype to {unet_weight_dtype}")
unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator