Fix bug in FLUX multi GPU training

This commit is contained in:
kohya-ss
2024-08-22 12:37:41 +09:00
parent e1cd19c0c0
commit 98c91a7625
8 changed files with 156 additions and 38 deletions

View File

@@ -745,7 +745,9 @@ class DoubleStreamBlock(nn.Module):
return custom_forward
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask)
return torch.utils.checkpoint.checkpoint(
create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask, use_reentrant=False
)
else:
return self._forward(img, txt, vec, pe, txt_attention_mask)
@@ -836,7 +838,7 @@ class SingleStreamBlock(nn.Module):
return custom_forward
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe)
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=False)
else:
return self._forward(x, vec, pe)

View File

@@ -9,7 +9,7 @@ from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config
from library import flux_models
from library.utils import setup_logging
from library.utils import setup_logging, MemoryEfficientSafeOpen
setup_logging()
import logging
@@ -19,32 +19,54 @@ logger = logging.getLogger(__name__)
MODEL_VERSION_FLUX_V1 = "flux1"
def load_flow_model(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.Flux:
# temporary copy from sd3_utils TODO refactor
def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: torch.dtype = torch.float32):
if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read())
# use experimental loader
logger.info(f"Loading without mmap (experimental)")
state_dict = {}
with MemoryEfficientSafeOpen(path) as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
return state_dict
else:
try:
return load_file(path, device=device)
except:
return load_file(path) # prevent device invalid Error
def load_flow_model(
name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
) -> flux_models.Flux:
logger.info(f"Building Flux model {name}")
with torch.device("meta"):
model = flux_models.Flux(flux_models.configs[name].params).to(dtype)
# load_sft doesn't support torch.device
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_file(ckpt_path, device=str(device))
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Flux: {info}")
return model
def load_ae(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.AutoEncoder:
def load_ae(
name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
) -> flux_models.AutoEncoder:
logger.info("Building AutoEncoder")
with torch.device("meta"):
ae = flux_models.AutoEncoder(flux_models.configs[name].ae_params).to(dtype)
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_file(ckpt_path, device=str(device))
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = ae.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded AE: {info}")
return ae
def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> CLIPTextModel:
def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> CLIPTextModel:
logger.info("Building CLIP")
CLIPL_CONFIG = {
"_name_or_path": "clip-vit-large-patch14/",
@@ -139,13 +161,13 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev
clip = CLIPTextModel._from_config(config)
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_file(ckpt_path, device=str(device))
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = clip.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded CLIP: {info}")
return clip
def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> T5EncoderModel:
def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> T5EncoderModel:
T5_CONFIG_JSON = """
{
"architectures": [
@@ -185,7 +207,7 @@ def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.devi
t5xxl = T5EncoderModel._from_config(config)
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_file(ckpt_path, device=str(device))
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = t5xxl.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded T5xxl: {info}")
return t5xxl

View File

@@ -137,7 +137,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
t5_attn_mask = data["t5_attn_mask"]
if self.apply_t5_attn_mask:
t5_out = self.mask_t5_attn(t5_out, t5_attn_mask)
t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) # FIXME do not mask here!!!
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
@@ -149,7 +149,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
# attn_mask is not applied when caching to disk: it is applied when loading from disk
# attn_mask is not applied when caching to disk: it is applied when loading from disk FIXME apply mask when loading
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk
)

View File

@@ -1104,10 +1104,6 @@ class BaseDataset(torch.utils.data.Dataset):
caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy()
batch_size = caching_strategy.batch_size or self.batch_size
# if cache to disk, don't cache TE outputs in non-main process
if caching_strategy.cache_to_disk and not is_main_process:
return
logger.info("caching Text Encoder outputs with caching strategy.")
image_infos = list(self.image_data.values())
@@ -1120,9 +1116,9 @@ class BaseDataset(torch.utils.data.Dataset):
# check disk cache exists and size of latents
if caching_strategy.cache_to_disk:
info.text_encoder_outputs_npz = te_out_npz
info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability/main process
cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz)
if cache_available: # do not add to batch
if cache_available or not is_main_process: # do not add to batch
continue
batch.append(info)
@@ -2638,7 +2634,7 @@ def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset:
return train_dataset_group
def load_image(image_path, alpha=False):
def load_image(image_path, alpha=False):
try:
with Image.open(image_path) as image:
if alpha:

View File

@@ -153,6 +153,95 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata:
v.contiguous().view(torch.uint8).numpy().tofile(f)
class MemoryEfficientSafeOpen:
# does not support metadata loading
def __init__(self, filename):
self.filename = filename
self.header, self.header_size = self._read_header()
self.file = open(filename, "rb")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.file.close()
def keys(self):
return [k for k in self.header.keys() if k != "__metadata__"]
def get_tensor(self, key):
if key not in self.header:
raise KeyError(f"Tensor '{key}' not found in the file")
metadata = self.header[key]
offset_start, offset_end = metadata["data_offsets"]
if offset_start == offset_end:
tensor_bytes = None
else:
# adjust offset by header size
self.file.seek(self.header_size + 8 + offset_start)
tensor_bytes = self.file.read(offset_end - offset_start)
return self._deserialize_tensor(tensor_bytes, metadata)
def _read_header(self):
with open(self.filename, "rb") as f:
header_size = struct.unpack("<Q", f.read(8))[0]
header_json = f.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
def _deserialize_tensor(self, tensor_bytes, metadata):
dtype = self._get_torch_dtype(metadata["dtype"])
shape = metadata["shape"]
if tensor_bytes is None:
byte_tensor = torch.empty(0, dtype=torch.uint8)
else:
tensor_bytes = bytearray(tensor_bytes) # make it writable
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
# process float8 types
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]:
return self._convert_float8(byte_tensor, metadata["dtype"], shape)
# convert to the target dtype and reshape
return byte_tensor.view(dtype).reshape(shape)
@staticmethod
def _get_torch_dtype(dtype_str):
dtype_map = {
"F64": torch.float64,
"F32": torch.float32,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I64": torch.int64,
"I32": torch.int32,
"I16": torch.int16,
"I8": torch.int8,
"U8": torch.uint8,
"BOOL": torch.bool,
}
# add float8 types if available
if hasattr(torch, "float8_e5m2"):
dtype_map["F8_E5M2"] = torch.float8_e5m2
if hasattr(torch, "float8_e4m3fn"):
dtype_map["F8_E4M3"] = torch.float8_e4m3fn
return dtype_map.get(dtype_str)
@staticmethod
def _convert_float8(byte_tensor, dtype_str, shape):
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"):
return byte_tensor.view(torch.float8_e5m2).reshape(shape)
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"):
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape)
else:
# # convert to float16 if float8 is not supported
# print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.")
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
# TODO make inf_utils.py