Fix bug in FLUX multi GPU training

This commit is contained in:
kohya-ss
2024-08-22 12:37:41 +09:00
parent e1cd19c0c0
commit 98c91a7625
8 changed files with 156 additions and 38 deletions

View File

@@ -9,6 +9,12 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
Aug 22, 2024:
Fixed a bug in multi-GPU training. It should work with fine-tuning and LoRA training. `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training.
`--disable_mmap_load_safetensors` option now works in `flux_train.py`. It speeds up model loading during training in WSL2. It is also effective in reducing memory usage when loading models during multi-GPU training. Please always check if the model is loaded correctly, as it uses a custom implementation of safetensors loading.
Aug 21, 2024 (update 3):
- There is a bug that `--full_bf16` option is enabled even if it is not specified in `flux_train.py`. The bug will be fixed sooner. __Please specify the `--full_bf16` option explicitly, especially when training with 24GB VRAM.__
- Stochastic rounding is now implemented when `--fused_backward_pass` is specified. The implementation is

View File

@@ -174,7 +174,7 @@ def train(args):
# load VAE for caching latents
ae = None
if cache_latents:
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu")
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
ae.to(accelerator.device, dtype=weight_dtype)
ae.requires_grad_(False)
ae.eval()
@@ -199,8 +199,8 @@ def train(args):
strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy)
# load clip_l, t5xxl for caching text encoder outputs
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu")
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu")
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
clip_l.eval()
t5xxl.eval()
clip_l.requires_grad_(False)
@@ -228,7 +228,6 @@ def train(args):
if args.sample_prompts is not None:
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
prompts = load_prompts(args.sample_prompts)
@@ -238,9 +237,9 @@ def train(args):
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
if p not in sample_prompts_te_outputs:
logger.info(f"cache Text Encoder outputs for prompt: {p}")
tokens_and_masks = tokenize_strategy.tokenize(p)
tokens_and_masks = flux_tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
)
accelerator.wait_for_everyone()
@@ -251,7 +250,9 @@ def train(args):
clean_memory_on_device(accelerator.device)
# load FLUX
flux = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu")
flux = flux_utils.load_flow_model(
name, args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
)
if args.gradient_checkpointing:
flux.enable_gradient_checkpointing(args.cpu_offload_checkpointing)
@@ -419,7 +420,7 @@ def train(args):
# if we doesn't swap blocks, we can move the model to device
flux = accelerator.prepare(flux, device_placement=[not is_swapping_blocks])
if is_swapping_blocks:
flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
@@ -439,8 +440,8 @@ def train(args):
double_blocks_to_swap = args.double_blocks_to_swap
single_blocks_to_swap = args.single_blocks_to_swap
num_double_blocks = len(flux.double_blocks)
num_single_blocks = len(flux.single_blocks)
num_double_blocks = 19 # len(flux.double_blocks)
num_single_blocks = 38 # len(flux.single_blocks)
handled_double_block_indices = set()
handled_single_block_indices = set()
@@ -537,8 +538,8 @@ def train(args):
double_blocks_to_swap = args.double_blocks_to_swap
single_blocks_to_swap = args.single_blocks_to_swap
num_double_blocks = len(flux.double_blocks)
num_single_blocks = len(flux.single_blocks)
num_double_blocks = 19 # len(flux.double_blocks)
num_single_blocks = 38 # len(flux.single_blocks)
for opt_idx, optimizer in enumerate(optimizers):
for param_group in optimizer.param_groups:
@@ -618,7 +619,7 @@ def train(args):
)
if is_swapping_blocks:
flux.prepare_block_swap_before_forward()
accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
# For --sample_at_first
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs)
@@ -660,7 +661,7 @@ def train(args):
with torch.no_grad():
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
text_encoder_conds = text_encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
)
if args.full_fp16:
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]

View File

@@ -57,19 +57,21 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
name = self.get_flux_model_name(args)
# if we load to cpu, flux.to(fp8) takes a long time
model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu")
model = flux_utils.load_flow_model(
name, args.pretrained_model_name_or_path, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
)
if args.split_mode:
model = self.prepare_split_model(model, weight_dtype, accelerator)
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu")
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
clip_l.eval()
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu")
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
t5xxl.eval()
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu")
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model

View File

@@ -745,7 +745,9 @@ class DoubleStreamBlock(nn.Module):
return custom_forward
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask)
return torch.utils.checkpoint.checkpoint(
create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask, use_reentrant=False
)
else:
return self._forward(img, txt, vec, pe, txt_attention_mask)
@@ -836,7 +838,7 @@ class SingleStreamBlock(nn.Module):
return custom_forward
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe)
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=False)
else:
return self._forward(x, vec, pe)

View File

@@ -9,7 +9,7 @@ from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config
from library import flux_models
from library.utils import setup_logging
from library.utils import setup_logging, MemoryEfficientSafeOpen
setup_logging()
import logging
@@ -19,32 +19,54 @@ logger = logging.getLogger(__name__)
MODEL_VERSION_FLUX_V1 = "flux1"
def load_flow_model(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.Flux:
# temporary copy from sd3_utils TODO refactor
def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: torch.dtype = torch.float32):
if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read())
# use experimental loader
logger.info(f"Loading without mmap (experimental)")
state_dict = {}
with MemoryEfficientSafeOpen(path) as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
return state_dict
else:
try:
return load_file(path, device=device)
except:
return load_file(path) # prevent device invalid Error
def load_flow_model(
name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
) -> flux_models.Flux:
logger.info(f"Building Flux model {name}")
with torch.device("meta"):
model = flux_models.Flux(flux_models.configs[name].params).to(dtype)
# load_sft doesn't support torch.device
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_file(ckpt_path, device=str(device))
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Flux: {info}")
return model
def load_ae(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.AutoEncoder:
def load_ae(
name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
) -> flux_models.AutoEncoder:
logger.info("Building AutoEncoder")
with torch.device("meta"):
ae = flux_models.AutoEncoder(flux_models.configs[name].ae_params).to(dtype)
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_file(ckpt_path, device=str(device))
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = ae.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded AE: {info}")
return ae
def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> CLIPTextModel:
def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> CLIPTextModel:
logger.info("Building CLIP")
CLIPL_CONFIG = {
"_name_or_path": "clip-vit-large-patch14/",
@@ -139,13 +161,13 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev
clip = CLIPTextModel._from_config(config)
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_file(ckpt_path, device=str(device))
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = clip.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded CLIP: {info}")
return clip
def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> T5EncoderModel:
def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> T5EncoderModel:
T5_CONFIG_JSON = """
{
"architectures": [
@@ -185,7 +207,7 @@ def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.devi
t5xxl = T5EncoderModel._from_config(config)
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_file(ckpt_path, device=str(device))
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = t5xxl.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded T5xxl: {info}")
return t5xxl

View File

@@ -137,7 +137,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
t5_attn_mask = data["t5_attn_mask"]
if self.apply_t5_attn_mask:
t5_out = self.mask_t5_attn(t5_out, t5_attn_mask)
t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) # FIXME do not mask here!!!
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
@@ -149,7 +149,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
# attn_mask is not applied when caching to disk: it is applied when loading from disk
# attn_mask is not applied when caching to disk: it is applied when loading from disk FIXME apply mask when loading
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk
)

View File

@@ -1104,10 +1104,6 @@ class BaseDataset(torch.utils.data.Dataset):
caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy()
batch_size = caching_strategy.batch_size or self.batch_size
# if cache to disk, don't cache TE outputs in non-main process
if caching_strategy.cache_to_disk and not is_main_process:
return
logger.info("caching Text Encoder outputs with caching strategy.")
image_infos = list(self.image_data.values())
@@ -1120,9 +1116,9 @@ class BaseDataset(torch.utils.data.Dataset):
# check disk cache exists and size of latents
if caching_strategy.cache_to_disk:
info.text_encoder_outputs_npz = te_out_npz
info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability/main process
cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz)
if cache_available: # do not add to batch
if cache_available or not is_main_process: # do not add to batch
continue
batch.append(info)
@@ -2638,7 +2634,7 @@ def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset:
return train_dataset_group
def load_image(image_path, alpha=False):
def load_image(image_path, alpha=False):
try:
with Image.open(image_path) as image:
if alpha:

View File

@@ -153,6 +153,95 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata:
v.contiguous().view(torch.uint8).numpy().tofile(f)
class MemoryEfficientSafeOpen:
# does not support metadata loading
def __init__(self, filename):
self.filename = filename
self.header, self.header_size = self._read_header()
self.file = open(filename, "rb")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.file.close()
def keys(self):
return [k for k in self.header.keys() if k != "__metadata__"]
def get_tensor(self, key):
if key not in self.header:
raise KeyError(f"Tensor '{key}' not found in the file")
metadata = self.header[key]
offset_start, offset_end = metadata["data_offsets"]
if offset_start == offset_end:
tensor_bytes = None
else:
# adjust offset by header size
self.file.seek(self.header_size + 8 + offset_start)
tensor_bytes = self.file.read(offset_end - offset_start)
return self._deserialize_tensor(tensor_bytes, metadata)
def _read_header(self):
with open(self.filename, "rb") as f:
header_size = struct.unpack("<Q", f.read(8))[0]
header_json = f.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
def _deserialize_tensor(self, tensor_bytes, metadata):
dtype = self._get_torch_dtype(metadata["dtype"])
shape = metadata["shape"]
if tensor_bytes is None:
byte_tensor = torch.empty(0, dtype=torch.uint8)
else:
tensor_bytes = bytearray(tensor_bytes) # make it writable
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
# process float8 types
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]:
return self._convert_float8(byte_tensor, metadata["dtype"], shape)
# convert to the target dtype and reshape
return byte_tensor.view(dtype).reshape(shape)
@staticmethod
def _get_torch_dtype(dtype_str):
dtype_map = {
"F64": torch.float64,
"F32": torch.float32,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I64": torch.int64,
"I32": torch.int32,
"I16": torch.int16,
"I8": torch.int8,
"U8": torch.uint8,
"BOOL": torch.bool,
}
# add float8 types if available
if hasattr(torch, "float8_e5m2"):
dtype_map["F8_E5M2"] = torch.float8_e5m2
if hasattr(torch, "float8_e4m3fn"):
dtype_map["F8_E4M3"] = torch.float8_e4m3fn
return dtype_map.get(dtype_str)
@staticmethod
def _convert_float8(byte_tensor, dtype_str, shape):
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"):
return byte_tensor.view(torch.float8_e5m2).reshape(shape)
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"):
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape)
else:
# # convert to float16 if float8 is not supported
# print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.")
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
# TODO make inf_utils.py