diff --git a/README.md b/README.md index 45349ba3..5125c663 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,12 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 22, 2024: +Fixed a bug in multi-GPU training. It should work with fine-tuning and LoRA training. `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training. + +`--disable_mmap_load_safetensors` option now works in `flux_train.py`. It speeds up model loading during training in WSL2. It is also effective in reducing memory usage when loading models during multi-GPU training. Please always check if the model is loaded correctly, as it uses a custom implementation of safetensors loading. + + Aug 21, 2024 (update 3): - There is a bug that `--full_bf16` option is enabled even if it is not specified in `flux_train.py`. The bug will be fixed sooner. __Please specify the `--full_bf16` option explicitly, especially when training with 24GB VRAM.__ - Stochastic rounding is now implemented when `--fused_backward_pass` is specified. The implementation is diff --git a/flux_train.py b/flux_train.py index bcf4b956..e7d45e04 100644 --- a/flux_train.py +++ b/flux_train.py @@ -174,7 +174,7 @@ def train(args): # load VAE for caching latents ae = None if cache_latents: - ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") + ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) ae.to(accelerator.device, dtype=weight_dtype) ae.requires_grad_(False) ae.eval() @@ -199,8 +199,8 @@ def train(args): strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy) # load clip_l, t5xxl for caching text encoder outputs - clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu") - t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu") + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors) clip_l.eval() t5xxl.eval() clip_l.requires_grad_(False) @@ -228,7 +228,6 @@ def train(args): if args.sample_prompts is not None: logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") - tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() prompts = load_prompts(args.sample_prompts) @@ -238,9 +237,9 @@ def train(args): for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: if p not in sample_prompts_te_outputs: logger.info(f"cache Text Encoder outputs for prompt: {p}") - tokens_and_masks = tokenize_strategy.tokenize(p) + tokens_and_masks = flux_tokenize_strategy.tokenize(p) sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask ) accelerator.wait_for_everyone() @@ -251,7 +250,9 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - flux = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") + flux = flux_utils.load_flow_model( + name, args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors + ) if args.gradient_checkpointing: flux.enable_gradient_checkpointing(args.cpu_offload_checkpointing) @@ -419,7 +420,7 @@ def train(args): # if we doesn't swap blocks, we can move the model to device flux = accelerator.prepare(flux, device_placement=[not is_swapping_blocks]) if is_swapping_blocks: - flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする @@ -439,8 +440,8 @@ def train(args): double_blocks_to_swap = args.double_blocks_to_swap single_blocks_to_swap = args.single_blocks_to_swap - num_double_blocks = len(flux.double_blocks) - num_single_blocks = len(flux.single_blocks) + num_double_blocks = 19 # len(flux.double_blocks) + num_single_blocks = 38 # len(flux.single_blocks) handled_double_block_indices = set() handled_single_block_indices = set() @@ -537,8 +538,8 @@ def train(args): double_blocks_to_swap = args.double_blocks_to_swap single_blocks_to_swap = args.single_blocks_to_swap - num_double_blocks = len(flux.double_blocks) - num_single_blocks = len(flux.single_blocks) + num_double_blocks = 19 # len(flux.double_blocks) + num_single_blocks = 38 # len(flux.single_blocks) for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: @@ -618,7 +619,7 @@ def train(args): ) if is_swapping_blocks: - flux.prepare_block_swap_before_forward() + accelerator.unwrap_model(flux).prepare_block_swap_before_forward() # For --sample_at_first flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) @@ -660,7 +661,7 @@ def train(args): with torch.no_grad(): input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] text_encoder_conds = text_encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask + flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask ) if args.full_fp16: text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] diff --git a/flux_train_network.py b/flux_train_network.py index 49bd270c..3e2057e9 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -57,19 +57,21 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): name = self.get_flux_model_name(args) # if we load to cpu, flux.to(fp8) takes a long time - model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") + model = flux_utils.load_flow_model( + name, args.pretrained_model_name_or_path, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + ) if args.split_mode: model = self.prepare_split_model(model, weight_dtype, accelerator) - clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu") + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) clip_l.eval() # loading t5xxl to cpu takes a long time, so we should load to gpu in future - t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu") + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) t5xxl.eval() - ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") + ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model diff --git a/library/flux_models.py b/library/flux_models.py index c98d52ec..c045aef6 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -745,7 +745,9 @@ class DoubleStreamBlock(nn.Module): return custom_forward - return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask) + return torch.utils.checkpoint.checkpoint( + create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask, use_reentrant=False + ) else: return self._forward(img, txt, vec, pe, txt_attention_mask) @@ -836,7 +838,7 @@ class SingleStreamBlock(nn.Module): return custom_forward - return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe) + return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=False) else: return self._forward(x, vec, pe) diff --git a/library/flux_utils.py b/library/flux_utils.py index 166cd833..37166933 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -9,7 +9,7 @@ from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config from library import flux_models -from library.utils import setup_logging +from library.utils import setup_logging, MemoryEfficientSafeOpen setup_logging() import logging @@ -19,32 +19,54 @@ logger = logging.getLogger(__name__) MODEL_VERSION_FLUX_V1 = "flux1" -def load_flow_model(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.Flux: +# temporary copy from sd3_utils TODO refactor +def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: torch.dtype = torch.float32): + if disable_mmap: + # return safetensors.torch.load(open(path, "rb").read()) + # use experimental loader + logger.info(f"Loading without mmap (experimental)") + state_dict = {} + with MemoryEfficientSafeOpen(path) as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key).to(device, dtype=dtype) + return state_dict + else: + try: + return load_file(path, device=device) + except: + return load_file(path) # prevent device invalid Error + + +def load_flow_model( + name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False +) -> flux_models.Flux: logger.info(f"Building Flux model {name}") with torch.device("meta"): model = flux_models.Flux(flux_models.configs[name].params).to(dtype) # load_sft doesn't support torch.device logger.info(f"Loading state dict from {ckpt_path}") - sd = load_file(ckpt_path, device=str(device)) + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = model.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded Flux: {info}") return model -def load_ae(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.AutoEncoder: +def load_ae( + name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False +) -> flux_models.AutoEncoder: logger.info("Building AutoEncoder") with torch.device("meta"): ae = flux_models.AutoEncoder(flux_models.configs[name].ae_params).to(dtype) logger.info(f"Loading state dict from {ckpt_path}") - sd = load_file(ckpt_path, device=str(device)) + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = ae.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded AE: {info}") return ae -def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> CLIPTextModel: +def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> CLIPTextModel: logger.info("Building CLIP") CLIPL_CONFIG = { "_name_or_path": "clip-vit-large-patch14/", @@ -139,13 +161,13 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev clip = CLIPTextModel._from_config(config) logger.info(f"Loading state dict from {ckpt_path}") - sd = load_file(ckpt_path, device=str(device)) + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = clip.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded CLIP: {info}") return clip -def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> T5EncoderModel: +def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> T5EncoderModel: T5_CONFIG_JSON = """ { "architectures": [ @@ -185,7 +207,7 @@ def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.devi t5xxl = T5EncoderModel._from_config(config) logger.info(f"Loading state dict from {ckpt_path}") - sd = load_file(ckpt_path, device=str(device)) + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = t5xxl.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded T5xxl: {info}") return t5xxl diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 737af390..b3643cbf 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -137,7 +137,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): t5_attn_mask = data["t5_attn_mask"] if self.apply_t5_attn_mask: - t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) + t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) # FIXME do not mask here!!! return [l_pooled, t5_out, txt_ids, t5_attn_mask] @@ -149,7 +149,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): - # attn_mask is not applied when caching to disk: it is applied when loading from disk + # attn_mask is not applied when caching to disk: it is applied when loading from disk FIXME apply mask when loading l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens( tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk ) diff --git a/library/train_util.py b/library/train_util.py index 8929c192..989758ad 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1104,10 +1104,6 @@ class BaseDataset(torch.utils.data.Dataset): caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy() batch_size = caching_strategy.batch_size or self.batch_size - # if cache to disk, don't cache TE outputs in non-main process - if caching_strategy.cache_to_disk and not is_main_process: - return - logger.info("caching Text Encoder outputs with caching strategy.") image_infos = list(self.image_data.values()) @@ -1120,9 +1116,9 @@ class BaseDataset(torch.utils.data.Dataset): # check disk cache exists and size of latents if caching_strategy.cache_to_disk: - info.text_encoder_outputs_npz = te_out_npz + info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability/main process cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz) - if cache_available: # do not add to batch + if cache_available or not is_main_process: # do not add to batch continue batch.append(info) @@ -2638,7 +2634,7 @@ def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset: return train_dataset_group -def load_image(image_path, alpha=False): +def load_image(image_path, alpha=False): try: with Image.open(image_path) as image: if alpha: diff --git a/library/utils.py b/library/utils.py index 7de22d5a..a1620997 100644 --- a/library/utils.py +++ b/library/utils.py @@ -153,6 +153,95 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: v.contiguous().view(torch.uint8).numpy().tofile(f) +class MemoryEfficientSafeOpen: + # does not support metadata loading + def __init__(self, filename): + self.filename = filename + self.header, self.header_size = self._read_header() + self.file = open(filename, "rb") + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def keys(self): + return [k for k in self.header.keys() if k != "__metadata__"] + + def get_tensor(self, key): + if key not in self.header: + raise KeyError(f"Tensor '{key}' not found in the file") + + metadata = self.header[key] + offset_start, offset_end = metadata["data_offsets"] + + if offset_start == offset_end: + tensor_bytes = None + else: + # adjust offset by header size + self.file.seek(self.header_size + 8 + offset_start) + tensor_bytes = self.file.read(offset_end - offset_start) + + return self._deserialize_tensor(tensor_bytes, metadata) + + def _read_header(self): + with open(self.filename, "rb") as f: + header_size = struct.unpack("