mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix bug in FLUX multi GPU training
This commit is contained in:
@@ -745,7 +745,9 @@ class DoubleStreamBlock(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask)
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask, use_reentrant=False
|
||||
)
|
||||
|
||||
else:
|
||||
return self._forward(img, txt, vec, pe, txt_attention_mask)
|
||||
@@ -836,7 +838,7 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe)
|
||||
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=False)
|
||||
else:
|
||||
return self._forward(x, vec, pe)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config
|
||||
|
||||
from library import flux_models
|
||||
|
||||
from library.utils import setup_logging
|
||||
from library.utils import setup_logging, MemoryEfficientSafeOpen
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -19,32 +19,54 @@ logger = logging.getLogger(__name__)
|
||||
MODEL_VERSION_FLUX_V1 = "flux1"
|
||||
|
||||
|
||||
def load_flow_model(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.Flux:
|
||||
# temporary copy from sd3_utils TODO refactor
|
||||
def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: torch.dtype = torch.float32):
|
||||
if disable_mmap:
|
||||
# return safetensors.torch.load(open(path, "rb").read())
|
||||
# use experimental loader
|
||||
logger.info(f"Loading without mmap (experimental)")
|
||||
state_dict = {}
|
||||
with MemoryEfficientSafeOpen(path) as f:
|
||||
for key in f.keys():
|
||||
state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
|
||||
return state_dict
|
||||
else:
|
||||
try:
|
||||
return load_file(path, device=device)
|
||||
except:
|
||||
return load_file(path) # prevent device invalid Error
|
||||
|
||||
|
||||
def load_flow_model(
|
||||
name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
|
||||
) -> flux_models.Flux:
|
||||
logger.info(f"Building Flux model {name}")
|
||||
with torch.device("meta"):
|
||||
model = flux_models.Flux(flux_models.configs[name].params).to(dtype)
|
||||
|
||||
# load_sft doesn't support torch.device
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_file(ckpt_path, device=str(device))
|
||||
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
info = model.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded Flux: {info}")
|
||||
return model
|
||||
|
||||
|
||||
def load_ae(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.AutoEncoder:
|
||||
def load_ae(
|
||||
name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
|
||||
) -> flux_models.AutoEncoder:
|
||||
logger.info("Building AutoEncoder")
|
||||
with torch.device("meta"):
|
||||
ae = flux_models.AutoEncoder(flux_models.configs[name].ae_params).to(dtype)
|
||||
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_file(ckpt_path, device=str(device))
|
||||
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
info = ae.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded AE: {info}")
|
||||
return ae
|
||||
|
||||
|
||||
def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> CLIPTextModel:
|
||||
def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> CLIPTextModel:
|
||||
logger.info("Building CLIP")
|
||||
CLIPL_CONFIG = {
|
||||
"_name_or_path": "clip-vit-large-patch14/",
|
||||
@@ -139,13 +161,13 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev
|
||||
clip = CLIPTextModel._from_config(config)
|
||||
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_file(ckpt_path, device=str(device))
|
||||
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
info = clip.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded CLIP: {info}")
|
||||
return clip
|
||||
|
||||
|
||||
def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> T5EncoderModel:
|
||||
def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> T5EncoderModel:
|
||||
T5_CONFIG_JSON = """
|
||||
{
|
||||
"architectures": [
|
||||
@@ -185,7 +207,7 @@ def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.devi
|
||||
t5xxl = T5EncoderModel._from_config(config)
|
||||
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_file(ckpt_path, device=str(device))
|
||||
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
info = t5xxl.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded T5xxl: {info}")
|
||||
return t5xxl
|
||||
|
||||
@@ -137,7 +137,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
t5_attn_mask = data["t5_attn_mask"]
|
||||
|
||||
if self.apply_t5_attn_mask:
|
||||
t5_out = self.mask_t5_attn(t5_out, t5_attn_mask)
|
||||
t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) # FIXME do not mask here!!!
|
||||
|
||||
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
|
||||
|
||||
@@ -149,7 +149,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
# attn_mask is not applied when caching to disk: it is applied when loading from disk
|
||||
# attn_mask is not applied when caching to disk: it is applied when loading from disk FIXME apply mask when loading
|
||||
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk
|
||||
)
|
||||
|
||||
@@ -1104,10 +1104,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy()
|
||||
batch_size = caching_strategy.batch_size or self.batch_size
|
||||
|
||||
# if cache to disk, don't cache TE outputs in non-main process
|
||||
if caching_strategy.cache_to_disk and not is_main_process:
|
||||
return
|
||||
|
||||
logger.info("caching Text Encoder outputs with caching strategy.")
|
||||
image_infos = list(self.image_data.values())
|
||||
|
||||
@@ -1120,9 +1116,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
if caching_strategy.cache_to_disk:
|
||||
info.text_encoder_outputs_npz = te_out_npz
|
||||
info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability/main process
|
||||
cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz)
|
||||
if cache_available: # do not add to batch
|
||||
if cache_available or not is_main_process: # do not add to batch
|
||||
continue
|
||||
|
||||
batch.append(info)
|
||||
@@ -2638,7 +2634,7 @@ def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset:
|
||||
return train_dataset_group
|
||||
|
||||
|
||||
def load_image(image_path, alpha=False):
|
||||
def load_image(image_path, alpha=False):
|
||||
try:
|
||||
with Image.open(image_path) as image:
|
||||
if alpha:
|
||||
|
||||
@@ -153,6 +153,95 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata:
|
||||
v.contiguous().view(torch.uint8).numpy().tofile(f)
|
||||
|
||||
|
||||
class MemoryEfficientSafeOpen:
|
||||
# does not support metadata loading
|
||||
def __init__(self, filename):
|
||||
self.filename = filename
|
||||
self.header, self.header_size = self._read_header()
|
||||
self.file = open(filename, "rb")
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.file.close()
|
||||
|
||||
def keys(self):
|
||||
return [k for k in self.header.keys() if k != "__metadata__"]
|
||||
|
||||
def get_tensor(self, key):
|
||||
if key not in self.header:
|
||||
raise KeyError(f"Tensor '{key}' not found in the file")
|
||||
|
||||
metadata = self.header[key]
|
||||
offset_start, offset_end = metadata["data_offsets"]
|
||||
|
||||
if offset_start == offset_end:
|
||||
tensor_bytes = None
|
||||
else:
|
||||
# adjust offset by header size
|
||||
self.file.seek(self.header_size + 8 + offset_start)
|
||||
tensor_bytes = self.file.read(offset_end - offset_start)
|
||||
|
||||
return self._deserialize_tensor(tensor_bytes, metadata)
|
||||
|
||||
def _read_header(self):
|
||||
with open(self.filename, "rb") as f:
|
||||
header_size = struct.unpack("<Q", f.read(8))[0]
|
||||
header_json = f.read(header_size).decode("utf-8")
|
||||
return json.loads(header_json), header_size
|
||||
|
||||
def _deserialize_tensor(self, tensor_bytes, metadata):
|
||||
dtype = self._get_torch_dtype(metadata["dtype"])
|
||||
shape = metadata["shape"]
|
||||
|
||||
if tensor_bytes is None:
|
||||
byte_tensor = torch.empty(0, dtype=torch.uint8)
|
||||
else:
|
||||
tensor_bytes = bytearray(tensor_bytes) # make it writable
|
||||
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
|
||||
|
||||
# process float8 types
|
||||
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]:
|
||||
return self._convert_float8(byte_tensor, metadata["dtype"], shape)
|
||||
|
||||
# convert to the target dtype and reshape
|
||||
return byte_tensor.view(dtype).reshape(shape)
|
||||
|
||||
@staticmethod
|
||||
def _get_torch_dtype(dtype_str):
|
||||
dtype_map = {
|
||||
"F64": torch.float64,
|
||||
"F32": torch.float32,
|
||||
"F16": torch.float16,
|
||||
"BF16": torch.bfloat16,
|
||||
"I64": torch.int64,
|
||||
"I32": torch.int32,
|
||||
"I16": torch.int16,
|
||||
"I8": torch.int8,
|
||||
"U8": torch.uint8,
|
||||
"BOOL": torch.bool,
|
||||
}
|
||||
# add float8 types if available
|
||||
if hasattr(torch, "float8_e5m2"):
|
||||
dtype_map["F8_E5M2"] = torch.float8_e5m2
|
||||
if hasattr(torch, "float8_e4m3fn"):
|
||||
dtype_map["F8_E4M3"] = torch.float8_e4m3fn
|
||||
return dtype_map.get(dtype_str)
|
||||
|
||||
@staticmethod
|
||||
def _convert_float8(byte_tensor, dtype_str, shape):
|
||||
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"):
|
||||
return byte_tensor.view(torch.float8_e5m2).reshape(shape)
|
||||
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"):
|
||||
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape)
|
||||
else:
|
||||
# # convert to float16 if float8 is not supported
|
||||
# print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.")
|
||||
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
|
||||
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
|
||||
|
||||
|
||||
# TODO make inf_utils.py
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user