From 87526942a67fd71bb775bc479b0a7449df516dd8 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Fri, 12 Jul 2024 22:56:38 +0800 Subject: [PATCH 01/87] judge image size for using diff interpolation --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3c..74720fec 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2362,7 +2362,7 @@ def trim_and_resize_if_required( if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA if image_width > resized_size[0] and image_height > resized_size[1] else cv2.INTER_LANCZOS4) image_height, image_width = image.shape[0:2] From 2e67978ee243a20f169ce76d7644bb1f9dec9bad Mon Sep 17 00:00:00 2001 From: Millie Date: Thu, 18 Jul 2024 11:52:58 -0700 Subject: [PATCH 02/87] Generate sample images without having CUDA (such as on Macs) --- library/train_util.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3c..9b0397d7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5229,7 +5229,7 @@ def sample_images_common( clean_memory_on_device(accelerator.device) torch.set_rng_state(rng_state) - if cuda_rng_state is not None: + if torch.cuda.is_available() and cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) @@ -5263,11 +5263,13 @@ def sample_image_inference( if seed is not None: torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) else: # True random sample image generation torch.seed() - torch.cuda.seed() + if torch.cuda.is_available(): + torch.cuda.seed() scheduler = get_my_scheduler( sample_sampler=sampler_name, @@ -5302,8 +5304,9 @@ def sample_image_inference( controlnet_image=controlnet_image, ) - with torch.cuda.device(torch.cuda.current_device()): - torch.cuda.empty_cache() + if torch.cuda.is_available(): + with torch.cuda.device(torch.cuda.current_device()): + torch.cuda.empty_cache() image = pipeline.latents_to_image(latents)[0] From 1f16b80e88b1c4f05d49b4fc328d3b9b105ebcbe Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sat, 20 Jul 2024 21:35:24 +0800 Subject: [PATCH 03/87] Revert "judge image size for using diff interpolation" This reverts commit 87526942a67fd71bb775bc479b0a7449df516dd8. --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 74720fec..15c23f3c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2362,7 +2362,7 @@ def trim_and_resize_if_required( if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA if image_width > resized_size[0] and image_height > resized_size[1] else cv2.INTER_LANCZOS4) + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ image_height, image_width = image.shape[0:2] From 9ca7a5b6cc99e25820a1aa6d02a779004d73bca0 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sat, 20 Jul 2024 21:59:11 +0800 Subject: [PATCH 04/87] instead cv2 LANCZOS4 resize to pil resize --- finetune/tag_images_by_wd14_tagger.py | 8 +++++--- library/train_util.py | 11 ++++++----- library/utils.py | 14 +++++++++++++- tools/detect_face_rotate.py | 7 +++++-- tools/resize_images_to_resolution.py | 11 +++++++---- 5 files changed, 36 insertions(+), 15 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index a327bbd6..6f5bdd36 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -11,7 +11,7 @@ from PIL import Image from tqdm import tqdm import library.train_util as train_util -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging @@ -42,8 +42,10 @@ def preprocess_image(image): pad_t = pad_y // 2 image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255) - interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 - image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) + if size > IMAGE_SIZE: + image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA) + else: + image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE)) image = image.astype(np.float32) return image diff --git a/library/train_util.py b/library/train_util.py index 15c23f3c..160e3b44 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -71,7 +71,7 @@ import library.model_util as model_util import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec import library.deepspeed_utils as deepspeed_utils -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging @@ -2028,9 +2028,7 @@ class ControlNetDataset(BaseDataset): # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # resize to target if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: - cond_img = cv2.resize( - cond_img, (int(target_size_hw[1]), int(target_size_hw[0])), interpolation=cv2.INTER_LANCZOS4 - ) + cond_img=pil_resize(cond_img,(int(target_size_hw[1]), int(target_size_hw[0]))) if flipped: cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride @@ -2362,7 +2360,10 @@ def trim_and_resize_if_required( if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + if image_width > resized_size[0] and image_height > resized_size[1]: + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + else: + image = pil_resize(image, resized_size) image_height, image_width = image.shape[0:2] diff --git a/library/utils.py b/library/utils.py index 3037c055..a219f6cb 100644 --- a/library/utils.py +++ b/library/utils.py @@ -7,7 +7,9 @@ from typing import * from diffusers import EulerAncestralDiscreteScheduler import diffusers.schedulers.scheduling_euler_ancestral_discrete from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput - +import cv2 +from PIL import Image +import numpy as np def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() @@ -78,7 +80,17 @@ def setup_logging(args=None, log_level=None, reset=False): logger = logging.getLogger(__name__) logger.info(msg_init) +def pil_resize(image, size, interpolation=Image.LANCZOS): + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + # use Pillow resize + resized_pil = pil_image.resize(size, interpolation) + + # return cv2 image + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) + + return resized_cv2 # TODO make inf_utils.py diff --git a/tools/detect_face_rotate.py b/tools/detect_face_rotate.py index bbc643ed..d2a4d9cf 100644 --- a/tools/detect_face_rotate.py +++ b/tools/detect_face_rotate.py @@ -15,7 +15,7 @@ import os from anime_face_detector import create_detector from tqdm import tqdm import numpy as np -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging logger = logging.getLogger(__name__) @@ -172,7 +172,10 @@ def process(args): if scale != 1.0: w = int(w * scale + .5) h = int(h * scale + .5) - face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4) + if scale < 1.0: + face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA) + else: + face_img = pil_resize(face_img, (w, h)) cx = int(cx * scale + .5) cy = int(cy * scale + .5) fw = int(fw * scale + .5) diff --git a/tools/resize_images_to_resolution.py b/tools/resize_images_to_resolution.py index b8069fc1..0f9e00b1 100644 --- a/tools/resize_images_to_resolution.py +++ b/tools/resize_images_to_resolution.py @@ -6,7 +6,7 @@ import shutil import math from PIL import Image import numpy as np -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging logger = logging.getLogger(__name__) @@ -24,9 +24,9 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi # Select interpolation method if interpolation == 'lanczos4': - cv2_interpolation = cv2.INTER_LANCZOS4 + pil_interpolation = Image.LANCZOS elif interpolation == 'cubic': - cv2_interpolation = cv2.INTER_CUBIC + pil_interpolation = Image.BICUBIC else: cv2_interpolation = cv2.INTER_AREA @@ -64,7 +64,10 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi new_width = int(img.shape[1] * math.sqrt(scale_factor)) # Resize image - img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) + if cv2_interpolation: + img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) + else: + img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation) else: new_height, new_width = img.shape[0:2] From 1e8108fec9962333e4cf2a8db1dcedf657049900 Mon Sep 17 00:00:00 2001 From: liesen Date: Sat, 24 Aug 2024 01:38:17 +0300 Subject: [PATCH 05/87] Handle args.v_parameterization properly for MinSNR and changed prediction target --- sdxl_train.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index 46d7860b..14b25965 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -590,7 +590,11 @@ def train(args): with accelerator.autocast(): noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) - target = noise + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise if ( args.min_snr_gamma @@ -606,7 +610,7 @@ def train(args): loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: From 2e89cd2cc634c27add7a04c21fcb6d0e16716a2b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Aug 2024 12:39:54 +0900 Subject: [PATCH 06/87] Fix issue with attention mask not being applied in single blocks --- README.md | 3 ++ flux_train_network.py | 4 +-- library/flux_models.py | 62 +++++++++++++++++++++--------------------- 3 files changed, 36 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 33b3a9a9..4151bf44 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,9 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 24, 2024: +Fixed an issue where the attention mask was not applied in single blocks when `--apply_t5_attn_mask` was specified. + Aug 22, 2024 (update 2): Fixed a bug that the embedding was zero-padded when `--apply_t5_attn_mask` option was applied. Also, the cache file for text encoder outputs now records whether the mask is applied or not. Please note that the cache file will be recreated when switching the `--apply_t5_attn_mask` option. diff --git a/flux_train_network.py b/flux_train_network.py index 3e2057e9..82f77a77 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -243,7 +243,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): self.flux_upper.to("cpu") clean_memory_on_device(self.target_device) self.flux_lower.to(self.target_device) - return self.flux_lower(img, txt, vec, pe) + return self.flux_lower(img, txt, vec, pe, txt_attention_mask) wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device) clean_memory_on_device(accelerator.device) @@ -352,7 +352,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): intermediate_txt.requires_grad_(True) vec.requires_grad_(True) pe.requires_grad_(True) - model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe) + model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) diff --git a/library/flux_models.py b/library/flux_models.py index c045aef6..b5726c29 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -752,18 +752,6 @@ class DoubleStreamBlock(nn.Module): else: return self._forward(img, txt, vec, pe, txt_attention_mask) - # def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): - # if self.training and self.gradient_checkpointing: - # def create_custom_forward(func): - # def custom_forward(*inputs): - # return func(*inputs) - # return custom_forward - # return torch.utils.checkpoint.checkpoint( - # create_custom_forward(self._forward), img, txt, vec, pe, use_reentrant=USE_REENTRANT - # ) - # else: - # return self._forward(img, txt, vec, pe) - class SingleStreamBlock(nn.Module): """ @@ -809,7 +797,7 @@ class SingleStreamBlock(nn.Module): self.gradient_checkpointing = False self.cpu_offload_checkpointing = False - def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: mod, _ = self.modulation(vec) x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) @@ -817,16 +805,35 @@ class SingleStreamBlock(nn.Module): q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) q, k = self.norm(q, k, v) + # make attention mask if not None + attn_mask = None + if txt_attention_mask is not None: + # F.scaled_dot_product_attention expects attn_mask to be bool for binary mask + attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len + attn_mask = torch.cat( + ( + attn_mask, + torch.ones( + attn_mask.shape[0], x.shape[1] - txt_attention_mask.shape[1], device=attn_mask.device, dtype=torch.bool + ), + ), + dim=1, + ) # b, seq_len + img_len = x_len + + # broadcast attn_mask to all heads + attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1) + # compute attention - attn = attention(q, k, v, pe=pe) + attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) + # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) return x + mod.gate * output - def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: if self.training and self.gradient_checkpointing: if not self.cpu_offload_checkpointing: - return checkpoint(self._forward, x, vec, pe, use_reentrant=False) + return checkpoint(self._forward, x, vec, pe, txt_attention_mask, use_reentrant=False) # cpu offload checkpointing @@ -838,19 +845,11 @@ class SingleStreamBlock(nn.Module): return custom_forward - return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=False) + return torch.utils.checkpoint.checkpoint( + create_custom_forward(self._forward), x, vec, pe, txt_attention_mask, use_reentrant=False + ) else: - return self._forward(x, vec, pe) - - # def forward(self, x: Tensor, vec: Tensor, pe: Tensor): - # if self.training and self.gradient_checkpointing: - # def create_custom_forward(func): - # def custom_forward(*inputs): - # return func(*inputs) - # return custom_forward - # return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=USE_REENTRANT) - # else: - # return self._forward(x, vec, pe) + return self._forward(x, vec, pe, txt_attention_mask) class LastLayer(nn.Module): @@ -1053,7 +1052,7 @@ class Flux(nn.Module): if not self.single_blocks_to_swap: for block in self.single_blocks: - img = block(img, vec=vec, pe=pe) + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) else: # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning for block_idx in range(self.single_blocks_to_swap): @@ -1075,7 +1074,7 @@ class Flux(nn.Module): block.to(self.device) # move to cuda # print(f"Moved single block {block_idx} to cuda.") - img = block(img, vec=vec, pe=pe) + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) if moving: self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) @@ -1250,10 +1249,11 @@ class FluxLower(nn.Module): txt: Tensor, vec: Tensor | None = None, pe: Tensor | None = None, + txt_attention_mask: Tensor | None = None, ) -> Tensor: img = torch.cat((txt, img), 1) for block in self.single_blocks: - img = block(img, vec=vec, pe=pe) + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) img = img[:, txt.shape[1] :, ...] img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) From cf689e7aa697877a0eee58622035ab702ce59d3e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Aug 2024 16:35:43 +0900 Subject: [PATCH 07/87] feat: Add option to split projection layers and apply LoRA --- README.md | 14 ++ networks/check_lora_weights.py | 2 +- networks/convert_flux_lora.py | 51 ++++-- networks/lora_flux.py | 322 +++++++++++++++++++++++++++------ 4 files changed, 323 insertions(+), 66 deletions(-) diff --git a/README.md b/README.md index 4151bf44..7d326a86 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,20 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 24, 2024 (update 2): + +__Experimental__ Added an option to split the projection layers of q/k/v/txt in the attention and apply LoRA to each of them in FLUX.1 LoRA training. Specify `"split_qkv=True"` in network_args like `--network_args "split_qkv=True"` (`train_blocks` is also available). + +The number of parameters may increase slightly, so the expressiveness may increase, but the training time may be longer. No detailed verification has been done. + +This implementation is experimental, so it may be deprecated or changed in the future. + +The .safetensors file of the trained model is compatible with the normal LoRA model of sd-scripts, so it should be usable in inference environments such as ComfyUI as it is. Also, converting it to AI-toolkit (Diffusers) format with `convert_flux_lora.py` will reduce the size. It should be no problem to convert it if you use it in the inference environment. + +Technical details: In the implementation of Black Forest Labs' model, the projection layers of q/k/v (and txt in single blocks) are concatenated into one. If LoRA is added there as it is, the LoRA module is only one, and the dimension is large. In contrast, in the implementation of Diffusers, the projection layers of q/k/v/txt are separated. Therefore, the LoRA module is applied to q/k/v/txt separately, and the dimension is smaller. This option is for training LoRA similar to the latter. + +The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large. + Aug 24, 2024: Fixed an issue where the attention mask was not applied in single blocks when `--apply_t5_attn_mask` was specified. diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py index 794659c9..b5b5e61a 100644 --- a/networks/check_lora_weights.py +++ b/networks/check_lora_weights.py @@ -18,7 +18,7 @@ def main(file): keys = list(sd.keys()) for key in keys: - if "lora_up" in key or "lora_down" in key: + if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key: values.append((key, sd[key])) print(f"number of LoRA modules: {len(values)}") diff --git a/networks/convert_flux_lora.py b/networks/convert_flux_lora.py index e9743534..bd4c1cf7 100644 --- a/networks/convert_flux_lora.py +++ b/networks/convert_flux_lora.py @@ -266,11 +266,12 @@ def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): if sds_key + ".lora_down.weight" not in sds_sd: return down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + up_weight = sds_sd.pop(sds_key + ".lora_up.weight") + sd_lora_rank = down_weight.shape[0] # scale weight by alpha and dim - rank = down_weight.shape[0] alpha = sds_sd.pop(sds_key + ".alpha") - scale = alpha / rank + scale = alpha / sd_lora_rank # calculate scale_down and scale_up scale_down = scale @@ -279,23 +280,49 @@ def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): scale_down *= 2 scale_up /= 2 - ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] - ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] - - num_splits = len(ait_keys) - up_weight = sds_sd.pop(sds_key + ".lora_up.weight") - - # down_weight is copied to each split - ait_sd.update({k: down_weight * scale_down for k in ait_down_keys}) + down_weight = down_weight * scale_down + up_weight = up_weight * scale_up # calculate dims if not provided + num_splits = len(ait_keys) if dims is None: dims = [up_weight.shape[0] // num_splits] * num_splits else: assert sum(dims) == up_weight.shape[0] - # up_weight is split to each split - ait_sd.update({k: v * scale_up for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) + # check upweight is sparse or not + is_sparse = False + if sd_lora_rank % num_splits == 0: + ait_rank = sd_lora_rank // num_splits + is_sparse = True + i = 0 + for j in range(len(dims)): + for k in range(len(dims)): + if j == k: + continue + is_sparse = is_sparse and torch.all(up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0) + i += dims[j] + if is_sparse: + logger.info(f"weight is sparse: {sds_key}") + + # make ai-toolkit weight + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + if not is_sparse: + # down_weight is copied to each split + ait_sd.update({k: down_weight for k in ait_down_keys}) + + # up_weight is split to each split + ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) + else: + # down_weight is chunked to each split + ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) + + # up_weight is sparse: only non-zero values are copied to each split + i = 0 + for j in range(len(dims)): + ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous() + i += dims[j] def convert_sd_scripts_to_ai_toolkit(sds_sd): diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 4da33542..efc7847e 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -39,6 +39,7 @@ class LoRAModule(torch.nn.Module): dropout=None, rank_dropout=None, module_dropout=None, + split_dims: Optional[List[int]] = None, ): """if alpha == 0 or None, alpha is rank (no scaling).""" super().__init__() @@ -52,16 +53,34 @@ class LoRAModule(torch.nn.Module): out_dim = org_module.out_features self.lora_dim = lora_dim + self.split_dims = split_dims - if org_module.__class__.__name__ == "Conv2d": - kernel_size = org_module.kernel_size - stride = org_module.stride - padding = org_module.padding - self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) - self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + if split_dims is None: + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) else: - self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) - self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + # conv2d not supported + assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim" + assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear" + # print(f"split_dims: {split_dims}") + self.lora_down = torch.nn.ModuleList( + [torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))] + ) + self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims]) + for lora_down in self.lora_down: + torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5)) + for lora_up in self.lora_up: + torch.nn.init.zeros_(lora_up.weight) if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() # without casting, bf16 causes error @@ -70,9 +89,6 @@ class LoRAModule(torch.nn.Module): self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える # same as microsoft's - torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) - torch.nn.init.zeros_(self.lora_up.weight) - self.multiplier = multiplier self.org_module = org_module # remove in applying self.dropout = dropout @@ -92,30 +108,56 @@ class LoRAModule(torch.nn.Module): if torch.rand(1) < self.module_dropout: return org_forwarded - lx = self.lora_down(x) + if self.split_dims is None: + lx = self.lora_down(x) - # normal dropout - if self.dropout is not None and self.training: - lx = torch.nn.functional.dropout(lx, p=self.dropout) + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) - # rank dropout - if self.rank_dropout is not None and self.training: - mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout - if len(lx.size()) == 3: - mask = mask.unsqueeze(1) # for Text Encoder - elif len(lx.size()) == 4: - mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d - lx = lx * mask + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask - # scaling for rank dropout: treat as if the rank is changed - # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる - scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + return org_forwarded + lx * self.multiplier * scale else: - scale = self.scale + lxs = [lora_down(x) for lora_down in self.lora_down] - lx = self.lora_up(lx) + # normal dropout + if self.dropout is not None and self.training: + lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs] - return org_forwarded + lx * self.multiplier * scale + # rank dropout + if self.rank_dropout is not None and self.training: + masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs] + for i in range(len(lxs)): + if len(lx.size()) == 3: + masks[i] = masks[i].unsqueeze(1) + elif len(lx.size()) == 4: + masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1) + lxs[i] = lxs[i] * masks[i] + + # scaling for rank dropout: treat as if the rank is changed + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + + return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale class LoRAInfModule(LoRAModule): @@ -152,31 +194,50 @@ class LoRAInfModule(LoRAModule): if device is None: device = org_device - # get up/down weight - up_weight = sd["lora_up.weight"].to(torch.float).to(device) - down_weight = sd["lora_down.weight"].to(torch.float).to(device) + if self.split_dims is None: + # get up/down weight + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + up_weight = sd["lora_up.weight"].to(torch.float).to(device) - # merge weight - if len(weight.size()) == 2: - # linear - weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = ( - weight - + self.multiplier - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * self.scale - ) + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # logger.info(conved.size(), weight.size(), module.stride, module.padding) - weight = weight + self.multiplier * conved * self.scale + # split_dims + total_dims = sum(self.split_dims) + for i in range(len(self.split_dims)): + # get up/down weight + down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim) + up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank) - # set weight to org_module - org_sd["weight"] = weight.to(dtype) - self.org_module.load_state_dict(org_sd) + # pad up_weight -> (total_dims, rank) + padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float) + padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight + + # merge weight + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) # 復元できるマージのため、このモジュールのweightを返す def get_weight(self, multiplier=None): @@ -211,7 +272,14 @@ class LoRAInfModule(LoRAModule): def default_forward(self, x): # logger.info(f"default_forward {self.lora_name} {x.size()}") - return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + if self.split_dims is None: + lx = self.lora_down(x) + lx = self.lora_up(lx) + return self.org_forward(x) + lx * self.multiplier * self.scale + else: + lxs = [lora_down(x) for lora_down in self.lora_down] + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale def forward(self, x): if not self.enabled: @@ -257,6 +325,11 @@ def create_network( if train_blocks is not None: assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}" + # split qkv + split_qkv = kwargs.get("split_qkv", False) + if split_qkv is not None: + split_qkv = True if split_qkv == "True" else False + # すごく引数が多いな ( ^ω^)・・・ network = LoRANetwork( text_encoders, @@ -270,6 +343,7 @@ def create_network( conv_lora_dim=conv_dim, conv_alpha=conv_alpha, train_blocks=train_blocks, + split_qkv=split_qkv, varbose=True, ) @@ -311,10 +385,34 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh modules_dim[lora_name] = dim # logger.info(lora_name, value.size(), dim) + # # split qkv + # double_qkv_rank = None + # single_qkv_rank = None + # rank = None + # for lora_name, dim in modules_dim.items(): + # if "double" in lora_name and "qkv" in lora_name: + # double_qkv_rank = dim + # elif "single" in lora_name and "linear1" in lora_name: + # single_qkv_rank = dim + # elif rank is None: + # rank = dim + # if double_qkv_rank is not None and single_qkv_rank is not None and rank is not None: + # break + # split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or ( + # single_qkv_rank is not None and single_qkv_rank != rank + # ) + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + module_class = LoRAInfModule if for_inference else LoRAModule network = LoRANetwork( - text_encoders, flux, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class + text_encoders, + flux, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + split_qkv=split_qkv, ) return network, weights_sd @@ -344,6 +442,7 @@ class LoRANetwork(torch.nn.Module): modules_dim: Optional[Dict[str, int]] = None, modules_alpha: Optional[Dict[str, int]] = None, train_blocks: Optional[str] = None, + split_qkv: bool = False, varbose: Optional[bool] = False, ) -> None: super().__init__() @@ -357,6 +456,7 @@ class LoRANetwork(torch.nn.Module): self.rank_dropout = rank_dropout self.module_dropout = module_dropout self.train_blocks = train_blocks if train_blocks is not None else "all" + self.split_qkv = split_qkv self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None @@ -373,6 +473,8 @@ class LoRANetwork(torch.nn.Module): logger.info( f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" ) + if self.split_qkv: + logger.info(f"split qkv for LoRA") # create module instances def create_modules( @@ -420,6 +522,14 @@ class LoRANetwork(torch.nn.Module): skipped.append(lora_name) continue + # qkv split + split_dims = None + if is_flux and split_qkv: + if "double" in lora_name and "qkv" in lora_name: + split_dims = [3072] * 3 + elif "single" in lora_name and "linear1" in lora_name: + split_dims = [3072] * 3 + [12288] + lora = module_class( lora_name, child_module, @@ -429,6 +539,7 @@ class LoRANetwork(torch.nn.Module): dropout=dropout, rank_dropout=rank_dropout, module_dropout=module_dropout, + split_dims=split_dims, ) loras.append(lora) return loras, skipped @@ -492,6 +603,111 @@ class LoRANetwork(torch.nn.Module): info = self.load_state_dict(weights_sd, False) return info + def load_state_dict(self, state_dict, strict=True): + # override to convert original weight to splitted qkv weight + if not self.split_qkv: + return super().load_state_dict(state_dict, strict) + + # split qkv + for key in list(state_dict.keys()): + if "double" in key and "qkv" in key: + split_dims = [3072] * 3 + elif "single" in key and "linear1" in key: + split_dims = [3072] * 3 + [12288] + else: + continue + + weight = state_dict[key] + lora_name = key.split(".")[0] + if "lora_down" in key and "weight" in key: + # dense weight (rank*3, in_dim) + split_weight = torch.chunk(weight, len(split_dims), dim=0) + for i, split_w in enumerate(split_weight): + state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w + + del state_dict[key] + # print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}") + elif "lora_up" in key and "weight" in key: + # sparse weight (out_dim=sum(split_dims), rank*3) + rank = weight.size(1) // len(split_dims) + i = 0 + for j in range(len(split_dims)): + state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dims[j], j * rank : (j + 1) * rank] + i += split_dims[j] + del state_dict[key] + + # # check is sparse + # i = 0 + # is_zero = True + # for j in range(len(split_dims)): + # for k in range(len(split_dims)): + # if j == k: + # continue + # is_zero = is_zero and torch.all(weight[i : i + split_dims[j], k * rank : (k + 1) * rank] == 0) + # i += split_dims[j] + # if not is_zero: + # logger.warning(f"weight is not sparse: {key}") + # else: + # logger.info(f"weight is sparse: {key}") + + # print( + # f"split {key}: {weight.shape} to {[state_dict[k].shape for k in [f'{lora_name}.lora_up.{j}.weight' for j in range(len(split_dims))]]}" + # ) + + # alpha is unchanged + + return super().load_state_dict(state_dict, strict) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if not self.split_qkv: + return super().state_dict(destination, prefix, keep_vars) + + # merge qkv + state_dict = super().state_dict(destination, prefix, keep_vars) + new_state_dict = {} + for key in list(state_dict.keys()): + if "double" in key and "qkv" in key: + split_dims = [3072] * 3 + elif "single" in key and "linear1" in key: + split_dims = [3072] * 3 + [12288] + else: + new_state_dict[key] = state_dict[key] + continue + + if key not in state_dict: + continue # already merged + + lora_name = key.split(".")[0] + + # (rank, in_dim) * 3 + down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))] + # (split dim, rank) * 3 + up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))] + + alpha = state_dict.pop(f"{lora_name}.alpha") + + # merge down weight + down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) + + # merge up weight (sum of split_dim, rank*3) + rank = up_weights[0].size(1) + up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) + i = 0 + for j in range(len(split_dims)): + up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j] + i += split_dims[j] + + new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight + new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight + new_state_dict[f"{lora_name}.alpha"] = alpha + + # print( + # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" + # ) + print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha") + + return new_state_dict + def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") From 5639c2adc0085e2e995bb3eee5a278aace397e7a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Aug 2024 16:37:49 +0900 Subject: [PATCH 08/87] fix typo --- networks/lora_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index efc7847e..07a80f0b 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -604,7 +604,7 @@ class LoRANetwork(torch.nn.Module): return info def load_state_dict(self, state_dict, strict=True): - # override to convert original weight to splitted qkv weight + # override to convert original weight to split qkv if not self.split_qkv: return super().load_state_dict(state_dict, strict) From d5c076cf9007f86f6dd1b9ecdfc5531336774b2f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Aug 2024 21:21:39 +0900 Subject: [PATCH 09/87] update readme --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 946df58f..81a54937 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Working in progress +- `--v_parameterization` is available in `sdxl_train.py`. The results are unpredictable, so use with caution. PR [#1505](https://github.com/kohya-ss/sd-scripts/pull/1505) Thanks to liesened! - Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr! - The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower. - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available. From 72287d39c76176c0e1c16e8da4f5ddc6f94ea7d6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 25 Aug 2024 16:01:24 +0900 Subject: [PATCH 10/87] feat: Add `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training --- README.md | 4 ++++ library/flux_train_utils.py | 15 +++++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 282f3b3b..562dcdb2 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,10 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 25, 2024: +Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`. +Sample command: `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` + Aug 24, 2024 (update 2): __Experimental__ Added an option to split the projection layers of q/k/v/txt in the attention and apply LoRA to each of them in FLUX.1 LoRA training. Specify `"split_qkv=True"` in network_args like `--network_args "split_qkv=True"` (`train_blocks` is also available). diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 1d3f80d7..75f70a54 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -380,9 +380,19 @@ def get_noisy_model_input_and_timesteps( t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) else: t = torch.rand((bsz,), device=device) + timesteps = t * 1000.0 t = t.view(-1, 1, 1, 1) noisy_model_input = (1 - t) * latents + t * noise + elif args.timestep_sampling == "shift": + shift = args.discrete_flow_shift + logits_norm = torch.randn(bsz, device=device) + timesteps = logits_norm.sigmoid() + timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) + + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 + noisy_model_input = (1 - t) * latents + t * noise else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -559,9 +569,10 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--timestep_sampling", - choices=["sigma", "uniform", "sigmoid"], + choices=["sigma", "uniform", "sigmoid", "shift"], default="sigma", - help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法:sigma、random uniform、またはrandom normalのsigmoid。", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。", ) parser.add_argument( "--sigmoid_scale", From 0087a46e14c8e568982cbe3a5d9b9c561b175abf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 27 Aug 2024 19:59:40 +0900 Subject: [PATCH 11/87] FLUX.1 LoRA supports CLIP-L --- README.md | 8 ++++ flux_train_network.py | 40 +++++++++++++----- library/flux_train_utils.py | 8 ++-- library/strategy_flux.py | 3 +- networks/lora_flux.py | 4 +- train_network.py | 81 ++++++++++++++++++++++++------------- 6 files changed, 101 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index 562dcdb2..1203b5eb 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,14 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 27, 2024: + +- FLUX.1 LoRA training now supports CLIP-L LoRA. Please remove `--network_train_unet_only`. T5XXL is not trained. The output of T5XXL is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. The trained LoRA can be used with ComfyUI. + - `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA. +- `--sigmoid_scale` is now effective even when `--timestep_sampling shift` is specified. Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution. + +- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is not required (Flux is fp8, and CLIP-L is bf16/fp16, regardless of the `--fp8_base` option). + Aug 25, 2024: Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`. Sample command: `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` diff --git a/flux_train_network.py b/flux_train_network.py index 82f77a77..1a40de61 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -40,9 +40,13 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): train_dataset_group.is_text_encoder_output_cacheable() ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" - assert ( - args.network_train_unet_only or not args.cache_text_encoder_outputs - ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" + # assert ( + # args.network_train_unet_only or not args.cache_text_encoder_outputs + # ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" + if not args.network_train_unet_only: + logger.info( + "network for CLIP-L only will be trained. T5XXL will not be trained / CLIP-Lのネットワークのみが学習されます。T5XXLは学習されません" + ) if args.max_token_length is not None: logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") @@ -137,12 +141,25 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) def get_models_for_text_encoding(self, args, accelerator, text_encoders): - return text_encoders # + [accelerator.unwrap_model(text_encoders[-1])] + if args.cache_text_encoder_outputs: + if self.is_train_text_encoder(args): + return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached + else: + return text_encoders # ignored + else: + return text_encoders # both CLIP-L and T5XXL are needed for encoding + + def get_text_encoders_train_flags(self, args, text_encoders): + return [True, False] if self.is_train_text_encoder(args) else [False, False] def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: return strategy_flux.FluxTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, None, False, apply_t5_attn_mask=args.apply_t5_attn_mask + args.cache_text_encoder_outputs_to_disk, + None, + False, + is_partial=self.is_train_text_encoder(args), + apply_t5_attn_mask=args.apply_t5_attn_mask, ) else: return None @@ -190,9 +207,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): accelerator.wait_for_everyone() # move back to cpu - logger.info("move text encoders back to cpu") - text_encoders[0].to("cpu") # , dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU - text_encoders[1].to("cpu") # , dtype=torch.float32) + if not self.is_train_text_encoder(args): + logger.info("move CLIP-L back to cpu") + text_encoders[0].to("cpu") + logger.info("move t5XXL back to cpu") + text_encoders[1].to("cpu") clean_memory_on_device(accelerator.device) if not args.lowram: @@ -297,7 +316,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): if args.gradient_checkpointing: noisy_model_input.requires_grad_(True) for t in text_encoder_conds: - t.requires_grad_(True) + if t.dtype.is_floating_point: + t.requires_grad_(True) img_ids.requires_grad_(True) guidance_vec.requires_grad_(True) @@ -384,7 +404,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift def is_text_encoder_not_needed_for_training(self, args): - return args.cache_text_encoder_outputs + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) def setup_parser() -> argparse.ArgumentParser: diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 75f70a54..a8e94ac0 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -58,7 +58,7 @@ def sample_images( logger.info("") logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") - if not os.path.isfile(args.sample_prompts): + if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None: logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return @@ -66,7 +66,8 @@ def sample_images( # unwrap unet and text_encoder(s) flux = accelerator.unwrap_model(flux) - text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] + if text_encoders is not None: + text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) prompts = load_prompts(args.sample_prompts) @@ -134,7 +135,7 @@ def sample_image_inference( accelerator: Accelerator, args: argparse.Namespace, flux: flux_models.Flux, - text_encoders: List[CLIPTextModel], + text_encoders: Optional[List[CLIPTextModel]], ae: flux_models.AutoEncoder, save_dir, prompt_dict, @@ -387,6 +388,7 @@ def get_noisy_model_input_and_timesteps( elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift logits_norm = torch.randn(bsz, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling timesteps = logits_norm.sigmoid() timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) diff --git a/library/strategy_flux.py b/library/strategy_flux.py index d52b3b8d..5d083913 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -60,7 +60,7 @@ class FluxTextEncodingStrategy(TextEncodingStrategy): if apply_t5_attn_mask is None: apply_t5_attn_mask = self.apply_t5_attn_mask - clip_l, t5xxl = models + clip_l, t5xxl = models if len(models) == 2 else (models[0], None) l_tokens, t5_tokens = tokens[:2] t5_attn_mask = tokens[2] if len(tokens) > 2 else None @@ -81,6 +81,7 @@ class FluxTextEncodingStrategy(TextEncodingStrategy): else: t5_out = None txt_ids = None + t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 07a80f0b..fcb56a46 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -401,7 +401,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh # split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or ( # single_qkv_rank is not None and single_qkv_rank != rank # ) - split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined module_class = LoRAInfModule if for_inference else LoRAModule @@ -421,7 +421,7 @@ class LoRANetwork(torch.nn.Module): # FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"] FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"] LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2" diff --git a/train_network.py b/train_network.py index cab0ec52..048c7e7b 100644 --- a/train_network.py +++ b/train_network.py @@ -127,8 +127,15 @@ class NetworkTrainer: return None def get_models_for_text_encoding(self, args, accelerator, text_encoders): + """ + Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models. + """ return text_encoders + # returns a list of bool values indicating whether each text encoder should be trained + def get_text_encoders_train_flags(self, args, text_encoders): + return [True] * len(text_encoders) if self.is_train_text_encoder(args) else [False] * len(text_encoders) + def is_train_text_encoder(self, args): return not args.network_train_unet_only @@ -136,11 +143,6 @@ class NetworkTrainer: for t_enc in text_encoders: t_enc.to(accelerator.device, dtype=weight_dtype) - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], weight_dtype) - return encoder_hidden_states - def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample return noise_pred @@ -313,7 +315,7 @@ class NetworkTrainer: collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) if args.debug_dataset: - train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly + train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: @@ -437,8 +439,10 @@ class NetworkTrainer: if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - for t_enc in text_encoders: - t_enc.gradient_checkpointing_enable() + for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): + if flag: + if t_enc.supports_gradient_checkpointing: + t_enc.gradient_checkpointing_enable() del t_enc network.enable_gradient_checkpointing() # may have no effect @@ -522,14 +526,17 @@ class NetworkTrainer: unet_weight_dtype = te_weight_dtype = weight_dtype # Experimental Feature: Put base model into fp8 to save vram - if args.fp8_base: + if args.fp8_base or args.fp8_base_unet: assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" assert ( args.mixed_precision != "no" ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" - accelerator.print("enable fp8 training.") + accelerator.print("enable fp8 training for U-Net.") unet_weight_dtype = torch.float8_e4m3fn - te_weight_dtype = torch.float8_e4m3fn + + if not args.fp8_base_unet: + accelerator.print("enable fp8 training for Text Encoder.") + te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory @@ -546,19 +553,18 @@ class NetworkTrainer: t_enc.to(dtype=te_weight_dtype) if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"): # nn.Embedding not support FP8 - t_enc.text_model.embeddings.to( - dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"): - t_enc.encoder.embeddings.to( - dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + t_enc.encoder.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if args.deepspeed: + flags = self.get_text_encoders_train_flags(args, text_encoders) ds_model = deepspeed_utils.prepare_deepspeed_model( args, unet=unet if train_unet else None, - text_encoder1=text_encoders[0] if train_text_encoder else None, - text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None, + text_encoder1=text_encoders[0] if flags[0] else None, + text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None, network=network, ) ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -571,11 +577,14 @@ class NetworkTrainer: else: unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator if train_text_encoder: + text_encoders = [ + (accelerator.prepare(t_enc) if flag else t_enc) + for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)) + ] if len(text_encoders) > 1: - text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders] + text_encoder = text_encoders else: - text_encoder = accelerator.prepare(text_encoder) - text_encoders = [text_encoder] + text_encoder = text_encoders[0] else: pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set @@ -587,11 +596,11 @@ class NetworkTrainer: if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() - for t_enc in text_encoders: + for t_enc, frag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): t_enc.train() # set top parameter requires_grad = True for gradient checkpointing works - if train_text_encoder: + if frag: t_enc.text_model.embeddings.requires_grad_(True) else: @@ -736,6 +745,7 @@ class NetworkTrainer: "ss_huber_schedule": args.huber_schedule, "ss_huber_c": args.huber_c, "ss_fp8_base": args.fp8_base, + "ss_fp8_base_unet": args.fp8_base_unet, } self.update_metadata(metadata, args) # architecture specific metadata @@ -1004,6 +1014,7 @@ class NetworkTrainer: for t_enc in text_encoders: del t_enc text_encoders = [] + text_encoder = None # For --sample_at_first self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) @@ -1018,7 +1029,7 @@ class NetworkTrainer: # log device and dtype for each model logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}") for t_enc in text_encoders: - logger.info(f"text_encoder dtype: {te_weight_dtype}, device: {t_enc.device}") + logger.info(f"text_encoder dtype: {t_enc.dtype}, device: {t_enc.device}") clean_memory_on_device(accelerator.device) @@ -1073,12 +1084,17 @@ class NetworkTrainer: text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs - else: + if ( + text_encoder_conds is None + or len(text_encoder_conds) == 0 + or text_encoder_conds[0] is None + or train_text_encoder + ): with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: # SD only - text_encoder_conds = get_weighted_text_embeddings( + encoded_text_encoder_conds = get_weighted_text_embeddings( tokenizers[0], text_encoder, batch["captions"], @@ -1088,13 +1104,18 @@ class NetworkTrainer: ) else: input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] - text_encoder_conds = text_encoding_strategy.encode_tokens( + encoded_text_encoder_conds = text_encoding_strategy.encode_tokens( tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids, ) if args.full_fp16: - text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] + + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( @@ -1257,6 +1278,12 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") + parser.add_argument( + "--fp8_base_unet", + action="store_true", + help="use fp8 for U-Net (or DiT), Text Encoder is fp16 or bf16" + " / U-Net(またはDiT)にfp8を使用する。Text Encoderはfp16またはbf16", + ) parser.add_argument( "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" From 3be712e3e011b0378fad389641cec0c1869555ab Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 27 Aug 2024 21:40:02 +0900 Subject: [PATCH 12/87] feat: Update direct loading fp8 ckpt for LoRA training --- README.md | 7 +++- flux_minimal_inference.py | 27 +----------- flux_train_network.py | 16 +++++++- library/flux_utils.py | 12 ++++-- library/utils.py | 62 +++++++++++++++++++++++++++- networks/flux_merge_lora.py | 82 ++++++++++++++++++++++++++----------- 6 files changed, 151 insertions(+), 55 deletions(-) diff --git a/README.md b/README.md index 1203b5eb..0108ada5 100644 --- a/README.md +++ b/README.md @@ -9,13 +9,18 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 27, 2024 (update 2): +In FLUX.1 LoRA training, when `--fp8_base` is specified, the FLUX.1 model file with fp8 (`float8_e4m3fn` type) can be loaded directly. Also, in `flux_minimal_inference.py`, it is possible to load it by specifying `fp8 (float8_e4m3fn)` in `--flux_dtype`. + +In `flux_merge_lora.py`, you can now specify the precision at save time with `fp8` (see `--help` for details). Also, if you do not specify the merge model, only the model type conversion will be performed. + Aug 27, 2024: - FLUX.1 LoRA training now supports CLIP-L LoRA. Please remove `--network_train_unet_only`. T5XXL is not trained. The output of T5XXL is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. The trained LoRA can be used with ComfyUI. - `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA. - `--sigmoid_scale` is now effective even when `--timestep_sampling shift` is specified. Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution. -- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is not required (Flux is fp8, and CLIP-L is bf16/fp16, regardless of the `--fp8_base` option). +- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is automatically enabled. Aug 25, 2024: Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`. diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 5b8aa250..56c1b198 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -10,7 +10,6 @@ import einops import numpy as np import torch -from safetensors.torch import safe_open, load_file from tqdm import tqdm from PIL import Image import accelerate @@ -21,7 +20,7 @@ from library.device_utils import init_ipex, get_preferred_device init_ipex() -from library.utils import setup_logging +from library.utils import setup_logging, str_to_dtype setup_logging() import logging @@ -288,28 +287,6 @@ if __name__ == "__main__": name = "schnell" if "schnell" in args.ckpt_path else "dev" # TODO change this to a more robust way is_schnell = name == "schnell" - def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: - if s is None: - return default_dtype - if s in ["bf16", "bfloat16"]: - return torch.bfloat16 - elif s in ["fp16", "float16"]: - return torch.float16 - elif s in ["fp32", "float32"]: - return torch.float32 - elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]: - return torch.float8_e4m3fn - elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]: - return torch.float8_e4m3fnuz - elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]: - return torch.float8_e5m2 - elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]: - return torch.float8_e5m2fnuz - elif s in ["fp8", "float8"]: - return torch.float8_e4m3fn # default fp8 - else: - raise ValueError(f"Unsupported dtype: {s}") - def is_fp8(dt): return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz] @@ -348,7 +325,7 @@ if __name__ == "__main__": encoding_strategy = strategy_flux.FluxTextEncodingStrategy() # DiT - model = flux_utils.load_flow_model(name, args.ckpt_path, flux_dtype, loading_device) + model = flux_utils.load_flow_model(name, args.ckpt_path, None, loading_device) model.eval() logger.info(f"Casting model to {flux_dtype}") model.to(flux_dtype) # make sure model is dtype diff --git a/flux_train_network.py b/flux_train_network.py index 1a40de61..4a63c2de 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -29,6 +29,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): super().assert_extra_args(args, train_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) + if args.fp8_base_unet: + args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1 + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: logger.warning( "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" @@ -61,9 +64,20 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): name = self.get_flux_model_name(args) # if we load to cpu, flux.to(fp8) takes a long time + if args.fp8_base: + loading_dtype = None # as is + else: + loading_dtype = weight_dtype + model = flux_utils.load_flow_model( - name, args.pretrained_model_name_or_path, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors ) + if args.fp8_base: + # check dtype of model + if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") + elif model.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 FLUX model") if args.split_mode: model = self.prepare_split_model(model, weight_dtype, accelerator) diff --git a/library/flux_utils.py b/library/flux_utils.py index 37166933..68083616 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -1,5 +1,5 @@ import json -from typing import Union +from typing import Optional, Union import einops import torch @@ -20,7 +20,9 @@ MODEL_VERSION_FLUX_V1 = "flux1" # temporary copy from sd3_utils TODO refactor -def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: torch.dtype = torch.float32): +def load_safetensors( + path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32 +): if disable_mmap: # return safetensors.torch.load(open(path, "rb").read()) # use experimental loader @@ -38,11 +40,13 @@ def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap: def load_flow_model( - name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False + name: str, ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False ) -> flux_models.Flux: logger.info(f"Building Flux model {name}") with torch.device("meta"): - model = flux_models.Flux(flux_models.configs[name].params).to(dtype) + model = flux_models.Flux(flux_models.configs[name].params) + if dtype is not None: + model = model.to(dtype) # load_sft doesn't support torch.device logger.info(f"Loading state dict from {ckpt_path}") diff --git a/library/utils.py b/library/utils.py index a1620997..d355cb10 100644 --- a/library/utils.py +++ b/library/utils.py @@ -82,6 +82,66 @@ def setup_logging(args=None, log_level=None, reset=False): logger.info(msg_init) +def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: + """ + Convert a string to a torch.dtype + + Args: + s: string representation of the dtype + default_dtype: default dtype to return if s is None + + Returns: + torch.dtype: the corresponding torch.dtype + + Raises: + ValueError: if the dtype is not supported + + Examples: + >>> str_to_dtype("float32") + torch.float32 + >>> str_to_dtype("fp32") + torch.float32 + >>> str_to_dtype("float16") + torch.float16 + >>> str_to_dtype("fp16") + torch.float16 + >>> str_to_dtype("bfloat16") + torch.bfloat16 + >>> str_to_dtype("bf16") + torch.bfloat16 + >>> str_to_dtype("fp8") + torch.float8_e4m3fn + >>> str_to_dtype("fp8_e4m3fn") + torch.float8_e4m3fn + >>> str_to_dtype("fp8_e4m3fnuz") + torch.float8_e4m3fnuz + >>> str_to_dtype("fp8_e5m2") + torch.float8_e5m2 + >>> str_to_dtype("fp8_e5m2fnuz") + torch.float8_e5m2fnuz + """ + if s is None: + return default_dtype + if s in ["bf16", "bfloat16"]: + return torch.bfloat16 + elif s in ["fp16", "float16"]: + return torch.float16 + elif s in ["fp32", "float32", "float"]: + return torch.float32 + elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]: + return torch.float8_e4m3fn + elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]: + return torch.float8_e4m3fnuz + elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]: + return torch.float8_e5m2 + elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]: + return torch.float8_e5m2fnuz + elif s in ["fp8", "float8"]: + return torch.float8_e4m3fn # default fp8 + else: + raise ValueError(f"Unsupported dtype: {s}") + + def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): """ memory efficient save file @@ -198,7 +258,7 @@ class MemoryEfficientSafeOpen: if tensor_bytes is None: byte_tensor = torch.empty(0, dtype=torch.uint8) else: - tensor_bytes = bytearray(tensor_bytes) # make it writable + tensor_bytes = bytearray(tensor_bytes) # make it writable byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8) # process float8 types diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py index d5e82920..2e0d4c29 100644 --- a/networks/flux_merge_lora.py +++ b/networks/flux_merge_lora.py @@ -8,7 +8,7 @@ from safetensors import safe_open from safetensors.torch import load_file, save_file from tqdm import tqdm -from library.utils import setup_logging +from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file setup_logging() import logging @@ -34,18 +34,23 @@ def load_state_dict(file_name, dtype): return sd, metadata -def save_to_file(file_name, state_dict, dtype, metadata): +def save_to_file(file_name, state_dict, dtype, metadata, mem_eff_save=False): if dtype is not None: logger.info(f"converting to {dtype}...") - for key in list(state_dict.keys()): + for key in tqdm(list(state_dict.keys())): if type(state_dict[key]) == torch.Tensor: state_dict[key] = state_dict[key].to(dtype) logger.info(f"saving to: {file_name}") - save_file(state_dict, file_name, metadata=metadata) + if mem_eff_save: + mem_eff_save_file(state_dict, file_name, metadata=metadata) + else: + save_file(state_dict, file_name, metadata=metadata) -def merge_to_flux_model(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): +def merge_to_flux_model( + loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False +): # create module map without loading state_dict logger.info(f"loading keys from FLUX.1 model: {flux_model}") lora_name_to_module_key = {} @@ -57,7 +62,14 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") lora_name_to_module_key[lora_name] = key - flux_state_dict = load_file(flux_model, device=loading_device) + if mem_eff_load_save: + flux_state_dict = {} + with MemoryEfficientSafeOpen(flux_model) as flux_file: + for key in tqdm(flux_file.keys()): + flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed + else: + flux_state_dict = load_file(flux_model, device=loading_device) + for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU @@ -120,9 +132,17 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati return flux_state_dict -def merge_to_flux_model_diffusers(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): +def merge_to_flux_model_diffusers( + loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False +): logger.info(f"loading keys from FLUX.1 model: {flux_model}") - flux_state_dict = load_file(flux_model, device=loading_device) + if mem_eff_load_save: + flux_state_dict = {} + with MemoryEfficientSafeOpen(flux_model) as flux_file: + for key in tqdm(flux_file.keys()): + flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed + else: + flux_state_dict = load_file(flux_model, device=loading_device) def create_key_map(n_double_layers, n_single_layers): key_map = {} @@ -474,19 +494,15 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): def merge(args): + if args.models is None: + args.models = [] + if args.ratios is None: + args.ratios = [] + assert len(args.models) == len( args.ratios ), "number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" - def str_to_dtype(p): - if p == "float": - return torch.float - if p == "fp16": - return torch.float16 - if p == "bf16": - return torch.bfloat16 - return None - merge_dtype = str_to_dtype(args.precision) save_dtype = str_to_dtype(args.save_precision) if save_dtype is None: @@ -500,11 +516,25 @@ def merge(args): if args.flux_model is not None: if not args.diffusers: state_dict = merge_to_flux_model( - args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype + args.loading_device, + args.working_device, + args.flux_model, + args.models, + args.ratios, + merge_dtype, + save_dtype, + args.mem_eff_load_save, ) else: state_dict = merge_to_flux_model_diffusers( - args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype + args.loading_device, + args.working_device, + args.flux_model, + args.models, + args.ratios, + merge_dtype, + save_dtype, + args.mem_eff_load_save, ) if args.no_metadata: @@ -517,7 +547,7 @@ def merge(args): ) logger.info(f"saving FLUX model to: {args.save_to}") - save_to_file(args.save_to, state_dict, save_dtype, sai_metadata) + save_to_file(args.save_to, state_dict, save_dtype, sai_metadata, args.mem_eff_load_save) else: state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) @@ -546,14 +576,14 @@ def setup_parser() -> argparse.ArgumentParser: "--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], - help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", + help="precision in saving, same to merging if omitted. supported types: " + "float32, fp16, bf16, fp8 (same as fp8_e4m3fn), fp8_e4m3fn, fp8_e4m3fnuz, fp8_e5m2, fp8_e5m2fnuz" + " / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", ) parser.add_argument( "--precision", type=str, default="float", - choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", ) parser.add_argument( @@ -562,6 +592,12 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="FLUX.1 model to load, merge LoRA models if omitted / 読み込むモデル、指定しない場合はLoRAモデルをマージする", ) + parser.add_argument( + "--mem_eff_load_save", + action="store_true", + help="use custom memory efficient load and save functions for FLUX.1 model" + " / カスタムのメモリ効率の良い読み込みと保存関数をFLUX.1モデルに使用する", + ) parser.add_argument( "--loading_device", type=str, From a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 27 Aug 2024 21:44:10 +0900 Subject: [PATCH 13/87] update readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0108ada5..7b1d9cc6 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ The command to install PyTorch is as follows: Aug 27, 2024 (update 2): In FLUX.1 LoRA training, when `--fp8_base` is specified, the FLUX.1 model file with fp8 (`float8_e4m3fn` type) can be loaded directly. Also, in `flux_minimal_inference.py`, it is possible to load it by specifying `fp8 (float8_e4m3fn)` in `--flux_dtype`. -In `flux_merge_lora.py`, you can now specify the precision at save time with `fp8` (see `--help` for details). Also, if you do not specify the merge model, only the model type conversion will be performed. +In `flux_merge_lora.py`, you can now specify `fp8` for the save precision (see `--help` for details). Also, if you do not specify the merge model, only the dtype conversion will be performed. Aug 27, 2024: From 6c0e8a5a1740dbd50a0a45ec1f08983877605cd7 Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Thu, 29 Aug 2024 14:50:29 +0800 Subject: [PATCH 14/87] make guidance_scale keep float in args --- flux_train_network.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flux_train_network.py b/flux_train_network.py index 4a63c2de..354a8c6f 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -324,7 +324,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) # get guidance - guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device) + # ensure guidance_scale in args is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) # ensure the hidden state will require grad if args.gradient_checkpointing: From a0cfb0894c4be4ea27412e4c12ed13f68b57094b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 29 Aug 2024 21:20:33 +0900 Subject: [PATCH 15/87] Cleaned up README --- README.md | 299 +++++++++++++++++++++++++++--------------------------- 1 file changed, 152 insertions(+), 147 deletions(-) diff --git a/README.md b/README.md index 7b1d9cc6..a73eead0 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ This repository contains training, generation and utility scripts for Stable Diffusion. -## FLUX.1 LoRA training (WIP) +## FLUX.1 training (WIP) This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. @@ -9,127 +9,24 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` -Aug 27, 2024 (update 2): -In FLUX.1 LoRA training, when `--fp8_base` is specified, the FLUX.1 model file with fp8 (`float8_e4m3fn` type) can be loaded directly. Also, in `flux_minimal_inference.py`, it is possible to load it by specifying `fp8 (float8_e4m3fn)` in `--flux_dtype`. - -In `flux_merge_lora.py`, you can now specify `fp8` for the save precision (see `--help` for details). Also, if you do not specify the merge model, only the dtype conversion will be performed. - -Aug 27, 2024: - -- FLUX.1 LoRA training now supports CLIP-L LoRA. Please remove `--network_train_unet_only`. T5XXL is not trained. The output of T5XXL is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. The trained LoRA can be used with ComfyUI. - - `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA. -- `--sigmoid_scale` is now effective even when `--timestep_sampling shift` is specified. Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution. - -- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is automatically enabled. - -Aug 25, 2024: -Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`. -Sample command: `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` - -Aug 24, 2024 (update 2): - -__Experimental__ Added an option to split the projection layers of q/k/v/txt in the attention and apply LoRA to each of them in FLUX.1 LoRA training. Specify `"split_qkv=True"` in network_args like `--network_args "split_qkv=True"` (`train_blocks` is also available). - -The number of parameters may increase slightly, so the expressiveness may increase, but the training time may be longer. No detailed verification has been done. - -This implementation is experimental, so it may be deprecated or changed in the future. - -The .safetensors file of the trained model is compatible with the normal LoRA model of sd-scripts, so it should be usable in inference environments such as ComfyUI as it is. Also, converting it to AI-toolkit (Diffusers) format with `convert_flux_lora.py` will reduce the size. It should be no problem to convert it if you use it in the inference environment. - -Technical details: In the implementation of Black Forest Labs' model, the projection layers of q/k/v (and txt in single blocks) are concatenated into one. If LoRA is added there as it is, the LoRA module is only one, and the dimension is large. In contrast, in the implementation of Diffusers, the projection layers of q/k/v/txt are separated. Therefore, the LoRA module is applied to q/k/v/txt separately, and the dimension is smaller. This option is for training LoRA similar to the latter. - -The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large. - -Aug 24, 2024: -Fixed an issue where the attention mask was not applied in single blocks when `--apply_t5_attn_mask` was specified. - -Aug 22, 2024 (update 2): -Fixed a bug that the embedding was zero-padded when `--apply_t5_attn_mask` option was applied. Also, the cache file for text encoder outputs now records whether the mask is applied or not. Please note that the cache file will be recreated when switching the `--apply_t5_attn_mask` option. - -Added a script to extract LoRA from the difference between the two models of FLUX.1. Use `networks/flux_extract_lora.py`. See `--help` for details. Normally, more than 50GB of memory is required, but specifying the `--mem_eff_safe_open` option significantly reduces memory usage. However, this option is a custom implementation, so unexpected problems may occur. Please always check if the model is loaded correctly. - -Aug 22, 2024: -Fixed a bug in multi-GPU training. It should work with fine-tuning and LoRA training. `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training. - -`--disable_mmap_load_safetensors` option now works in `flux_train.py`. It speeds up model loading during training in WSL2. It is also effective in reducing memory usage when loading models during multi-GPU training. Please always check if the model is loaded correctly, as it uses a custom implementation of safetensors loading. - - -Aug 21, 2024 (update 3): -- There is a bug that `--full_bf16` option is enabled even if it is not specified in `flux_train.py`. The bug will be fixed sooner. __Please specify the `--full_bf16` option explicitly, especially when training with 24GB VRAM.__ -- Stochastic rounding is now implemented when `--fused_backward_pass` is specified. The implementation is -based on the code provided by 2kpr. Thank you so much! - - With this change, `--fused_backward_pass` is recommended over `--blockwise_fused_optimizers` when `--full_bf16` is specified. - - Please note that `--fused_backward_pass` is only supported with Adafactor. -- The sample command in [FLUX.1 fine-tuning](#flux1-fine-tuning) is updated to reflect these changes. -- Fixed `--single_blocks_to_swap` is not working in `flux_train.py`. - -Aug 21, 2024 (update 2): -Fixed an error in applying mask in Attention. The attention mask was float, but it should be bool. - -Added a script `convert_flux_lora.py` to convert LoRA between sd-scripts format (BFL-based) and AI-toolkit format (Diffusers-based). See `--help` for details. BFL-based LoRA has a large module, so converting it to Diffusers format may reduce temporary memory usage in the inference environment. Note that re-conversion will increase the size of LoRA. - - -Aug 21, 2024: -The specification of `--apply_t5_attn_mask` has been changed. Previously, the T5 output was zero-padded, but now, two steps are taken: "1. Apply mask when encoding T5" and "2. Apply mask in the attention of Double Block". Fine tuning, LoRA training, and inference in `flux_mini_inference.py` have been changed. - -Aug 20, 2024 (update 3): -__Experimental__ The multi-resolution training is now supported with caching latents to disk. - -The cache files now hold latents for multiple resolutions. Since the latents are appended to the current cache file, it is recommended to delete the cache file in advance (if not, the old latents is kept in .npz file). - -See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details. - -Aug 20, 2024 (update 2): -`flux_merge_lora.py` now supports LoRA from AI-toolkit (Diffusers based keys). Specify `--diffusers` option to merge LoRA with Diffusers based keys. Thanks to exveria1015! - -Aug 20, 2024: -FLUX.1 supports multi-resolution inference, so training at multiple resolutions may be possible and the results may be improved (like 1024x1024, 768x768 and 512x512 ... you can use any resolution). - -The script seems to support multi-resolution even in the current version, ~~if `--cache_latents_to_disk` is not specified~~ -> `--cache_latents_to_disk` is now supported for multi-resolution training. Please try if you are interested. See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details. - -We will support multi-resolution caching to disk in the near future. - -Aug 19, 2024: -In `flux_train.py`, the memory consumption during model saving is reduced when `--save_precision` is set to the same value as `--mixed_precision` (about 22GB). Please set the same value unless there is a reason. - -An experimental option `--mem_eff_save` is also added. When specified, it can further reduce memory consumption (about 22GB), but since it is a custom implementation, unexpected problems may occur. We do not recommend using it unless you are familiar with the code. - -Aug 18, 2024: -Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr! See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. - -Aug 17, 2024: -Added a script `flux_train.py` to train FLUX.1. The script is experimental and not an optimized version. It needs >28GB VRAM for training. - -Aug 16, 2024: - -Added a script `networks/flux_merge_lora.py` to merge LoRA into FLUX.1 checkpoint. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. - -FLUX.1 schnell model based training is now supported (but not tested). If the name of the model file contains `schnell`, the model is treated as a schnell model. - -Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL. The default is 512 in dev and 256 in schnell. - -Previously, when `--max_token_length` was specified, that value was used, and 512 was used when omitted (default). Therefore, there is no impact if `--max_token_length` was not specified. If `--max_token_length` was specified, please specify `--t5xxl_max_token_length` instead. `--max_token_length` is ignored during FLUX.1 training. - -Aug 14, 2024: Sample image generation during training is now supported. Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. It will be very slow when `--split_mode` is specified. - -Aug 13, 2024: - -__Experimental__ A network argument `train_blocks` is added to `lora_flux`. This is to select the target blocks of LoRA from FLUX double blocks and single blocks. Specify like `--network_args "train_blocks=single"`. `all` trains both double blocks and single blocks, `double` trains only double blocks, and `single` trains only single blocks. The default (omission) is `all`. - -This argument is available even if `--split_mode` is not specified. - -__Experimental__ `--split_mode` option is added to `flux_train_network.py`. This splits FLUX into double blocks and single blocks for training. By enabling gradients only for the single blocks part, memory usage is reduced. When this option is specified, you need to specify `"train_blocks=single"` in the network arguments. - -This option enables training with 12GB VRAM GPUs, but the training speed is 2-3 times slower than the default. - -Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-generate the latents cache file if you have used the option before. - -Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. - +- [FLUX.1 LoRA training](#flux1-lora-training) + - [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training) + - [Inference for FLUX.1 LoRA model](#inference-for-flux1-lora-model) + - [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) +- [FLUX.1 fine-tuning](#flux1-fine-tuning) + - [Key Features for FLUX.1 fine-tuning](#key-features-for-flux1-fine-tuning) +- [Extract LoRA from FLUX.1 Models](#extract-lora-from-flux1-models) +- [Convert FLUX LoRA](#convert-flux-lora) +- [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) +- [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) ### FLUX.1 LoRA training -We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs. +We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. + +FLUX.1 model, CLIP-L, and T5XXL models are recommended to be in bf16/fp16 format. If you specify `--fp8_base`, you can use fp8 models for FLUX.1. The fp8 model is only compatible with `float8_e4m3fn` format. + +Sample command is below. It will work with 24GB VRAM GPUs. ``` accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py @@ -137,46 +34,107 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 ---network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base +--cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml ---output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigmoid ---model_prediction_type raw --guidance_scale 1.0 --loss_type l2 +--output_dir path/to/output/dir --output_name flux-lora-name +--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 ``` (The command is multi-line for readability. Please combine it into one line.) The training can be done with 16GB VRAM GPUs with Adafactor optimizer. Please use settings like below: ``` ---optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 ``` The training can be done with 12GB VRAM GPUs with Adafactor optimizer, `--split_mode` and `train_blocks=single` options. Please use settings like below: ``` ---optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 ``` -LoRAs for Text Encoders are not tested yet. - -We have added some new options (Aug 10, 2024): `--time_sampling`, `--sigmoid_scale`, `--model_prediction_type` and `--discrete_flow_shift`. The options are as follows: - -- `--timestep_sampling` is the method to sample timesteps (0-1): `sigma` (sigma-based, same as SD3), `uniform` (uniform random), or `sigmoid` (sigmoid of random normal, same as x-flux). -- `--sigmoid_scale` is the scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). The default is 1.0. Larger values will make the sampling more uniform. -- `--model_prediction_type` is how to interpret and process the model prediction: `raw` (use as is, same as x-flux), `additive` (add to noisy input), `sigma_scaled` (apply sigma scaling, same as SD3). -- `--discrete_flow_shift` is the discrete flow shift for the Euler Discrete Scheduler, default is 3.0 (same as SD3). - -`--loss_type` may be useful for FLUX.1 training. The default is `l2`. - -In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. - -additional note (Aug 11): A quick check shows that the settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). This seems to be a good starting point. Thanks to Ostris for the great work! - -Other settings may work better, so please try different settings. - We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. The trained LoRA model can be used with ComfyUI. +#### Key Options for FLUX.1 LoRA training + +There are many unknown points in FLUX.1 training, so some settings can be specified by arguments. Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome. + +- `--timestep_sampling` is the method to sample timesteps (0-1): + - `sigma`: sigma-based, same as SD3 + - `uniform`: uniform random + - `sigmoid`: sigmoid of random normal, same as x-flux, AI-toolkit etc. + - `shift`: shifts the value of sigmoid of normal distribution random number +- `--sigmoid_scale` is the scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). The default is 1.0. Larger values will make the sampling more uniform. + - This option is effective even when`--timestep_sampling shift` is specified. + - Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution. +- `--model_prediction_type` is how to interpret and process the model prediction: + - `raw`: use as is, same as x-flux + - `additive`: add to noisy input + - `sigma_scaled`: apply sigma scaling, same as SD3 +- `--discrete_flow_shift` is the discrete flow shift for the Euler Discrete Scheduler, default is 3.0 (same as SD3). + +The existing `--loss_type` option may be useful for FLUX.1 training. The default is `l2`. + +~~In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. ~~ + +In our experiments, `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` with `--loss_type l2` seems to work better than other settings. + +The settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). + +Other settings may work better, so please try different settings. + +Other options are described below. + +#### Distribution of timesteps + +`--timestep_sampling` and `--sigmoid_scale`, `--discrete_flow_shift` adjust the distribution of timesteps. The distribution is shown in the figures below. + +The effect of `--discrete_flow_shift` with `--timestep_sampling shift` (when `--sigmoid_scale` is not specified, the default is 1.0): + +The difference between `--timestep_sampling uniform` and `--timestep_sampling sigma`: + +The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--timestep_sampling sigmoid` is specified, `--discrete_flow_shift` is ignored): + +#### Key Features for FLUX.1 LoRA training + +1. CLIP-L LoRA Support: + - FLUX.1 LoRA training now supports CLIP-L LoRA. + - Remove `--network_train_unet_only` from your command. + - T5XXL is not trained. Its output is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. + - The trained LoRA can be used with ComfyUI. + - Note: `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA. + +2. Experimental FP8/FP16 mixed training: + - `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L. + - FLUX can be trained with fp8, and CLIP-L can be trained with bf16/fp16. + - When specifying this option, the `--fp8_base` option is automatically enabled. + +3. Split Q/K/V Projection Layers (Experimental): + - Added an option to split the projection layers of q/k/v/txt in the attention and apply LoRA to each of them. + - Specify `"split_qkv=True"` in network_args like `--network_args "split_qkv=True"` (`train_blocks` is also available). + - May increase expressiveness but also training time. + - The trained model is compatible with normal LoRA models in sd-scripts and can be used in environments like ComfyUI. + - Converting to AI-toolkit (Diffusers) format with `convert_flux_lora.py` will reduce the size. + +4. T5 Attention Mask Application: + - T5 attention mask is applied when `--apply_t5_attn_mask` is specified. + - Now applies mask when encoding T5 and in the attention of Double and Single Blocks + - Affects fine-tuning, LoRA training, and inference in `flux_minimal_inference.py`. + +5. Multi-resolution Training Support: + - FLUX.1 now supports multi-resolution training, even with caching latents to disk. + + +Technical details of Q/K/V split: + +In the implementation of Black Forest Labs' model, the projection layers of q/k/v (and txt in single blocks) are concatenated into one. If LoRA is added there as it is, the LoRA module is only one, and the dimension is large. In contrast, in the implementation of Diffusers, the projection layers of q/k/v/txt are separated. Therefore, the LoRA module is applied to q/k/v/txt separately, and the dimension is smaller. This option is for training LoRA similar to the latter. + +The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large. + +### Inference for FLUX.1 with LoRA model + The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. ``` @@ -185,6 +143,8 @@ python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safete ### FLUX.1 fine-tuning +The memory-efficient training with block swap is based on 2kpr's implementation. Thanks to 2kpr! + Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GPUs, and 64GB main memory is recommended. ``` @@ -195,15 +155,13 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t --dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name output-name --learning_rate 5e-5 --max_train_epochs 4 --sdpa --highvram --cache_text_encoder_outputs_to_disk --cache_latents_to_disk --save_every_n_epochs 1 --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" ---timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 +--lr_scheduler constant_with_warmup --max_grad_norm 0.0 +--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 --fused_backward_pass --double_blocks_to_swap 6 --cpu_offload_checkpointing --full_bf16 ``` +(The command is multi-line for readability. Please combine it into one line.) -(Combine the command into one line.) - -Sample image generation during training is not tested yet. - -Options are almost the same as LoRA training. The difference is `--blockwise_fused_optimizers`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. +Options are almost the same as LoRA training. The difference is `--full_bf16`, `--blockwise_fused_optimizers`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. `--full_bf16` enables the training with bf16 (weights and gradients). @@ -223,6 +181,53 @@ Swap 6 double blocks and use cpu offload checkpointing may be a good starting po The learning rate and the number of epochs are not optimized yet. Please adjust them according to the training results. +#### Key Features for FLUX.1 fine-tuning + +1. Sample Image Generation: + - Sample image generation during training is now supported. + - The prompts are cached and used for generation if `--cache_latents` is specified. So changing the prompts during training will not affect the generated images. + - Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. + - Note: It will be very slow when `--split_mode` is specified. + +2. Experimental Memory-Efficient Saving: + - `--mem_eff_save` option can further reduce memory consumption during model saving (about 22GB). + - This is a custom implementation and may cause unexpected issues. Use with caution. + +3. T5XXL Token Length Control: + - Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL. + - Default is 512 in dev and 256 in schnell models. + +4. Multi-GPU Training Support: + - Note: `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training. + +5. Disable mmap Load for Safetensors: + - `--disable_mmap_load_safetensors` option now works in `flux_train.py`. + - Speeds up model loading during training in WSL2. + - Effective in reducing memory usage when loading models during multi-GPU training. + + +### Extract LoRA from FLUX.1 Models + +Script: `networks/flux_extract_lora.py` + +Extracts LoRA from the difference between two FLUX.1 models. + +Offers memory-efficient option with `--mem_eff_safe_open`. + +CLIP-L LoRA is not supported. + +### Convert FLUX LoRA + +Script: `convert_flux_lora.py` + +Converts LoRA between sd-scripts format (BFL-based) and AI-toolkit format (Diffusers-based). + +If you use LoRA in the inference environment, converting it to AI-toolkit format may reduce temporary memory usage. + +Note that re-conversion will increase the size of LoRA. + +CLIP-L LoRA is not supported. + ### Merge LoRA to FLUX.1 checkpoint `networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__ From daa6ad516581872aa6acaa15c0d24aad4f998838 Mon Sep 17 00:00:00 2001 From: "Kohya S." <52813779+kohya-ss@users.noreply.github.com> Date: Thu, 29 Aug 2024 21:25:30 +0900 Subject: [PATCH 16/87] Update README.md --- README.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index a73eead0..6e2ae337 100644 --- a/README.md +++ b/README.md @@ -77,9 +77,9 @@ There are many unknown points in FLUX.1 training, so some settings can be specif The existing `--loss_type` option may be useful for FLUX.1 training. The default is `l2`. -~~In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. ~~ +~~In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted.~~ -In our experiments, `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` with `--loss_type l2` seems to work better than other settings. +In our experiments, `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type) seems to work better. The settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). @@ -92,10 +92,13 @@ Other options are described below. `--timestep_sampling` and `--sigmoid_scale`, `--discrete_flow_shift` adjust the distribution of timesteps. The distribution is shown in the figures below. The effect of `--discrete_flow_shift` with `--timestep_sampling shift` (when `--sigmoid_scale` is not specified, the default is 1.0): +![Figure_2](https://github.com/user-attachments/assets/d9de42f9-f17d-40da-b88d-d964402569c6) -The difference between `--timestep_sampling uniform` and `--timestep_sampling sigma`: +The difference between `--timestep_sampling sigmoid` and `--timestep_sampling uniform` (when `--timestep_sampling sigmoid` or `uniform` is specified, `--discrete_flow_shift` is ignored): +![Figure_3](https://github.com/user-attachments/assets/27029009-1f5d-4dc0-bb24-13d02ac4fdad) The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--timestep_sampling sigmoid` is specified, `--discrete_flow_shift` is ignored): +![Figure_4](https://github.com/user-attachments/assets/08a2267c-e47e-48b7-826e-f9a080787cdc) #### Key Features for FLUX.1 LoRA training From 8ecf0fc4bfd1b03cfc6fd4055af0b3363f5d1f38 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 29 Aug 2024 22:10:57 +0900 Subject: [PATCH 17/87] Refactor code to ensure args.guidance_scale is always a float #1525 --- flux_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flux_train.py b/flux_train.py index 410728d4..32a36f03 100644 --- a/flux_train.py +++ b/flux_train.py @@ -688,8 +688,8 @@ def train(args): packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) - # get guidance - guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device) + # get guidance: ensure args.guidance_scale is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) # call model l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds From 8fdfd8c857a88aaa78ac9c2488432ef8115982f2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 29 Aug 2024 22:26:29 +0900 Subject: [PATCH 18/87] Update safetensors to version 0.4.4 in requirements.txt #1524 --- README.md | 7 +++++++ requirements.txt | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 6e2ae337..30264e73 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,13 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +### Recent Updates + +Aug 29, 2024: +Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `requirements.txt` is updated. + +### Contents + - [FLUX.1 LoRA training](#flux1-lora-training) - [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training) - [Inference for FLUX.1 LoRA model](#inference-for-flux1-lora-model) diff --git a/requirements.txt b/requirements.txt index 4ee19b3e..4c1bc392 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ bitsandbytes==0.43.3 prodigyopt==1.0 lion-pytorch==0.0.6 tensorboard -safetensors==0.4.2 +safetensors==0.4.4 # gradio==3.16.2 altair==4.2.2 easygui==0.98.3 From 34f2315047f8d5b89b7a8a6093bb56679bff13c3 Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Thu, 29 Aug 2024 22:33:37 +0800 Subject: [PATCH 19/87] fix: text_encoder_conds referenced before assignment --- train_network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index 048c7e7b..628c421c 100644 --- a/train_network.py +++ b/train_network.py @@ -1081,12 +1081,12 @@ class NetworkTrainer: # print(f"set multiplier: {multipliers}") accelerator.unwrap_model(network).set_multiplier(multipliers) + text_encoder_conds = [] text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs if ( - text_encoder_conds is None - or len(text_encoder_conds) == 0 + len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder ): From 35882f8d5bbd076a97622cf6193c988621481803 Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Thu, 29 Aug 2024 23:03:43 +0800 Subject: [PATCH 20/87] fix --- train_network.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/train_network.py b/train_network.py index 628c421c..4204bce3 100644 --- a/train_network.py +++ b/train_network.py @@ -1112,10 +1112,14 @@ class NetworkTrainer: if args.full_fp16: encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] - # if encoded_text_encoder_conds is not None, update cached text_encoder_conds - for i in range(len(encoded_text_encoder_conds)): - if encoded_text_encoder_conds[i] is not None: - text_encoder_conds[i] = encoded_text_encoder_conds[i] + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( From 2a3aefb4e44dce1f189677d0a996ba0244633956 Mon Sep 17 00:00:00 2001 From: Nando Metzger <42088121+nandometzger@users.noreply.github.com> Date: Fri, 30 Aug 2024 08:15:05 +0200 Subject: [PATCH 21/87] Update train_util.py, bug fix --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3c..0fec565d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1489,7 +1489,7 @@ class DreamBoothDataset(BaseDataset): def load_dreambooth_dir(subset: DreamBoothSubset): if not os.path.isdir(subset.image_dir): logger.warning(f"not directory: {subset.image_dir}") - return [], [] + return [], [], [] info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE) use_cached_info_for_subset = subset.cache_info From 3a6154b7b0dbcae82d24adacf5a76f75288b98f4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 31 Aug 2024 06:21:16 +0000 Subject: [PATCH 22/87] Bump opencv-python from 4.7.0.68 to 4.8.1.78 Bumps [opencv-python](https://github.com/opencv/opencv-python) from 4.7.0.68 to 4.8.1.78. - [Release notes](https://github.com/opencv/opencv-python/releases) - [Commits](https://github.com/opencv/opencv-python/commits) --- updated-dependencies: - dependency-name: opencv-python dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e99775b8..977c5cd9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ transformers==4.36.2 diffusers[torch]==0.25.0 ftfy==6.1.1 # albumentations==1.3.0 -opencv-python==4.7.0.68 +opencv-python==4.8.1.78 einops==0.7.0 pytorch-lightning==1.9.0 bitsandbytes==0.43.0 From 25c9040f4fbbcbddc0297895369337846152fea4 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sat, 31 Aug 2024 03:05:19 +0800 Subject: [PATCH 23/87] Update flux_train_utils.py --- library/flux_train_utils.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index a8e94ac0..735bcced 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -371,7 +371,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def get_noisy_model_input_and_timesteps( args, noise_scheduler, latents, noise, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - bsz = latents.shape[0] + bsz, _, H, W = latents.shape sigmas = None if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": @@ -392,6 +392,16 @@ def get_noisy_model_input_and_timesteps( timesteps = logits_norm.sigmoid() timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 + noisy_model_input = (1 - t) * latents + t * noise + elif args.timestep_sampling == "flux_shift": + logits_norm = torch.randn(bsz, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling + timesteps = logits_norm.sigmoid() + mu=get_lin_function(y1=0.5, y2=1.15)((H//2) * (W//2)) + timesteps = time_shift(mu, 1.0, timesteps) + t = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 noisy_model_input = (1 - t) * latents + t * noise @@ -571,7 +581,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--timestep_sampling", - choices=["sigma", "uniform", "sigmoid", "shift"], + choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"], default="sigma", help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid." " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。", From 1bcf8d600bfb9f4314a41a12a5e7b272a17ceaed Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 1 Sep 2024 01:33:04 +0000 Subject: [PATCH 24/87] Bump crate-ci/typos from 1.19.0 to 1.24.3 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.19.0 to 1.24.3. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.19.0...v1.24.3) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/typos.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index e8b06483..0149dcdd 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -18,4 +18,4 @@ jobs: - uses: actions/checkout@v4 - name: typos-action - uses: crate-ci/typos@v1.19.0 + uses: crate-ci/typos@v1.24.3 From ef510b3cb94427d72df681389e1214251813b1a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= <865105819@qq.com> Date: Sun, 1 Sep 2024 17:41:01 +0800 Subject: [PATCH 25/87] Sd3 freeze x_block (#1417) * Update sd3_train.py * add freeze block lr * Update train_util.py * update --- library/train_util.py | 21 +++++++++++++++++++++ sd3_train.py | 9 ++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 989758ad..74aae0a7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3246,6 +3246,12 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): default=None, help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", ) + parser.add_argument( + "--num_last_block_to_freeze", + type=int, + default=None, + help="num_last_block_to_freeze", + ) def add_optimizer_arguments(parser: argparse.ArgumentParser): @@ -5758,6 +5764,21 @@ def sample_image_inference( pass +def freeze_blocks(model, num_last_block_to_freeze, block_name="x_block"): + + filtered_blocks = [(name, param) for name, param in model.named_parameters() if block_name in name] + print(f"filtered_blocks: {len(filtered_blocks)}") + + num_blocks_to_freeze = min(len(filtered_blocks), num_last_block_to_freeze) + + print(f"freeze_blocks: {num_blocks_to_freeze}") + + start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze) + + for i in range(start_freezing_from, len(filtered_blocks)): + _, param = filtered_blocks[i] + param.requires_grad = False + # endregion diff --git a/sd3_train.py b/sd3_train.py index 3b6c8a11..ce9500b0 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -368,12 +368,19 @@ def train(args): vae.eval() vae.to(accelerator.device, dtype=vae_dtype) + mmdit.requires_grad_(train_mmdit) + if not train_mmdit: + mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared + + if args.num_last_block_to_freeze: + train_util.freeze_blocks(mmdit,num_last_block_to_freeze=args.num_last_block_to_freeze) + training_models = [] params_to_optimize = [] # if train_unet: training_models.append(mmdit) # if block_lrs is None: - params_to_optimize.append({"params": list(mmdit.parameters()), "lr": args.learning_rate}) + params_to_optimize.append({"params": list(filter(lambda p: p.requires_grad, mmdit.parameters())), "lr": args.learning_rate}) # else: # params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs)) From 92e7600cc2fea604321004f260e7db76c764f388 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Sep 2024 18:57:07 +0900 Subject: [PATCH 26/87] Move freeze_blocks to sd3_train because it's only for sd3 --- README.md | 3 +++ library/train_util.py | 21 --------------------- sd3_train.py | 22 ++++++++++++++++++++-- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 30264e73..d9636719 100644 --- a/README.md +++ b/README.md @@ -309,6 +309,9 @@ resolution = [512, 512] SD3 training is done with `sd3_train.py`. +__Sep 1, 2024__: +- `--num_last_block_to_freeze` is added to `sd3_train.py`. This option is to freeze the last n blocks of the MMDiT. See [#1417](https://github.com/kohya-ss/sd-scripts/pull/1417) for details. Thanks to sdbds! + __Jul 27, 2024__: - Latents and text encoder outputs caching mechanism is refactored significantly. - Existing cache files for SD3 need to be recreated. Please delete the previous cache files. diff --git a/library/train_util.py b/library/train_util.py index 74aae0a7..989758ad 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3246,12 +3246,6 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): default=None, help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", ) - parser.add_argument( - "--num_last_block_to_freeze", - type=int, - default=None, - help="num_last_block_to_freeze", - ) def add_optimizer_arguments(parser: argparse.ArgumentParser): @@ -5764,21 +5758,6 @@ def sample_image_inference( pass -def freeze_blocks(model, num_last_block_to_freeze, block_name="x_block"): - - filtered_blocks = [(name, param) for name, param in model.named_parameters() if block_name in name] - print(f"filtered_blocks: {len(filtered_blocks)}") - - num_blocks_to_freeze = min(len(filtered_blocks), num_last_block_to_freeze) - - print(f"freeze_blocks: {num_blocks_to_freeze}") - - start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze) - - for i in range(start_freezing_from, len(filtered_blocks)): - _, param = filtered_blocks[i] - param.requires_grad = False - # endregion diff --git a/sd3_train.py b/sd3_train.py index ce9500b0..87011b21 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -373,7 +373,20 @@ def train(args): mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared if args.num_last_block_to_freeze: - train_util.freeze_blocks(mmdit,num_last_block_to_freeze=args.num_last_block_to_freeze) + # freeze last n blocks of MM-DIT + block_name = "x_block" + filtered_blocks = [(name, param) for name, param in mmdit.named_parameters() if block_name in name] + accelerator.print(f"filtered_blocks: {len(filtered_blocks)}") + + num_blocks_to_freeze = min(len(filtered_blocks), args.num_last_block_to_freeze) + + accelerator.print(f"freeze_blocks: {num_blocks_to_freeze}") + + start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze) + + for i in range(start_freezing_from, len(filtered_blocks)): + _, param = filtered_blocks[i] + param.requires_grad = False training_models = [] params_to_optimize = [] @@ -1033,12 +1046,17 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", ) - parser.add_argument( "--skip_latents_validity_check", action="store_true", help="skip latents validity check / latentsの正当性チェックをスキップする", ) + parser.add_argument( + "--num_last_block_to_freeze", + type=int, + default=None, + help="freeze last n blocks of MM-DIT / MM-DITの最後のnブロックを凍結する", + ) return parser From 4f6d915d15262447b1049a78a55678b2825784a3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Sep 2024 19:12:29 +0900 Subject: [PATCH 27/87] update help and README --- README.md | 5 +++++ library/flux_train_utils.py | 8 ++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index d9636719..331951ef 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 1, 2024: +- `--timestamp_sampling` has `flux_shift` option. Thanks to sdbds! + - This is the same shift as FLUX.1 dev inference, adjusting the timestep sampling depending on the resolution. `--discrete_flow_shift` is ignored when `flux_shift` is specified. It is not verified which is better, `shift` or `flux_shift`. + Aug 29, 2024: Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `requirements.txt` is updated. @@ -73,6 +77,7 @@ There are many unknown points in FLUX.1 training, so some settings can be specif - `uniform`: uniform random - `sigmoid`: sigmoid of random normal, same as x-flux, AI-toolkit etc. - `shift`: shifts the value of sigmoid of normal distribution random number + - `flux_shift`: shifts the value of sigmoid of normal distribution random number, depending on the resolution (same as FLUX.1 dev inference). `--discrete_flow_shift` is ignored when `flux_shift` is specified. - `--sigmoid_scale` is the scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). The default is 1.0. Larger values will make the sampling more uniform. - This option is effective even when`--timestep_sampling shift` is specified. - Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution. diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 735bcced..9dad4baa 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -371,7 +371,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def get_noisy_model_input_and_timesteps( args, noise_scheduler, latents, noise, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - bsz, _, H, W = latents.shape + bsz, _, h, w = latents.shape sigmas = None if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": @@ -399,7 +399,7 @@ def get_noisy_model_input_and_timesteps( logits_norm = torch.randn(bsz, device=device) logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling timesteps = logits_norm.sigmoid() - mu=get_lin_function(y1=0.5, y2=1.15)((H//2) * (W//2)) + mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) timesteps = time_shift(mu, 1.0, timesteps) t = timesteps.view(-1, 1, 1, 1) @@ -583,8 +583,8 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): "--timestep_sampling", choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"], default="sigma", - help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid." - " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。", ) parser.add_argument( "--sigmoid_scale", From 6abacf04da756808ffca567f6660445ecdf478bd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 2 Sep 2024 13:05:26 +0900 Subject: [PATCH 28/87] update README --- README.md | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 331951ef..5dd916aa 100644 --- a/README.md +++ b/README.md @@ -184,7 +184,7 @@ Options are almost the same as LoRA training. The difference is `--full_bf16`, ` `--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now. -`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. `--double_blocks_to_swap` can be specified with `--single_blocks_to_swap`. The recommended maximum number of blocks to swap is 9 for double blocks and 18 for single blocks. +`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. `--double_blocks_to_swap` can be specified with `--single_blocks_to_swap`. The recommended maximum number of blocks to swap is 9 for double blocks and 18 for single blocks. Please see the next chapter for details. `--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. @@ -198,24 +198,32 @@ The learning rate and the number of epochs are not optimized yet. Please adjust #### Key Features for FLUX.1 fine-tuning -1. Sample Image Generation: +1. Technical details of double/single block swap: + - Reduce memory usage by transferring double and single blocks of FLUX.1 from GPU to CPU when they are not needed. + - During forward pass, the weights of the blocks that have finished calculation are transferred to CPU, and the weights of the blocks to be calculated are transferred to GPU. + - The same is true for the backward pass, but the order is reversed. The gradients remain on the GPU. + - Since the transfer between CPU and GPU takes time, the training will be slower. + - `--double_blocks_to_swap` and `--single_blocks_to_swap` specify the number of blocks to swap. For example, `--double_blocks_to_swap 6` swaps 6 blocks at each step of training, but the remaining 13 blocks are always on the GPU. + - About 640MB of memory can be saved per double block, and about 320MB of memory can be saved per single block. + +2. Sample Image Generation: - Sample image generation during training is now supported. - The prompts are cached and used for generation if `--cache_latents` is specified. So changing the prompts during training will not affect the generated images. - Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. - Note: It will be very slow when `--split_mode` is specified. -2. Experimental Memory-Efficient Saving: +3. Experimental Memory-Efficient Saving: - `--mem_eff_save` option can further reduce memory consumption during model saving (about 22GB). - This is a custom implementation and may cause unexpected issues. Use with caution. -3. T5XXL Token Length Control: +4. T5XXL Token Length Control: - Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL. - Default is 512 in dev and 256 in schnell models. -4. Multi-GPU Training Support: +5. Multi-GPU Training Support: - Note: `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training. -5. Disable mmap Load for Safetensors: +6. Disable mmap Load for Safetensors: - `--disable_mmap_load_safetensors` option now works in `flux_train.py`. - Speeds up model loading during training in WSL2. - Effective in reducing memory usage when loading models during multi-GPU training. From b65ae9b439e4324359014d6d720aa01def3a19fc Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 4 Sep 2024 21:33:17 +0900 Subject: [PATCH 29/87] T5XXL LoRA training, fp8 T5XXL support --- README.md | 45 +++++++++++---- flux_train_network.py | 112 +++++++++++++++++++++++++++++------- library/flux_train_utils.py | 23 ++++++-- library/flux_utils.py | 9 ++- library/strategy_flux.py | 13 ++++- networks/lora_flux.py | 39 ++++++++++--- train_network.py | 48 ++++++++++------ 7 files changed, 222 insertions(+), 67 deletions(-) diff --git a/README.md b/README.md index 5dd916aa..84065570 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,11 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 4, 2024: +- T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI. +- In LoRA training, when `--fp8_base` is specified, you can specify `t5xxl_fp8_e4m3fn.safetensors` as the T5XXL weights. However, it is recommended to use fp16 weights for caching. +- Fixed an issue where the training CLIP-L LoRA was not used in sample image generation during LoRA training. + Sep 1, 2024: - `--timestamp_sampling` has `flux_shift` option. Thanks to sdbds! - This is the same shift as FLUX.1 dev inference, adjusting the timestep sampling depending on the resolution. `--discrete_flow_shift` is ignored when `flux_shift` is specified. It is not verified which is better, `shift` or `flux_shift`. @@ -41,8 +46,8 @@ Sample command is below. It will work with 24GB VRAM GPUs. ``` accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py ---pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors ---ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers +--pretrained_model_name_or_path flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors +--ae ae.safetensors --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base @@ -72,6 +77,11 @@ The trained LoRA model can be used with ComfyUI. There are many unknown points in FLUX.1 training, so some settings can be specified by arguments. Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome. +- `--pretrained_model_name_or_path` is the path to the pretrained model (FLUX.1). bf16 (original BFL model) is recommended (`flux1-dev.safetensors` or `flux1-dev.sft`). If you specify `--fp8_base`, you can use fp8 models for FLUX.1. The fp8 model is only compatible with `float8_e4m3fn` format. +- `--clip_l` is the path to the CLIP-L model. +- `--t5xxl` is the path to the T5XXL model. If you specify `--fp8_base`, you can use fp8 (float8_e4m3fn) models for T5XXL. However, it is recommended to use fp16 models for caching. +- `--ae` is the path to the autoencoder model (`ae.safetensors` or `ae.sft`). + - `--timestep_sampling` is the method to sample timesteps (0-1): - `sigma`: sigma-based, same as SD3 - `uniform`: uniform random @@ -114,16 +124,29 @@ The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--times #### Key Features for FLUX.1 LoRA training -1. CLIP-L LoRA Support: - - FLUX.1 LoRA training now supports CLIP-L LoRA. +1. CLIP-L and T5XXL LoRA Support: + - FLUX.1 LoRA training now supports CLIP-L and T5XXL LoRA training. - Remove `--network_train_unet_only` from your command. - - T5XXL is not trained. Its output is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. + - Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time. + - T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. - The trained LoRA can be used with ComfyUI. - - Note: `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA. + - Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet. + + | trained LoRA|option|network_args|cache_text_encoder_outputs (*1)| + |---|---|---|---| + |FLUX.1|`--network_train_unet_only`|-|o| + |FLUX.1 + CLIP-L|-|-|o (*2)| + |FLUX.1 + CLIP-L + T5XXL|-|`train_t5xxl=True`|-| + |CLIP-L (*3)|`--network_train_text_encoder_only`|-|o (*2)| + |CLIP-L + T5XXL (*3)|`--network_train_text_encoder_only`|`train_t5xxl=True`|-| + + - *1: `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. + - *2: T5XXL output can be cached for CLIP-L LoRA training. + - *3: Not tested yet. 2. Experimental FP8/FP16 mixed training: - - `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L. - - FLUX can be trained with fp8, and CLIP-L can be trained with bf16/fp16. + - `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L/T5XXL. + - FLUX can be trained with fp8, and CLIP-L/T5XXL can be trained with bf16/fp16. - When specifying this option, the `--fp8_base` option is automatically enabled. 3. Split Q/K/V Projection Layers (Experimental): @@ -153,7 +176,7 @@ The compatibility of the saved model (state dict) is ensured by concatenating th The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. ``` -python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 +python flux_minimal_inference.py --ckpt flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.safetensors --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 ``` ### FLUX.1 fine-tuning @@ -164,7 +187,7 @@ Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GP ``` accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train.py ---pretrained_model_name_or_path flux1-dev.sft --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.sft +--pretrained_model_name_or_path flux1-dev.safetensors --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.safetensors --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name output-name @@ -256,7 +279,7 @@ CLIP-L LoRA is not supported. `networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__ ``` -python networks/flux_merge_lora.py --flux_model flux1-dev.sft --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu +python networks/flux_merge_lora.py --flux_model flux1-dev.safetensors --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu ``` You can also merge multiple LoRA models into a FLUX.1 model. Specify multiple LoRA models in `--models`. Specify the same number of ratios in `--ratios`. diff --git a/flux_train_network.py b/flux_train_network.py index 354a8c6f..2fc0f323 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -43,13 +43,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): train_dataset_group.is_text_encoder_output_cacheable() ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" - # assert ( - # args.network_train_unet_only or not args.cache_text_encoder_outputs - # ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" - if not args.network_train_unet_only: - logger.info( - "network for CLIP-L only will be trained. T5XXL will not be trained / CLIP-Lのネットワークのみが学習されます。T5XXLは学習されません" - ) + # prepare CLIP-L/T5XXL training flags + self.train_clip_l = not args.network_train_unet_only + self.train_t5xxl = False # default is False even if args.network_train_unet_only is False if args.max_token_length is not None: logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") @@ -63,12 +59,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # currently offload to cpu for some models name = self.get_flux_model_name(args) - # if we load to cpu, flux.to(fp8) takes a long time - if args.fp8_base: - loading_dtype = None # as is - else: - loading_dtype = weight_dtype + # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) + loading_dtype = None if args.fp8_base else weight_dtype + # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future model = flux_utils.load_flow_model( name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors ) @@ -85,9 +79,21 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) clip_l.eval() + # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) + if args.fp8_base and not args.fp8_base_unet: + loading_dtype = None # as is + else: + loading_dtype = weight_dtype + # loading t5xxl to cpu takes a long time, so we should load to gpu in future - t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) t5xxl.eval() + if args.fp8_base and not args.fp8_base_unet: + # check dtype of model + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) @@ -154,25 +160,35 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def get_text_encoding_strategy(self, args): return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) + def post_process_network(self, args, accelerator, network, text_encoders, unet): + # check t5xxl is trained or not + self.train_t5xxl = network.train_t5xxl + + if self.train_t5xxl and args.cache_text_encoder_outputs: + raise ValueError( + "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" + ) + def get_models_for_text_encoding(self, args, accelerator, text_encoders): if args.cache_text_encoder_outputs: - if self.is_train_text_encoder(args): + if self.train_clip_l and not self.train_t5xxl: return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached else: - return text_encoders # ignored + return None # no text encoders are needed for encoding because both are cached else: return text_encoders # both CLIP-L and T5XXL are needed for encoding def get_text_encoders_train_flags(self, args, text_encoders): - return [True, False] if self.is_train_text_encoder(args) else [False, False] + return [self.train_clip_l, self.train_t5xxl] def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: + # if the text encoders is trained, we need tokenization, so is_partial is True return strategy_flux.FluxTextEncoderOutputsCachingStrategy( args.cache_text_encoder_outputs_to_disk, None, False, - is_partial=self.is_train_text_encoder(args), + is_partial=self.train_clip_l or self.train_t5xxl, apply_t5_attn_mask=args.apply_t5_attn_mask, ) else: @@ -193,8 +209,16 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # When TE is not be trained, it will not be prepared so we need to use explicit autocast logger.info("move text encoders to gpu") - text_encoders[0].to(accelerator.device, dtype=weight_dtype) - text_encoders[1].to(accelerator.device, dtype=weight_dtype) + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[1].to(accelerator.device) + + if text_encoders[1].dtype == torch.float8_e4m3fn: + # if we load fp8 weights, the model is already fp8, so we use it as is + self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) + else: + # otherwise, we need to convert it to target dtype + text_encoders[1].to(weight_dtype) + with accelerator.autocast(): dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process) @@ -235,7 +259,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): else: # Text Encoderから毎回出力を取得するので、GPUに乗せておく text_encoders[0].to(accelerator.device, dtype=weight_dtype) - text_encoders[1].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device) # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -255,9 +279,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # return noise_pred def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): + text_encoders = text_encoder # for compatibility + text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) + if not args.split_mode: flux_train_utils.sample_images( - accelerator, args, epoch, global_step, flux, ae, text_encoder, self.sample_prompts_te_outputs + accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs ) return @@ -281,7 +308,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device) clean_memory_on_device(accelerator.device) flux_train_utils.sample_images( - accelerator, args, epoch, global_step, wrapper, ae, text_encoder, self.sample_prompts_te_outputs + accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs ) clean_memory_on_device(accelerator.device) @@ -421,6 +448,47 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def is_text_encoder_not_needed_for_training(self, args): return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + if index == 0: # CLIP-L + return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) + else: # T5XXL + text_encoder.encoder.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + if index == 0: # CLIP-L + logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") + text_encoder.to(te_weight_dtype) # fp8 + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + else: # T5XXL + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: + logger.info(f"T5XXL already prepared for fp8") + else: + logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") + text_encoder.to(te_weight_dtype) # fp8 + prepare_fp8(text_encoder, weight_dtype) + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 9dad4baa..0b5d4d90 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -85,7 +85,7 @@ def sample_images( if distributed_state.num_processes <= 1: # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. - with torch.no_grad(): + with torch.no_grad(), accelerator.autocast(): for prompt_dict in prompts: sample_image_inference( accelerator, @@ -187,14 +187,27 @@ def sample_image_inference( tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + text_encoder_conds = [] if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: - te_outputs = sample_prompts_te_outputs[prompt] - else: + text_encoder_conds = sample_prompts_te_outputs[prompt] + print(f"Using cached text encoder outputs for prompt: {prompt}") + if text_encoders is not None: + print(f"Encoding prompt: {prompt}") tokens_and_masks = tokenize_strategy.tokenize(prompt) # strategy has apply_t5_attn_mask option - te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + print([x.shape if x is not None else None for x in encoded_text_encoder_conds]) - l_pooled, t5_out, txt_ids, t5_attn_mask = te_outputs + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] + + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds # sample image weight_dtype = ae.dtype # TOFO give dtype as argument diff --git a/library/flux_utils.py b/library/flux_utils.py index 68083616..7b0a41a8 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -171,7 +171,9 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev return clip -def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> T5EncoderModel: +def load_t5xxl( + ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False +) -> T5EncoderModel: T5_CONFIG_JSON = """ { "architectures": [ @@ -217,6 +219,11 @@ def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.devi return t5xxl +def get_t5xxl_actual_dtype(t5xxl: T5EncoderModel) -> torch.dtype: + # nn.Embedding is the first layer, but it could be casted to bfloat16 or float32 + return t5xxl.encoder.block[0].layer[0].SelfAttention.q.weight.dtype + + def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int): img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3) img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None] diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 5d083913..6c9ef5e4 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -5,8 +5,7 @@ import torch import numpy as np from transformers import CLIPTokenizer, T5TokenizerFast -from library import sd3_utils, train_util -from library import sd3_models +from library import flux_utils, train_util from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy from library.utils import setup_logging @@ -100,6 +99,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) self.apply_t5_attn_mask = apply_t5_attn_mask + self.warn_fp8_weights = False + def get_outputs_npz_path(self, image_abs_path: str) -> str: return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX @@ -144,6 +145,14 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): def cache_batch_outputs( self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List ): + if not self.warn_fp8_weights: + if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn: + logger.warning( + "T5 model is using fp8 weights for caching. This may affect the quality of the cached outputs." + " / T5モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。" + ) + self.warn_fp8_weights = True + flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy captions = [info.caption for info in infos] diff --git a/networks/lora_flux.py b/networks/lora_flux.py index fcb56a46..295267be 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -330,6 +330,11 @@ def create_network( if split_qkv is not None: split_qkv = True if split_qkv == "True" else False + # train T5XXL + train_t5xxl = kwargs.get("train_t5xxl", False) + if train_t5xxl is not None: + train_t5xxl = True if train_t5xxl == "True" else False + # すごく引数が多いな ( ^ω^)・・・ network = LoRANetwork( text_encoders, @@ -344,6 +349,7 @@ def create_network( conv_alpha=conv_alpha, train_blocks=train_blocks, split_qkv=split_qkv, + train_t5xxl=train_t5xxl, varbose=True, ) @@ -370,9 +376,10 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh else: weights_sd = torch.load(file, map_location="cpu") - # get dim/alpha mapping + # get dim/alpha mapping, and train t5xxl modules_dim = {} modules_alpha = {} + train_t5xxl = None for key, value in weights_sd.items(): if "." not in key: continue @@ -385,6 +392,12 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh modules_dim[lora_name] = dim # logger.info(lora_name, value.size(), dim) + if train_t5xxl is None: + train_t5xxl = "lora_te3" in lora_name + + if train_t5xxl is None: + train_t5xxl = False + # # split qkv # double_qkv_rank = None # single_qkv_rank = None @@ -413,6 +426,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh modules_alpha=modules_alpha, module_class=module_class, split_qkv=split_qkv, + train_t5xxl=train_t5xxl, ) return network, weights_sd @@ -421,10 +435,10 @@ class LoRANetwork(torch.nn.Module): # FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"] FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"] LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" - LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2" + LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible def __init__( self, @@ -443,6 +457,7 @@ class LoRANetwork(torch.nn.Module): modules_alpha: Optional[Dict[str, int]] = None, train_blocks: Optional[str] = None, split_qkv: bool = False, + train_t5xxl: bool = False, varbose: Optional[bool] = False, ) -> None: super().__init__() @@ -457,6 +472,7 @@ class LoRANetwork(torch.nn.Module): self.module_dropout = module_dropout self.train_blocks = train_blocks if train_blocks is not None else "all" self.split_qkv = split_qkv + self.train_t5xxl = train_t5xxl self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None @@ -469,12 +485,16 @@ class LoRANetwork(torch.nn.Module): logger.info( f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" ) - if self.conv_lora_dim is not None: - logger.info( - f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" - ) + # if self.conv_lora_dim is not None: + # logger.info( + # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + # ) if self.split_qkv: logger.info(f"split qkv for LoRA") + if self.train_blocks is not None: + logger.info(f"train {self.train_blocks} blocks only") + if train_t5xxl: + logger.info(f"train T5XXL as well") # create module instances def create_modules( @@ -550,12 +570,15 @@ class LoRANetwork(torch.nn.Module): skipped_te = [] for i, text_encoder in enumerate(text_encoders): index = i + if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False + break + logger.info(f"create LoRA for Text Encoder {index+1}:") text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.") self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped - logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # create LoRA for U-Net if self.train_blocks == "all": diff --git a/train_network.py b/train_network.py index 4204bce3..a68ccfcc 100644 --- a/train_network.py +++ b/train_network.py @@ -157,6 +157,9 @@ class NetworkTrainer: # region SD/SDXL + def post_process_network(self, args, accelerator, network, text_encoders, unet): + pass + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False @@ -237,6 +240,13 @@ class NetworkTrainer: def is_text_encoder_not_needed_for_training(self, args): return False # use for sample images + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + # set top parameter requires_grad = True for gradient checkpointing works + text_encoder.text_model.embeddings.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + # endregion def train(self, args): @@ -329,7 +339,7 @@ class NetworkTrainer: train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - self.assert_extra_args(args, train_dataset_group) + self.assert_extra_args(args, train_dataset_group) # may change some args # acceleratorを準備する logger.info("preparing accelerator") @@ -428,12 +438,15 @@ class NetworkTrainer: ) args.scale_weight_norms = False + self.post_process_network(args, accelerator, network, text_encoders, unet) + + # apply network to unet and text_encoder train_unet = not args.network_train_text_encoder_only train_text_encoder = self.is_train_text_encoder(args) network.apply_to(text_encoder, unet, train_text_encoder, train_unet) if args.network_weights is not None: - # FIXME consider alpha of weights + # FIXME consider alpha of weights: this assumes that the alpha is not changed info = network.load_weights(args.network_weights) accelerator.print(f"load network weights from {args.network_weights}: {info}") @@ -533,7 +546,7 @@ class NetworkTrainer: ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" accelerator.print("enable fp8 training for U-Net.") unet_weight_dtype = torch.float8_e4m3fn - + if not args.fp8_base_unet: accelerator.print("enable fp8 training for Text Encoder.") te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn @@ -545,17 +558,16 @@ class NetworkTrainer: unet.requires_grad_(False) unet.to(dtype=unet_weight_dtype) - for t_enc in text_encoders: + for i, t_enc in enumerate(text_encoders): t_enc.requires_grad_(False) # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 if t_enc.device.type != "cpu": t_enc.to(dtype=te_weight_dtype) - if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"): - # nn.Embedding not support FP8 - t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) - elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"): - t_enc.encoder.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + + # nn.Embedding not support FP8 + if te_weight_dtype != weight_dtype: + self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if args.deepspeed: @@ -596,12 +608,12 @@ class NetworkTrainer: if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() - for t_enc, frag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): + for i, (t_enc, frag) in enumerate(zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))): t_enc.train() # set top parameter requires_grad = True for gradient checkpointing works if frag: - t_enc.text_model.embeddings.requires_grad_(True) + self.prepare_text_encoder_grad_ckpt_workaround(i, t_enc) else: unet.eval() @@ -1028,8 +1040,12 @@ class NetworkTrainer: # log device and dtype for each model logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}") - for t_enc in text_encoders: - logger.info(f"text_encoder dtype: {t_enc.dtype}, device: {t_enc.device}") + for i, t_enc in enumerate(text_encoders): + params_itr = t_enc.parameters() + params_itr.__next__() # skip the first parameter + params_itr.__next__() # skip the second parameter. because CLIP first two parameters are embeddings + param_3rd = params_itr.__next__() + logger.info(f"text_encoder [{i}] dtype: {param_3rd.dtype}, device: {t_enc.device}") clean_memory_on_device(accelerator.device) @@ -1085,11 +1101,7 @@ class NetworkTrainer: text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs - if ( - len(text_encoder_conds) == 0 - or text_encoder_conds[0] is None - or train_text_encoder - ): + if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: From b7cff0a7548e5e33f735f06293ba24119fdaa585 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 4 Sep 2024 21:35:47 +0900 Subject: [PATCH 30/87] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 84065570..c0acfa1d 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ The command to install PyTorch is as follows: ### Recent Updates Sep 4, 2024: -- T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI. +- T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI. See [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) for details. - In LoRA training, when `--fp8_base` is specified, you can specify `t5xxl_fp8_e4m3fn.safetensors` as the T5XXL weights. However, it is recommended to use fp16 weights for caching. - Fixed an issue where the training CLIP-L LoRA was not used in sample image generation during LoRA training. From 56cb2fc885d818e9c4493fb2843870d7a141db1c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 4 Sep 2024 23:15:27 +0900 Subject: [PATCH 31/87] support T5XXL LoRA, reduce peak memory usage #1560 --- flux_minimal_inference.py | 73 +++++++++++++++++++++++++++++++-------- networks/lora_flux.py | 2 +- 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 56c1b198..1c194e7c 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -5,7 +5,7 @@ import datetime import math import os import random -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional import einops import numpy as np @@ -13,6 +13,7 @@ import torch from tqdm import tqdm from PIL import Image import accelerate +from transformers import CLIPTextModel from library import device_utils from library.device_utils import init_ipex, get_preferred_device @@ -125,7 +126,7 @@ def do_sample( def generate_image( model, - clip_l, + clip_l: CLIPTextModel, t5xxl, ae, prompt: str, @@ -141,12 +142,13 @@ def generate_image( # make first noise with packed shape # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2 packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16) + noise_dtype = torch.float32 if is_fp8(dtype) else dtype noise = torch.randn( 1, packed_latent_height * packed_latent_width, 16 * 2 * 2, device=device, - dtype=dtype, + dtype=noise_dtype, generator=torch.Generator(device=device).manual_seed(seed), ) @@ -166,9 +168,48 @@ def generate_image( clip_l = clip_l.to(device) t5xxl = t5xxl.to(device) with torch.no_grad(): - if is_fp8(clip_l_dtype) or is_fp8(t5xxl_dtype): - clip_l.to(clip_l_dtype) - t5xxl.to(t5xxl_dtype) + if is_fp8(clip_l_dtype): + param_itr = clip_l.parameters() + param_itr.__next__() # skip first + param_2nd = param_itr.__next__() + if param_2nd.dtype != clip_l_dtype: + logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}") + clip_l.to(clip_l_dtype) # fp8 + clip_l.text_model.embeddings.to(dtype=torch.bfloat16) + + with accelerator.autocast(): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + + if is_fp8(t5xxl_dtype): + if flux_utils.get_t5xxl_actual_dtype(t5xxl) != t5xxl_dtype or not hasattr(t5xxl, "fp8_prepared"): + logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}") + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + text_encoder.fp8_prepared = True + + t5xxl.to(t5xxl_dtype) + prepare_fp8(t5xxl.encoder, torch.bfloat16) + with accelerator.autocast(): _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask @@ -315,10 +356,10 @@ if __name__ == "__main__": t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device) t5xxl.eval() - if is_fp8(clip_l_dtype): - clip_l = accelerator.prepare(clip_l) - if is_fp8(t5xxl_dtype): - t5xxl = accelerator.prepare(t5xxl) + # if is_fp8(clip_l_dtype): + # clip_l = accelerator.prepare(clip_l) + # if is_fp8(t5xxl_dtype): + # t5xxl = accelerator.prepare(t5xxl) t5xxl_max_length = 256 if is_schnell else 512 tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) @@ -329,14 +370,16 @@ if __name__ == "__main__": model.eval() logger.info(f"Casting model to {flux_dtype}") model.to(flux_dtype) # make sure model is dtype - if is_fp8(flux_dtype): - model = accelerator.prepare(model) + # if is_fp8(flux_dtype): + # model = accelerator.prepare(model) + # if args.offload: + # model = model.to("cpu") # AE ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device) ae.eval() - if is_fp8(ae_dtype): - ae = accelerator.prepare(ae) + # if is_fp8(ae_dtype): + # ae = accelerator.prepare(ae) # LoRA lora_models: List[lora_flux.LoRANetwork] = [] @@ -360,7 +403,7 @@ if __name__ == "__main__": lora_model.to(device) lora_models.append(lora_model) - + if not args.interactive: generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance) else: diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 295267be..ab9ccc4d 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -392,7 +392,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh modules_dim[lora_name] = dim # logger.info(lora_name, value.size(), dim) - if train_t5xxl is None: + if train_t5xxl is None or train_t5xxl is False: train_t5xxl = "lora_te3" in lora_name if train_t5xxl is None: From 90ed2dfb526168b2e77b8d367e928d8cc44b4278 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 5 Sep 2024 08:39:29 +0900 Subject: [PATCH 32/87] feat: Add support for merging CLIP-L and T5XXL LoRA models --- README.md | 22 ++++- networks/flux_merge_lora.py | 180 ++++++++++++++++++++++++++++-------- 2 files changed, 162 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index c0acfa1d..fa81f6c0 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 5, 2024: +The LoRA merge script now supports CLIP-L and T5XXL LoRA. Please specify `--clip_l` and `--t5xxl`. `--clip_l_save_to` and `--t5xxl_save_to` specify the save destination for CLIP-L and T5XXL. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. + Sep 4, 2024: - T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI. See [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) for details. - In LoRA training, when `--fp8_base` is specified, you can specify `t5xxl_fp8_e4m3fn.safetensors` as the T5XXL weights. However, it is recommended to use fp16 weights for caching. @@ -276,7 +279,7 @@ CLIP-L LoRA is not supported. ### Merge LoRA to FLUX.1 checkpoint -`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__ +`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint, CLIP-L or T5XXL models. __The script is experimental.__ ``` python networks/flux_merge_lora.py --flux_model flux1-dev.safetensors --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu @@ -284,13 +287,24 @@ python networks/flux_merge_lora.py --flux_model flux1-dev.safetensors --save_to You can also merge multiple LoRA models into a FLUX.1 model. Specify multiple LoRA models in `--models`. Specify the same number of ratios in `--ratios`. -`--loading_device` is the device to load the LoRA models. `--working_device` is the device to merge (calculate) the models. Default is `cpu` for both. Loading / working device examples are below (in the case of `--save_precision fp16` or `--save_precision bf16`): +CLIP-L and T5XXL LoRA are supported. `--clip_l` and `--clip_l_save_to` are for CLIP-L, `--t5xxl` and `--t5xxl_save_to` are for T5XXL. Sample command is below. + +``` +--clip_l clip_l.safetensors --clip_l_save_to merged_clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --t5xxl_save_to merged_t5xxl.safetensors +``` + +FLUX.1, CLIP-L, and T5XXL can be merged together or separately for memory efficiency. + +An experimental option `--mem_eff_load_save` is available. This option is for memory-efficient loading and saving. It may also speed up loading and saving. + +`--loading_device` is the device to load the LoRA models. `--working_device` is the device to merge (calculate) the models. Default is `cpu` for both. Loading / working device examples are below (in the case of `--save_precision fp16` or `--save_precision bf16`, `float32` will consume more memory): - 'cpu' / 'cpu': Uses >50GB of RAM, but works on any machine. - 'cuda' / 'cpu': Uses 24GB of VRAM, but requires 30GB of RAM. -- 'cuda' / 'cuda': Uses 30GB of VRAM, but requires 30GB of RAM, faster than 'cuda' / 'cpu'. +- 'cpu' / 'cuda': Uses 4GB of VRAM, but requires 50GB of RAM, faster than 'cpu' / 'cpu' or 'cuda' / 'cpu'. +- 'cuda' / 'cuda': Uses 30GB of VRAM, but requires 30GB of RAM, faster than 'cpu' / 'cpu' or 'cuda' / 'cpu'. -In the case of LoRA models are trained with `bf16`, we are not sure which is better, `fp16` or `bf16` for `--save_precision`. +`--save_precision` is the precision to save the merged model. In the case of LoRA models are trained with `bf16`, we are not sure which is better, `fp16` or `bf16` for `--save_precision`. The script can merge multiple LoRA models. If you want to merge multiple LoRA models, specify `--concat` option to work the merged LoRA model properly. diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py index 2e0d4c29..5e100a3b 100644 --- a/networks/flux_merge_lora.py +++ b/networks/flux_merge_lora.py @@ -2,6 +2,7 @@ import argparse import math import os import time +from typing import Any, Dict, Union import torch from safetensors import safe_open @@ -34,11 +35,11 @@ def load_state_dict(file_name, dtype): return sd, metadata -def save_to_file(file_name, state_dict, dtype, metadata, mem_eff_save=False): +def save_to_file(file_name, state_dict: Dict[str, Union[Any, torch.Tensor]], dtype, metadata, mem_eff_save=False): if dtype is not None: logger.info(f"converting to {dtype}...") for key in tqdm(list(state_dict.keys())): - if type(state_dict[key]) == torch.Tensor: + if type(state_dict[key]) == torch.Tensor and state_dict[key].dtype.is_floating_point: state_dict[key] = state_dict[key].to(dtype) logger.info(f"saving to: {file_name}") @@ -49,26 +50,76 @@ def save_to_file(file_name, state_dict, dtype, metadata, mem_eff_save=False): def merge_to_flux_model( - loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False + loading_device, + working_device, + flux_path: str, + clip_l_path: str, + t5xxl_path: str, + models, + ratios, + merge_dtype, + save_dtype, + mem_eff_load_save=False, ): # create module map without loading state_dict - logger.info(f"loading keys from FLUX.1 model: {flux_model}") lora_name_to_module_key = {} - with safe_open(flux_model, framework="pt", device=loading_device) as flux_file: - keys = list(flux_file.keys()) - for key in keys: - if key.endswith(".weight"): - module_name = ".".join(key.split(".")[:-1]) - lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") - lora_name_to_module_key[lora_name] = key + if flux_path is not None: + logger.info(f"loading keys from FLUX.1 model: {flux_path}") + with safe_open(flux_path, framework="pt", device=loading_device) as flux_file: + keys = list(flux_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") + lora_name_to_module_key[lora_name] = key + lora_name_to_clip_l_key = {} + if clip_l_path is not None: + logger.info(f"loading keys from clip_l model: {clip_l_path}") + with safe_open(clip_l_path, framework="pt", device=loading_device) as clip_l_file: + keys = list(clip_l_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP + "_" + module_name.replace(".", "_") + lora_name_to_clip_l_key[lora_name] = key + + lora_name_to_t5xxl_key = {} + if t5xxl_path is not None: + logger.info(f"loading keys from t5xxl model: {t5xxl_path}") + with safe_open(t5xxl_path, framework="pt", device=loading_device) as t5xxl_file: + keys = list(t5xxl_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5 + "_" + module_name.replace(".", "_") + lora_name_to_t5xxl_key[lora_name] = key + + flux_state_dict = {} + clip_l_state_dict = {} + t5xxl_state_dict = {} if mem_eff_load_save: - flux_state_dict = {} - with MemoryEfficientSafeOpen(flux_model) as flux_file: - for key in tqdm(flux_file.keys()): - flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed + if flux_path is not None: + with MemoryEfficientSafeOpen(flux_path) as flux_file: + for key in tqdm(flux_file.keys()): + flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed + + if clip_l_path is not None: + with MemoryEfficientSafeOpen(clip_l_path) as clip_l_file: + for key in tqdm(clip_l_file.keys()): + clip_l_state_dict[key] = clip_l_file.get_tensor(key).to(loading_device) + + if t5xxl_path is not None: + with MemoryEfficientSafeOpen(t5xxl_path) as t5xxl_file: + for key in tqdm(t5xxl_file.keys()): + t5xxl_state_dict[key] = t5xxl_file.get_tensor(key).to(loading_device) else: - flux_state_dict = load_file(flux_model, device=loading_device) + if flux_path is not None: + flux_state_dict = load_file(flux_path, device=loading_device) + if clip_l_path is not None: + clip_l_state_dict = load_file(clip_l_path, device=loading_device) + if t5xxl_path is not None: + t5xxl_state_dict = load_file(t5xxl_path, device=loading_device) for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") @@ -81,8 +132,20 @@ def merge_to_flux_model( up_key = key.replace("lora_down", "lora_up") alpha_key = key[: key.index("lora_down")] + "alpha" - if lora_name not in lora_name_to_module_key: - logger.warning(f"no module found for LoRA weight: {key}. LoRA for Text Encoder is not supported yet.") + if lora_name in lora_name_to_module_key: + module_weight_key = lora_name_to_module_key[lora_name] + state_dict = flux_state_dict + elif lora_name in lora_name_to_clip_l_key: + module_weight_key = lora_name_to_clip_l_key[lora_name] + state_dict = clip_l_state_dict + elif lora_name in lora_name_to_t5xxl_key: + module_weight_key = lora_name_to_t5xxl_key[lora_name] + state_dict = t5xxl_state_dict + else: + logger.warning( + f"no module found for LoRA weight: {key}. Skipping..." + f"LoRAの重みに対応するモジュールが見つかりませんでした。スキップします。" + ) continue down_weight = lora_sd.pop(key) @@ -93,11 +156,7 @@ def merge_to_flux_model( scale = alpha / dim # W <- W + U * D - module_weight_key = lora_name_to_module_key[lora_name] - if module_weight_key not in flux_state_dict: - weight = flux_file.get_tensor(module_weight_key) - else: - weight = flux_state_dict[module_weight_key] + weight = state_dict[module_weight_key] weight = weight.to(working_device, merge_dtype) up_weight = up_weight.to(working_device, merge_dtype) @@ -121,7 +180,7 @@ def merge_to_flux_model( # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + ratio * conved * scale - flux_state_dict[module_weight_key] = weight.to(loading_device, save_dtype) + state_dict[module_weight_key] = weight.to(loading_device, save_dtype) del up_weight del down_weight del weight @@ -129,7 +188,7 @@ def merge_to_flux_model( if len(lora_sd) > 0: logger.warning(f"Unused keys in LoRA model: {list(lora_sd.keys())}") - return flux_state_dict + return flux_state_dict, clip_l_state_dict, t5xxl_state_dict def merge_to_flux_model_diffusers( @@ -508,17 +567,28 @@ def merge(args): if save_dtype is None: save_dtype = merge_dtype - dest_dir = os.path.dirname(args.save_to) + assert ( + args.save_to or args.clip_l_save_to or args.t5xxl_save_to + ), "save_to or clip_l_save_to or t5xxl_save_to must be specified / save_toまたはclip_l_save_toまたはt5xxl_save_toを指定してください" + dest_dir = os.path.dirname(args.save_to or args.clip_l_save_to or args.t5xxl_save_to) if not os.path.exists(dest_dir): logger.info(f"creating directory: {dest_dir}") os.makedirs(dest_dir) - if args.flux_model is not None: + if args.flux_model is not None or args.clip_l is not None or args.t5xxl is not None: if not args.diffusers: - state_dict = merge_to_flux_model( + assert (args.clip_l is None and args.clip_l_save_to is None) or ( + args.clip_l is not None and args.clip_l_save_to is not None + ), "clip_l_save_to must be specified if clip_l is specified / clip_lが指定されている場合はclip_l_save_toも指定してください" + assert (args.t5xxl is None and args.t5xxl_save_to is None) or ( + args.t5xxl is not None and args.t5xxl_save_to is not None + ), "t5xxl_save_to must be specified if t5xxl is specified / t5xxlが指定されている場合はt5xxl_save_toも指定してください" + flux_state_dict, clip_l_state_dict, t5xxl_state_dict = merge_to_flux_model( args.loading_device, args.working_device, args.flux_model, + args.clip_l, + args.t5xxl, args.models, args.ratios, merge_dtype, @@ -526,7 +596,10 @@ def merge(args): args.mem_eff_load_save, ) else: - state_dict = merge_to_flux_model_diffusers( + assert ( + args.clip_l is None and args.t5xxl is None + ), "clip_l and t5xxl are not supported with --diffusers / clip_l、t5xxlはDiffusersではサポートされていません" + flux_state_dict = merge_to_flux_model_diffusers( args.loading_device, args.working_device, args.flux_model, @@ -536,8 +609,10 @@ def merge(args): save_dtype, args.mem_eff_load_save, ) + clip_l_state_dict = None + t5xxl_state_dict = None - if args.no_metadata: + if args.no_metadata or (flux_state_dict is None or len(flux_state_dict) == 0): sai_metadata = None else: merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models) @@ -546,15 +621,24 @@ def merge(args): None, False, False, False, False, False, time.time(), title=title, merged_from=merged_from, flux="dev" ) - logger.info(f"saving FLUX model to: {args.save_to}") - save_to_file(args.save_to, state_dict, save_dtype, sai_metadata, args.mem_eff_load_save) + if flux_state_dict is not None and len(flux_state_dict) > 0: + logger.info(f"saving FLUX model to: {args.save_to}") + save_to_file(args.save_to, flux_state_dict, save_dtype, sai_metadata, args.mem_eff_load_save) + + if clip_l_state_dict is not None and len(clip_l_state_dict) > 0: + logger.info(f"saving clip_l model to: {args.clip_l_save_to}") + save_to_file(args.clip_l_save_to, clip_l_state_dict, save_dtype, None, args.mem_eff_load_save) + + if t5xxl_state_dict is not None and len(t5xxl_state_dict) > 0: + logger.info(f"saving t5xxl model to: {args.t5xxl_save_to}") + save_to_file(args.t5xxl_save_to, t5xxl_state_dict, save_dtype, None, args.mem_eff_load_save) else: - state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) + flux_state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) logger.info("calculating hashes and creating metadata...") - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(flux_state_dict, metadata) metadata["sshs_model_hash"] = model_hash metadata["sshs_legacy_hash"] = legacy_hash @@ -562,12 +646,12 @@ def merge(args): merged_from = sai_model_spec.build_merged_from(args.models) title = os.path.splitext(os.path.basename(args.save_to))[0] sai_metadata = sai_model_spec.build_metadata( - state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev" + flux_state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev" ) metadata.update(sai_metadata) logger.info(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, save_dtype, metadata) + save_to_file(args.save_to, flux_state_dict, save_dtype, metadata) def setup_parser() -> argparse.ArgumentParser: @@ -592,6 +676,18 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="FLUX.1 model to load, merge LoRA models if omitted / 読み込むモデル、指定しない場合はLoRAモデルをマージする", ) + parser.add_argument( + "--clip_l", + type=str, + default=None, + help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)", + ) + parser.add_argument( + "--t5xxl", + type=str, + default=None, + help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)", + ) parser.add_argument( "--mem_eff_load_save", action="store_true", @@ -617,6 +713,18 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="destination file name: safetensors file / 保存先のファイル名、safetensorsファイル", ) + parser.add_argument( + "--clip_l_save_to", + type=str, + default=None, + help="destination file name for clip_l: safetensors file / clip_lの保存先のファイル名、safetensorsファイル", + ) + parser.add_argument( + "--t5xxl_save_to", + type=str, + default=None, + help="destination file name for t5xxl: safetensors file / t5xxlの保存先のファイル名、safetensorsファイル", + ) parser.add_argument( "--models", type=str, From d9129522a6effea7077f18cdea0ee733a5ac7cb0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 5 Sep 2024 12:20:07 +0900 Subject: [PATCH 33/87] set dtype before calling ae closes #1562 --- flux_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flux_train.py b/flux_train.py index 32a36f03..0293b7be 100644 --- a/flux_train.py +++ b/flux_train.py @@ -651,7 +651,7 @@ def train(args): else: with torch.no_grad(): # encode images to latents. images are [-1, 1] - latents = ae.encode(batch["images"]) + latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): From 2889108d858880589d362e06e98eeadf4682476a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 5 Sep 2024 20:58:33 +0900 Subject: [PATCH 34/87] feat: Add --cpu_offload_checkpointing option to LoRA training --- README.md | 7 +++++++ flux_train.py | 2 +- flux_train_network.py | 5 +++++ train_network.py | 12 +++++++++++- 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index fa81f6c0..e8a12089 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,12 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 5, 2024 (update 1): + +Added `--cpu_offload_checkpointing` option to LoRA training script. Offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`. + Sep 5, 2024: + The LoRA merge script now supports CLIP-L and T5XXL LoRA. Please specify `--clip_l` and `--t5xxl`. `--clip_l_save_to` and `--t5xxl_save_to` specify the save destination for CLIP-L and T5XXL. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. Sep 4, 2024: @@ -72,6 +77,8 @@ The training can be done with 12GB VRAM GPUs with Adafactor optimizer, `--split_ --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 ``` +`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`. + We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. The trained LoRA model can be used with ComfyUI. diff --git a/flux_train.py b/flux_train.py index 0293b7be..0edc83a9 100644 --- a/flux_train.py +++ b/flux_train.py @@ -261,7 +261,7 @@ def train(args): ) if args.gradient_checkpointing: - flux.enable_gradient_checkpointing(args.cpu_offload_checkpointing) + flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) flux.requires_grad_(True) diff --git a/flux_train_network.py b/flux_train_network.py index 2fc0f323..a6e57eed 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -50,6 +50,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): if args.max_token_length is not None: logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + assert not args.split_mode or not args.cpu_offload_checkpointing, ( + "split_mode and cpu_offload_checkpointing cannot be used together" + " / split_modeとcpu_offload_checkpointingは同時に使用できません" + ) + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this def get_flux_model_name(self, args): diff --git a/train_network.py b/train_network.py index a68ccfcc..ad97491d 100644 --- a/train_network.py +++ b/train_network.py @@ -451,7 +451,11 @@ class NetworkTrainer: accelerator.print(f"load network weights from {args.network_weights}: {info}") if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() + if args.cpu_offload_checkpointing: + unet.enable_gradient_checkpointing(cpu_offload=True) + else: + unet.enable_gradient_checkpointing() + for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): if flag: if t_enc.supports_gradient_checkpointing: @@ -1281,6 +1285,12 @@ def setup_parser() -> argparse.ArgumentParser: config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing for U-Net or DiT, if supported" + " / 勾配チェックポイント時にテンソルをCPUにオフロードする(U-NetまたはDiTのみ、サポートされている場合)", + ) parser.add_argument( "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない" ) From 0005867ba509d2e1a5674b267e8286b561c0ed71 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 7 Sep 2024 10:45:18 +0900 Subject: [PATCH 35/87] update README, format code --- README.md | 5 +++++ library/train_util.py | 4 ++-- library/utils.py | 4 +++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 81a54937..16ab80e7 100644 --- a/README.md +++ b/README.md @@ -139,7 +139,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Working in progress +- When enlarging images in the script (when the size of the training image is small and bucket_no_upscale is not specified), it has been changed to use Pillow's resize and LANCZOS interpolation instead of OpenCV2's resize and Lanczos4 interpolation. The quality of the image enlargement may be slightly improved. PR [#1426](https://github.com/kohya-ss/sd-scripts/pull/1426) Thanks to sdbds! + +- Sample image generation during training now works on non-CUDA devices. PR [#1433](https://github.com/kohya-ss/sd-scripts/pull/1433) Thanks to millie-v! + - `--v_parameterization` is available in `sdxl_train.py`. The results are unpredictable, so use with caution. PR [#1505](https://github.com/kohya-ss/sd-scripts/pull/1505) Thanks to liesened! + - Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr! - The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower. - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available. diff --git a/library/train_util.py b/library/train_util.py index 102d39ed..1441e74f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2094,7 +2094,7 @@ class ControlNetDataset(BaseDataset): # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # resize to target if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: - cond_img=pil_resize(cond_img,(int(target_size_hw[1]), int(target_size_hw[0]))) + cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0]))) if flipped: cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride @@ -2432,7 +2432,7 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: return train_dataset_group -def load_image(image_path, alpha=False): +def load_image(image_path, alpha=False): try: with Image.open(image_path) as image: if alpha: diff --git a/library/utils.py b/library/utils.py index a219f6cb..5b7e657b 100644 --- a/library/utils.py +++ b/library/utils.py @@ -11,6 +11,7 @@ import cv2 from PIL import Image import numpy as np + def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() @@ -80,8 +81,8 @@ def setup_logging(args=None, log_level=None, reset=False): logger = logging.getLogger(__name__) logger.info(msg_init) -def pil_resize(image, size, interpolation=Image.LANCZOS): +def pil_resize(image, size, interpolation=Image.LANCZOS): pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) # use Pillow resize @@ -92,6 +93,7 @@ def pil_resize(image, size, interpolation=Image.LANCZOS): return resized_cv2 + # TODO make inf_utils.py From d29af146b8d4c4d028f8752657bd1349c8cd3509 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 9 Sep 2024 23:01:15 +0900 Subject: [PATCH 36/87] add negative prompt for flux inference script --- README.md | 3 + flux_minimal_inference.py | 283 +++++++++++++++++++++++++++----------- 2 files changed, 203 insertions(+), 83 deletions(-) diff --git a/README.md b/README.md index 2f010f49..126516f9 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 9, 2024: +Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used. + Sep 5, 2024 (update 1): Added `--cpu_offload_checkpointing` option to LoRA training script. Offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`. diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 1c194e7c..de607c52 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -71,22 +71,57 @@ def denoise( timesteps: list[float], guidance: float = 4.0, t5_attn_mask: Optional[torch.Tensor] = None, + neg_txt: Optional[torch.Tensor] = None, + neg_vec: Optional[torch.Tensor] = None, + neg_t5_attn_mask: Optional[torch.Tensor] = None, + cfg_scale: Optional[float] = None, ): # this is ignored for schnell + logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}") guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + + # prepare classifier free guidance + if neg_txt is not None and neg_vec is not None: + b_img_ids = torch.cat([img_ids, img_ids], dim=0) + b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0) + b_txt = torch.cat([neg_txt, txt], dim=0) + b_vec = torch.cat([neg_vec, vec], dim=0) + if t5_attn_mask is not None and neg_t5_attn_mask is not None: + b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0) + else: + b_t5_attn_mask = None + else: + b_img_ids = img_ids + b_txt_ids = txt_ids + b_txt = txt + b_vec = vec + b_t5_attn_mask = t5_attn_mask + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): - t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device) + + # classifier free guidance + if neg_txt is not None and neg_vec is not None: + b_img = torch.cat([img, img], dim=0) + else: + b_img = img + pred = model( - img=img, - img_ids=img_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, + img=b_img, + img_ids=b_img_ids, + txt=b_txt, + txt_ids=b_txt_ids, + y=b_vec, timesteps=t_vec, guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, + txt_attention_mask=b_t5_attn_mask, ) + # classifier free guidance + if neg_txt is not None and neg_vec is not None: + pred_uncond, pred = torch.chunk(pred, 2, dim=0) + pred = pred_uncond + cfg_scale * (pred - pred_uncond) + img = img + (t_prev - t_curr) * pred return img @@ -106,19 +141,48 @@ def do_sample( is_schnell: bool, device: torch.device, flux_dtype: torch.dtype, + neg_l_pooled: Optional[torch.Tensor] = None, + neg_t5_out: Optional[torch.Tensor] = None, + neg_t5_attn_mask: Optional[torch.Tensor] = None, + cfg_scale: Optional[float] = None, ): + logger.info(f"num_steps: {num_steps}") timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell) # denoise initial noise if accelerator: with accelerator.autocast(), torch.no_grad(): x = denoise( - model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask + model, + img, + img_ids, + t5_out, + txt_ids, + l_pooled, + timesteps, + guidance, + t5_attn_mask, + neg_t5_out, + neg_l_pooled, + neg_t5_attn_mask, + cfg_scale, ) else: with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad(): x = denoise( - model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask + model, + img, + img_ids, + t5_out, + txt_ids, + l_pooled, + timesteps, + guidance, + t5_attn_mask, + neg_t5_out, + neg_l_pooled, + neg_t5_attn_mask, + cfg_scale, ) return x @@ -135,6 +199,8 @@ def generate_image( image_height: int, steps: Optional[int], guidance: float, + negative_prompt: Optional[str], + cfg_scale: float, ): seed = seed if seed is not None else random.randint(0, 2**32 - 1) logger.info(f"Seed: {seed}") @@ -162,65 +228,73 @@ def generate_image( # txt2img only needs img_ids img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width) + # prepare fp8 models + if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared): + logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}") + clip_l.to(clip_l_dtype) # fp8 + clip_l.text_model.embeddings.to(dtype=torch.bfloat16) + clip_l.fp8_prepared = True + + if is_fp8(t5xxl_dtype) and (not hasattr(t5xxl, "fp8_prepared") or not t5xxl.fp8_prepared): + logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}") + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + t5xxl.to(t5xxl_dtype) + prepare_fp8(t5xxl.encoder, torch.bfloat16) + t5xxl.fp8_prepared = True + # prepare embeddings logger.info("Encoding prompts...") - tokens_and_masks = tokenize_strategy.tokenize(prompt) clip_l = clip_l.to(device) t5xxl = t5xxl.to(device) - with torch.no_grad(): - if is_fp8(clip_l_dtype): - param_itr = clip_l.parameters() - param_itr.__next__() # skip first - param_2nd = param_itr.__next__() - if param_2nd.dtype != clip_l_dtype: - logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}") - clip_l.to(clip_l_dtype) # fp8 - clip_l.text_model.embeddings.to(dtype=torch.bfloat16) - with accelerator.autocast(): - l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + def encode(prpt: str): + tokens_and_masks = tokenize_strategy.tokenize(prpt) + with torch.no_grad(): + if is_fp8(clip_l_dtype): + with accelerator.autocast(): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + else: + with torch.autocast(device_type=device.type, dtype=clip_l_dtype): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) - if is_fp8(t5xxl_dtype): - if flux_utils.get_t5xxl_actual_dtype(t5xxl) != t5xxl_dtype or not hasattr(t5xxl, "fp8_prepared"): - logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}") + if is_fp8(t5xxl_dtype): + with accelerator.autocast(): + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + else: + with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + return l_pooled, t5_out, txt_ids, t5_attn_mask - def prepare_fp8(text_encoder, target_dtype): - def forward_hook(module): - def forward(hidden_states): - hidden_gelu = module.act(module.wi_0(hidden_states)) - hidden_linear = module.wi_1(hidden_states) - hidden_states = hidden_gelu * hidden_linear - hidden_states = module.dropout(hidden_states) - - hidden_states = module.wo(hidden_states) - return hidden_states - - return forward - - for module in text_encoder.modules(): - if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: - # print("set", module.__class__.__name__, "to", target_dtype) - module.to(target_dtype) - if module.__class__.__name__ in ["T5DenseGatedActDense"]: - # print("set", module.__class__.__name__, "hooks") - module.forward = forward_hook(module) - - text_encoder.fp8_prepared = True - - t5xxl.to(t5xxl_dtype) - prepare_fp8(t5xxl.encoder, torch.bfloat16) - - with accelerator.autocast(): - _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask - ) - else: - with torch.autocast(device_type=device.type, dtype=clip_l_dtype): - l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) - with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): - _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( - tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask - ) + l_pooled, t5_out, txt_ids, t5_attn_mask = encode(prompt) + if negative_prompt: + neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode(negative_prompt) + else: + neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None # NaN check if torch.isnan(l_pooled).any(): @@ -244,7 +318,23 @@ def generate_image( t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None x = do_sample( - accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, t5_attn_mask, is_schnell, device, flux_dtype + accelerator, + model, + noise, + img_ids, + l_pooled, + t5_out, + txt_ids, + steps, + guidance, + t5_attn_mask, + is_schnell, + device, + flux_dtype, + neg_l_pooled, + neg_t5_out, + neg_t5_attn_mask, + cfg_scale, ) if args.offload: model = model.cpu() @@ -307,6 +397,8 @@ if __name__ == "__main__": parser.add_argument("--seed", type=int, default=None) parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev") parser.add_argument("--guidance", type=float, default=3.5) + parser.add_argument("--negative_prompt", type=str, default=None) + parser.add_argument("--cfg_scale", type=float, default=1.0) parser.add_argument("--offload", action="store_true", help="Offload to CPU") parser.add_argument( "--lora_weights", @@ -403,19 +495,34 @@ if __name__ == "__main__": lora_model.to(device) lora_models.append(lora_model) - + if not args.interactive: - generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance) + generate_image( + model, + clip_l, + t5xxl, + ae, + args.prompt, + args.seed, + args.width, + args.height, + args.steps, + args.guidance, + args.negative_prompt, + args.cfg_scale, + ) else: # loop for interactive width = target_width height = target_height steps = None guidance = args.guidance + cfg_scale = args.cfg_scale while True: print( "Enter prompt (empty to exit). Options: --w --h --s --d --g --m " + " --n , `-` for empty negative prompt --c " ) prompt = input() if prompt == "": @@ -425,26 +532,36 @@ if __name__ == "__main__": options = prompt.split("--") prompt = options[0].strip() seed = None + negative_prompt = None for opt in options[1:]: - opt = opt.strip() - if opt.startswith("w"): - width = int(opt[1:].strip()) - elif opt.startswith("h"): - height = int(opt[1:].strip()) - elif opt.startswith("s"): - steps = int(opt[1:].strip()) - elif opt.startswith("d"): - seed = int(opt[1:].strip()) - elif opt.startswith("g"): - guidance = float(opt[1:].strip()) - elif opt.startswith("m"): - mutipliers = opt[1:].strip().split(",") - if len(mutipliers) != len(lora_models): - logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") - continue - for i, lora_model in enumerate(lora_models): - lora_model.set_multiplier(float(mutipliers[i])) + try: + opt = opt.strip() + if opt.startswith("w"): + width = int(opt[1:].strip()) + elif opt.startswith("h"): + height = int(opt[1:].strip()) + elif opt.startswith("s"): + steps = int(opt[1:].strip()) + elif opt.startswith("d"): + seed = int(opt[1:].strip()) + elif opt.startswith("g"): + guidance = float(opt[1:].strip()) + elif opt.startswith("m"): + mutipliers = opt[1:].strip().split(",") + if len(mutipliers) != len(lora_models): + logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") + continue + for i, lora_model in enumerate(lora_models): + lora_model.set_multiplier(float(mutipliers[i])) + elif opt.startswith("n"): + negative_prompt = opt[1:].strip() + if negative_prompt == "-": + negative_prompt = "" + elif opt.startswith("c"): + cfg_scale = float(opt[1:].strip()) + except ValueError as e: + logger.error(f"Invalid option: {opt}, {e}") - generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance) + generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale) logger.info("Done!") From d10ff62a78b15d0bb55f443cc2849c460300131b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 10 Sep 2024 20:32:09 +0900 Subject: [PATCH 37/87] support individual LR for CLIP-L/T5XXL --- README.md | 4 +++ networks/lora_flux.py | 71 +++++++++++++++---------------------------- train_network.py | 32 ++++++++++++------- 3 files changed, 49 insertions(+), 58 deletions(-) diff --git a/README.md b/README.md index 126516f9..b5799dd6 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 10, 2024: +In FLUX.1 LoRA training, individual learning rates can be specified for CLIP-L and T5XXL. By specifying multiple numbers in `--text_encoder_lr`, you can set the learning rates for CLIP-L and T5XXL separately. Specify like `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. + Sep 9, 2024: Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used. @@ -142,6 +145,7 @@ The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--times - Remove `--network_train_unet_only` from your command. - Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time. - T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. + - The learning rates for CLIP-L and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL. - The trained LoRA can be used with ComfyUI. - Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet. diff --git a/networks/lora_flux.py b/networks/lora_flux.py index ab9ccc4d..d540c221 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -786,28 +786,23 @@ class LoRANetwork(torch.nn.Module): logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") - # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): - # TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?) - # if ( - # self.loraplus_lr_ratio is not None - # or self.loraplus_text_encoder_lr_ratio is not None - # or self.loraplus_unet_lr_ratio is not None - # ): - # assert ( - # optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower() - # ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません" + def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): + # make sure text_encoder_lr as list of two elements + if text_encoder_lr is None or len(text_encoder_lr) == 0: + text_encoder_lr = [default_lr, default_lr] + elif len(text_encoder_lr) == 1: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]] self.requires_grad_(True) all_params = [] lr_descriptions = [] - def assemble_params(loras, lr, ratio): + def assemble_params(loras, lr, loraplus_ratio): param_groups = {"lora": {}, "plus": {}} for lora in loras: for name, param in lora.named_parameters(): - if ratio is not None and "lora_up" in name: + if loraplus_ratio is not None and "lora_up" in name: param_groups["plus"][f"{lora.lora_name}.{name}"] = param else: param_groups["lora"][f"{lora.lora_name}.{name}"] = param @@ -822,7 +817,7 @@ class LoRANetwork(torch.nn.Module): if lr is not None: if key == "plus": - param_data["lr"] = lr * ratio + param_data["lr"] = lr * loraplus_ratio else: param_data["lr"] = lr @@ -836,41 +831,23 @@ class LoRANetwork(torch.nn.Module): return params, descriptions if self.text_encoder_loras: - params, descriptions = assemble_params( - self.text_encoder_loras, - text_encoder_lr if text_encoder_lr is not None else default_lr, - self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio, - ) - all_params.extend(params) - lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions]) + loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio + + # split text encoder loras for te1 and te3 + te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP)] + te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)] + if len(te1_loras) > 0: + logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}") + params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te3_loras) > 0: + logger.info(f"Text Encoder 2 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[1]}") + params, descriptions = assemble_params(te3_loras, text_encoder_lr[1], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 2 " + (" " + d if d else "") for d in descriptions]) if self.unet_loras: - # if self.block_lr: - # is_sdxl = False - # for lora in self.unet_loras: - # if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name: - # is_sdxl = True - # break - - # # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類 - # block_idx_to_lora = {} - # for lora in self.unet_loras: - # idx = get_block_index(lora.lora_name, is_sdxl) - # if idx not in block_idx_to_lora: - # block_idx_to_lora[idx] = [] - # block_idx_to_lora[idx].append(lora) - - # # blockごとにパラメータを設定する - # for idx, block_loras in block_idx_to_lora.items(): - # params, descriptions = assemble_params( - # block_loras, - # (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx), - # self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, - # ) - # all_params.extend(params) - # lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions]) - - # else: params, descriptions = assemble_params( self.unet_loras, unet_lr if unet_lr is not None else default_lr, diff --git a/train_network.py b/train_network.py index ad97491d..e45db052 100644 --- a/train_network.py +++ b/train_network.py @@ -466,9 +466,17 @@ class NetworkTrainer: # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - # 後方互換性を確保するよ + # make backward compatibility for text_encoder_lr + support_multiple_lrs = hasattr(network, "prepare_optimizer_params_with_multiple_te_lrs") + if support_multiple_lrs: + text_encoder_lr = args.text_encoder_lr + else: + text_encoder_lr = None if args.text_encoder_lr is None or len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0] try: - results = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) + if support_multiple_lrs: + results = network.prepare_optimizer_params_with_multiple_te_lrs(text_encoder_lr, args.unet_lr, args.learning_rate) + else: + results = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr, args.learning_rate) if type(results) is tuple: trainable_params = results[0] lr_descriptions = results[1] @@ -476,11 +484,7 @@ class NetworkTrainer: trainable_params = results lr_descriptions = None except TypeError as e: - # logger.warning(f"{e}") - # accelerator.print( - # "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" - # ) - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + trainable_params = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr) lr_descriptions = None # if len(trainable_params) == 0: @@ -713,7 +717,7 @@ class NetworkTrainer: "ss_training_started_at": training_started_at, # unix timestamp "ss_output_name": args.output_name, "ss_learning_rate": args.learning_rate, - "ss_text_encoder_lr": args.text_encoder_lr, + "ss_text_encoder_lr": text_encoder_lr, "ss_unet_lr": args.unet_lr, "ss_num_train_images": train_dataset_group.num_train_images, "ss_num_reg_images": train_dataset_group.num_reg_images, @@ -760,8 +764,8 @@ class NetworkTrainer: "ss_loss_type": args.loss_type, "ss_huber_schedule": args.huber_schedule, "ss_huber_c": args.huber_c, - "ss_fp8_base": args.fp8_base, - "ss_fp8_base_unet": args.fp8_base_unet, + "ss_fp8_base": bool(args.fp8_base), + "ss_fp8_base_unet": bool(args.fp8_base_unet), } self.update_metadata(metadata, args) # architecture specific metadata @@ -1303,7 +1307,13 @@ def setup_parser() -> argparse.ArgumentParser: ) parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") - parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") + parser.add_argument( + "--text_encoder_lr", + type=float, + default=None, + nargs="*", + help="learning rate for Text Encoder, can be multiple / Text Encoderの学習率、複数指定可能", + ) parser.add_argument( "--fp8_base_unet", action="store_true", From 65b8a064f6bb9a403374d4b08f4003037df42f8d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 10 Sep 2024 21:20:38 +0900 Subject: [PATCH 38/87] update README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b5799dd6..caea59b7 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ The command to install PyTorch is as follows: ### Recent Updates Sep 10, 2024: -In FLUX.1 LoRA training, individual learning rates can be specified for CLIP-L and T5XXL. By specifying multiple numbers in `--text_encoder_lr`, you can set the learning rates for CLIP-L and T5XXL separately. Specify like `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. +In FLUX.1 LoRA training, individual learning rates can be specified for CLIP-L and T5XXL. By specifying multiple numbers in `--text_encoder_lr`, you can set the learning rates for CLIP-L and T5XXL separately. Specify like `--text_encoder_lr 1e-4 1e-5`. The first value is the learning rate for CLIP-L, and the second value is for T5XXL. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. Sep 9, 2024: Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used. @@ -145,7 +145,7 @@ The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--times - Remove `--network_train_unet_only` from your command. - Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time. - T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. - - The learning rates for CLIP-L and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL. + - The learning rates for CLIP-L and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5`. The first value is the learning rate for CLIP-L, and the second value is for T5XXL. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL. - The trained LoRA can be used with ComfyUI. - Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet. From fd68703f3795b3e9c75409ac5452807d056b928f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= <865105819@qq.com> Date: Wed, 11 Sep 2024 20:25:45 +0800 Subject: [PATCH 39/87] Add New lr scheduler (#1393) * add new lr scheduler * fix bugs and use num_cycles / 2 * Update requirements.txt * add num_cycles for min lr * keep PIECEWISE_CONSTANT * allow use float with warmup or decay ratio. * Update train_util.py --- library/train_util.py | 80 ++++++++++++++++++++++++++++++++++++++----- requirements.txt | 6 ++-- 2 files changed, 75 insertions(+), 11 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index c7b73ee3..340f6d64 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -42,7 +42,8 @@ from torch.optim import Optimizer from torchvision import transforms from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection import transformers -from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION +from diffusers.optimization import SchedulerType as DiffusersSchedulerType, TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION +from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers import ( StableDiffusionPipeline, DDPMScheduler, @@ -2972,6 +2973,20 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): def add_optimizer_arguments(parser: argparse.ArgumentParser): + def int_or_float(value): + if value.endswith('%'): + try: + return float(value[:-1]) / 100.0 + except ValueError: + raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage") + try: + float_value = float(value) + if float_value >= 1: + return int(value) + return float(value) + except ValueError: + raise argparse.ArgumentTypeError(f"'{value}' is not an int or float") + parser.add_argument( "--optimizer_type", type=str, @@ -3024,9 +3039,15 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): ) parser.add_argument( "--lr_warmup_steps", - type=int, + type=int_or_float, default=0, - help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)", + help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)", + ) + parser.add_argument( + "--lr_decay_steps", + type=int_or_float, + default=0, + help="Int number of steps for the decay in the lr scheduler (default is 0) or float with ratio of train steps", ) parser.add_argument( "--lr_scheduler_num_cycles", @@ -3046,6 +3067,18 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL" + " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効", ) + parser.add_argument( + "--lr_scheduler_timescale", + type=int, + default=None, + help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`", + ) + parser.add_argument( + "--lr_scheduler_min_lr_ratio", + type=float, + default=None, + help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler", + ) def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): @@ -4293,10 +4326,14 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): Unified API to get any scheduler from its name. """ name = args.lr_scheduler - num_warmup_steps: Optional[int] = args.lr_warmup_steps num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps + num_warmup_steps: Optional[int] = int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps + num_decay_steps: Optional[int] = int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps + num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps num_cycles = args.lr_scheduler_num_cycles power = args.lr_scheduler_power + timescale = args.lr_scheduler_timescale + min_lr_ratio = args.lr_scheduler_min_lr_ratio lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0: @@ -4332,13 +4369,13 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): # logger.info(f"adafactor scheduler init lr {initial_lr}") return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) - name = SchedulerType(name) - schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + name = SchedulerType(name) or DiffusersSchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] or DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name] if name == SchedulerType.CONSTANT: return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs)) - if name == SchedulerType.PIECEWISE_CONSTANT: + if name == DiffusersSchedulerType.PIECEWISE_CONSTANT: return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs # All other schedulers require `num_warmup_steps` @@ -4348,6 +4385,9 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): if name == SchedulerType.CONSTANT_WITH_WARMUP: return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs) + if name == SchedulerType.INVERSE_SQRT: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs) + # All other schedulers require `num_training_steps` if num_training_steps is None: raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") @@ -4366,7 +4406,31 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs ) - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **lr_scheduler_kwargs) + if name == SchedulerType.COSINE_WITH_MIN_LR: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles / 2, + min_lr_rate=min_lr_ratio, + **lr_scheduler_kwargs, + ) + + # All other schedulers require `num_decay_steps` + if num_decay_steps is None: + raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.") + if name == SchedulerType.WARMUP_STABLE_DECAY: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_stable_steps=num_stable_steps, + num_decay_steps=num_decay_steps, + num_cycles=num_cycles / 2, + min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0, + **lr_scheduler_kwargs, + ) + + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_decay_steps=num_decay_steps, **lr_scheduler_kwargs) def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): diff --git a/requirements.txt b/requirements.txt index 977c5cd9..d2a2fbb8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -accelerate==0.25.0 -transformers==4.36.2 +accelerate==0.30.0 +transformers==4.41.2 diffusers[torch]==0.25.0 ftfy==6.1.1 # albumentations==1.3.0 @@ -16,7 +16,7 @@ altair==4.2.2 easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 -huggingface-hub==0.20.1 +huggingface-hub==0.23.3 # for Image utils imagesize==1.4.1 # for BLIP captioning From 6dbfd47a59cdb91be2077e1d0dec0f94698348dd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 11 Sep 2024 21:44:36 +0900 Subject: [PATCH 40/87] Fix to work PIECEWISE_CONSTANT, update requirement.txt and README #1393 --- README.md | 9 ++++++ library/train_util.py | 66 ++++++++++++++++++++++++++++--------------- requirements.txt | 4 +-- 3 files changed, 54 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 16ab80e7..011141bf 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,15 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Working in progress +- __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries. + - transformers, accelerate and huggingface_hub are updated. + - If you encounter any issues, please report them. + +- en: The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds! + - See the [transformers documentation](https://huggingface.co/docs/transformers/v4.44.2/en/main_classes/optimizer_schedules#schedules) for details on each scheduler. + - `--lr_warmup_steps` and `--lr_decay_steps` can now be specified as a ratio of the number of training steps, not just the step value. Example: `--lr_warmup_steps=0.1` or `--lr_warmup_steps=10%`, etc. + +https://github.com/kohya-ss/sd-scripts/pull/1393 - When enlarging images in the script (when the size of the training image is small and bucket_no_upscale is not specified), it has been changed to use Pillow's resize and LANCZOS interpolation instead of OpenCV2's resize and Lanczos4 interpolation. The quality of the image enlargement may be slightly improved. PR [#1426](https://github.com/kohya-ss/sd-scripts/pull/1426) Thanks to sdbds! - Sample image generation during training now works on non-CUDA devices. PR [#1433](https://github.com/kohya-ss/sd-scripts/pull/1433) Thanks to millie-v! diff --git a/library/train_util.py b/library/train_util.py index 340f6d64..e65760ba 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -42,7 +42,10 @@ from torch.optim import Optimizer from torchvision import transforms from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection import transformers -from diffusers.optimization import SchedulerType as DiffusersSchedulerType, TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION +from diffusers.optimization import ( + SchedulerType as DiffusersSchedulerType, + TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION, +) from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers import ( StableDiffusionPipeline, @@ -2974,7 +2977,7 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): def add_optimizer_arguments(parser: argparse.ArgumentParser): def int_or_float(value): - if value.endswith('%'): + if value.endswith("%"): try: return float(value[:-1]) / 100.0 except ValueError: @@ -3041,13 +3044,15 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): "--lr_warmup_steps", type=int_or_float, default=0, - help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)", + help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps" + " / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)", ) parser.add_argument( "--lr_decay_steps", type=int_or_float, default=0, - help="Int number of steps for the decay in the lr scheduler (default is 0) or float with ratio of train steps", + help="Int number of steps for the decay in the lr scheduler (default is 0) or float (<1) with ratio of train steps" + " / 学習率のスケジューラを減衰させるステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)", ) parser.add_argument( "--lr_scheduler_num_cycles", @@ -3071,13 +3076,16 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): "--lr_scheduler_timescale", type=int, default=None, - help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`", + help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`" + " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", + , ) parser.add_argument( "--lr_scheduler_min_lr_ratio", type=float, default=None, - help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler", + help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler" + " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", ) @@ -4327,8 +4335,12 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): """ name = args.lr_scheduler num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps - num_warmup_steps: Optional[int] = int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps - num_decay_steps: Optional[int] = int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps + num_warmup_steps: Optional[int] = ( + int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps + ) + num_decay_steps: Optional[int] = ( + int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps + ) num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps num_cycles = args.lr_scheduler_num_cycles power = args.lr_scheduler_power @@ -4369,15 +4381,17 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): # logger.info(f"adafactor scheduler init lr {initial_lr}") return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) - name = SchedulerType(name) or DiffusersSchedulerType(name) - schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] or DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name] + if name == DiffusersSchedulerType.PIECEWISE_CONSTANT.value: + name = DiffusersSchedulerType(name) + schedule_func = DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name] + return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs + + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] if name == SchedulerType.CONSTANT: return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs)) - if name == DiffusersSchedulerType.PIECEWISE_CONSTANT: - return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs - # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") @@ -4408,11 +4422,11 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): if name == SchedulerType.COSINE_WITH_MIN_LR: return schedule_func( - optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps, + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, num_cycles=num_cycles / 2, - min_lr_rate=min_lr_ratio, + min_lr_rate=min_lr_ratio, **lr_scheduler_kwargs, ) @@ -4421,16 +4435,22 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.") if name == SchedulerType.WARMUP_STABLE_DECAY: return schedule_func( - optimizer, - num_warmup_steps=num_warmup_steps, - num_stable_steps=num_stable_steps, - num_decay_steps=num_decay_steps, - num_cycles=num_cycles / 2, + optimizer, + num_warmup_steps=num_warmup_steps, + num_stable_steps=num_stable_steps, + num_decay_steps=num_decay_steps, + num_cycles=num_cycles / 2, min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0, **lr_scheduler_kwargs, ) - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_decay_steps=num_decay_steps, **lr_scheduler_kwargs) + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_decay_steps=num_decay_steps, + **lr_scheduler_kwargs, + ) def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): diff --git a/requirements.txt b/requirements.txt index d2a2fbb8..15e6e58f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ accelerate==0.30.0 -transformers==4.41.2 +transformers==4.44.0 diffusers[torch]==0.25.0 ftfy==6.1.1 # albumentations==1.3.0 @@ -16,7 +16,7 @@ altair==4.2.2 easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 -huggingface-hub==0.23.3 +huggingface-hub==0.24.5 # for Image utils imagesize==1.4.1 # for BLIP captioning From 8311e88225fef377591e5be19eb1f50fe7a2941f Mon Sep 17 00:00:00 2001 From: cocktailpeanut Date: Wed, 11 Sep 2024 09:02:29 -0400 Subject: [PATCH 41/87] typo fix --- library/train_util.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index c38864fe..f682dcbf 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3355,15 +3355,14 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): type=int, default=None, help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`" - " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", - , + + " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", ) parser.add_argument( "--lr_scheduler_min_lr_ratio", type=float, default=None, help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler" - " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", + + " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", ) From c7c666b1829a7c1f3435558efa425b08b50fab41 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 11 Sep 2024 22:12:31 +0900 Subject: [PATCH 42/87] fix typo --- library/train_util.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index e65760ba..a46d9487 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3077,15 +3077,14 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): type=int, default=None, help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`" - " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", - , + + " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", ) parser.add_argument( "--lr_scheduler_min_lr_ratio", type=float, default=None, help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler" - " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", + + " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", ) From a823fd9fb8d219b5b4c57df12eed41ae34fdf843 Mon Sep 17 00:00:00 2001 From: Plat <60182057+p1atdev@users.noreply.github.com> Date: Wed, 11 Sep 2024 22:21:16 +0900 Subject: [PATCH 43/87] Improve wandb logging (#1576) * fix: wrong training steps were recorded to wandb, and no log was sent when logging_dir was not specified * fix: checking of whether wandb is enabled * feat: log images to wandb with their positive prompt as captions * feat: logging sample images' caption for sd3 and flux * fix: import wandb before use --- fine_tune.py | 7 +++++-- flux_train.py | 7 +++++-- library/flux_train_utils.py | 20 +++++++++++--------- library/sd3_train_utils.py | 20 +++++++++++--------- library/train_util.py | 20 +++++++++++--------- sd3_train.py | 7 +++++-- sdxl_train.py | 7 +++++-- sdxl_train_control_net_lllite.py | 4 ++-- sdxl_train_control_net_lllite_old.py | 4 ++-- train_controlnet.py | 7 +++++-- train_db.py | 7 +++++-- train_network.py | 7 +++++-- train_textual_inversion.py | 8 ++++++-- train_textual_inversion_XTI.py | 4 ++-- 14 files changed, 80 insertions(+), 49 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index c9102f6c..fb6b3ed6 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -337,6 +337,9 @@ def train(args): train_util.sample_images( accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -456,7 +459,7 @@ def train(args): ) current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) accelerator.log(logs, step=global_step) @@ -469,7 +472,7 @@ def train(args): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/flux_train.py b/flux_train.py index 0edc83a9..33481df8 100644 --- a/flux_train.py +++ b/flux_train.py @@ -629,6 +629,9 @@ def train(args): # For --sample_at_first flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() epoch = 0 # avoid error when max_train_steps is 0 @@ -777,7 +780,7 @@ def train(args): ) current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) @@ -791,7 +794,7 @@ def train(args): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 0b5d4d90..f77d4b58 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -254,17 +254,19 @@ def sample_image_inference( img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" image.save(os.path.join(save_dir, img_filename)) - # wandb有効時のみログを送信 - try: + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass + import wandb + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log( + {f"sample_{i}": wandb.Image( + image, + caption=prompt # positive prompt as a caption + )}, + commit=False + ) def time_shift(mu: float, sigma: float, t: torch.Tensor): diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index da072950..e819d440 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -604,17 +604,19 @@ def sample_image_inference( img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" image.save(os.path.join(save_dir, img_filename)) - # wandb有効時のみログを送信 - try: + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass + import wandb + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log( + {f"sample_{i}": wandb.Image( + image, + caption=prompt # positive prompt as a caption + )}, + commit=False + ) # region Diffusers diff --git a/library/train_util.py b/library/train_util.py index f682dcbf..742d057e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5832,17 +5832,19 @@ def sample_image_inference( img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" image.save(os.path.join(save_dir, img_filename)) - # wandb有効時のみログを送信 - try: + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass + import wandb + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log( + {f"sample_{i}": wandb.Image( + image, + caption=prompt # positive prompt as a caption + )}, + commit=False + ) # endregion diff --git a/sd3_train.py b/sd3_train.py index 87011b21..5120105f 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -682,6 +682,9 @@ def train(args): # For --sample_at_first sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) # following function will be moved to sd3_train_utils @@ -901,7 +904,7 @@ def train(args): ) current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_mmdit) @@ -915,7 +918,7 @@ def train(args): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/sdxl_train.py b/sdxl_train.py index b2c62dd1..7291ddd2 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -617,6 +617,9 @@ def train(args): sdxl_train_util.sample_images( accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, [text_encoder1, text_encoder2], unet ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -797,7 +800,7 @@ def train(args): ) current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} if block_lrs is None: train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet) @@ -814,7 +817,7 @@ def train(args): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 0eaec29b..9d1cfc63 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -541,14 +541,14 @@ def train(args): logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 292a0463..6fa1d609 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -480,14 +480,14 @@ def train(args): logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/train_controlnet.py b/train_controlnet.py index c9ac6c5a..57f0d263 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -409,6 +409,9 @@ def train(args): train_util.sample_images( accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) # training loop for epoch in range(num_train_epochs): @@ -542,14 +545,14 @@ def train(args): logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/train_db.py b/train_db.py index 7caee664..d42afd89 100644 --- a/train_db.py +++ b/train_db.py @@ -315,6 +315,9 @@ def train(args): train_util.sample_images( accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -445,7 +448,7 @@ def train(args): ) current_loss = loss.detach().item() - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) accelerator.log(logs, step=global_step) @@ -458,7 +461,7 @@ def train(args): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/train_network.py b/train_network.py index e45db052..34385ae0 100644 --- a/train_network.py +++ b/train_network.py @@ -1038,6 +1038,9 @@ class NetworkTrainer: # For --sample_at_first self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) # training loop if initial_step > 0: # only if skip_until_initial_step is specified @@ -1224,7 +1227,7 @@ class NetworkTrainer: if args.scale_weight_norms: progress_bar.set_postfix(**{**max_mean_logs, **logs}) - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = self.generate_step_logs( args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm ) @@ -1233,7 +1236,7 @@ class NetworkTrainer: if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 9044f50d..956c7860 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -550,6 +550,9 @@ class TextualInversionTrainer: unet, prompt_replacement, ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) # training loop for epoch in range(num_train_epochs): @@ -684,7 +687,7 @@ class TextualInversionTrainer: remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} if ( args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() @@ -702,7 +705,7 @@ class TextualInversionTrainer: if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_total / len(train_dataloader)} accelerator.log(logs, step=epoch + 1) @@ -739,6 +742,7 @@ class TextualInversionTrainer: unet, prompt_replacement, ) + accelerator.log({}) # end of epoch diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index efb59137..ca0b603f 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -538,7 +538,7 @@ def train(args): remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} if ( args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() @@ -556,7 +556,7 @@ def train(args): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_total / len(train_dataloader)} accelerator.log(logs, step=epoch + 1) From 237317fffd060bcfb078b770ccd2df18bc4dd3a6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 11 Sep 2024 22:23:43 +0900 Subject: [PATCH 44/87] update README --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 2b3d0d5a..d3481b6a 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 11, 2024: +Logging to wandb is improved. See PR [#1576](https://github.com/kohya-ss/sd-scripts/pull/1576) for details. Thanks to p1atdev! + Sep 10, 2024: In FLUX.1 LoRA training, individual learning rates can be specified for CLIP-L and T5XXL. By specifying multiple numbers in `--text_encoder_lr`, you can set the learning rates for CLIP-L and T5XXL separately. Specify like `--text_encoder_lr 1e-4 1e-5`. The first value is the learning rate for CLIP-L, and the second value is for T5XXL. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. From cefe52629e1901dd8192b0487afd5e9f089e3519 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 12 Sep 2024 12:36:07 +0900 Subject: [PATCH 45/87] fix to work old notation for TE LR in .toml --- networks/lora_flux.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index d540c221..dd267de0 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -788,8 +788,11 @@ class LoRANetwork(torch.nn.Module): def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): # make sure text_encoder_lr as list of two elements - if text_encoder_lr is None or len(text_encoder_lr) == 0: + # if float, use the same value for both text encoders + if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): text_encoder_lr = [default_lr, default_lr] + elif isinstance(text_encoder_lr, float): + text_encoder_lr = [text_encoder_lr, text_encoder_lr] elif len(text_encoder_lr) == 1: text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]] From 1d7118a62268f12ebfd81c10db53bd85ef9d7631 Mon Sep 17 00:00:00 2001 From: Maru-mee <151493593+Maru-mee@users.noreply.github.com> Date: Fri, 13 Sep 2024 19:01:36 +0900 Subject: [PATCH 46/87] Support : OFT merge to base model (#1580) * Support : OFT merge to base model * Fix typo * Fix typo_2 * Delete unused parameter 'eye' --- networks/sdxl_merge_lora.py | 190 +++++++++++++++++++++++++++--------- 1 file changed, 143 insertions(+), 47 deletions(-) diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index 3383a80d..2c998c8c 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -8,10 +8,12 @@ from tqdm import tqdm from library import sai_model_spec, sdxl_model_util, train_util import library.model_util as model_util import lora +import oft from library.utils import setup_logging setup_logging() import logging logger = logging.getLogger(__name__) +import concurrent.futures def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": @@ -39,82 +41,176 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): else: torch.save(model, file_name) +def detect_method_from_training_model(models, dtype): + for model in models: + lora_sd, _ = load_state_dict(model, dtype) + for key in tqdm(lora_sd.keys()): + if 'lora_up' in key or 'lora_down' in key: + return 'LoRA' + elif "oft_blocks" in key: + return 'OFT' def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): text_encoder1.to(merge_dtype) text_encoder1.to(merge_dtype) unet.to(merge_dtype) + + # detect the method: OFT or LoRA_module + method = detect_method_from_training_model(models, merge_dtype) + logger.info(f"method:{method}") # create module map name_to_module = {} for i, root_module in enumerate([text_encoder1, text_encoder2, unet]): - if i <= 1: - if i == 0: - prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 + if method == 'LoRA': + if i <= 1: + if i == 0: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 + else: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2 + target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE else: - prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2 - target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE - else: - prefix = lora.LoRANetwork.LORA_PREFIX_UNET - target_replace_modules = ( + prefix = lora.LoRANetwork.LORA_PREFIX_UNET + target_replace_modules = ( lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + ) + elif method == 'OFT': + prefix = oft.OFTNetwork.OFT_PREFIX_UNET + target_replace_modules = ( + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 ) for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: for child_name, child_module in module.named_modules(): - if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": - lora_name = prefix + "." + name + "." + child_name - lora_name = lora_name.replace(".", "_") - name_to_module[lora_name] = child_module - + if method == 'LoRA': + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + name_to_module[lora_name] = child_module + elif method == 'OFT': + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": + oft_name = prefix + "." + name + "." + child_name + oft_name = oft_name.replace(".", "_") + name_to_module[oft_name] = child_module + + for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) logger.info(f"merging...") - for key in tqdm(lora_sd.keys()): - if "lora_down" in key: - up_key = key.replace("lora_down", "lora_up") - alpha_key = key[: key.index("lora_down")] + "alpha" - # find original module for this lora - module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" + if method == 'LoRA': + for key in tqdm(lora_sd.keys()): + if "lora_down" in key: + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + # find original module for this lora + module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" + if module_name not in name_to_module: + logger.info(f"no module found for LoRA weight: {key}") + continue + module = name_to_module[module_name] + # logger.info(f"apply {key} to {module}") + + down_weight = lora_sd[key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + # W <- W + U * D + weight = module.weight + # logger.info(module_name, down_weight.size(), up_weight.size()) + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale + + module.weight = torch.nn.Parameter(weight) + + + elif method == 'OFT': + + multiplier=1.0 + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + for key in tqdm(lora_sd.keys()): + if "oft_blocks" in key: + oft_blocks = lora_sd[key] + dim = oft_blocks.shape[0] + break + for key in tqdm(lora_sd.keys()): + if "alpha" in key: + oft_blocks = lora_sd[key] + alpha = oft_blocks.item() + break + + def merge_to(key): + if "alpha" in key: + return + + # find original module for this OFT + module_name = ".".join(key.split(".")[:-1]) if module_name not in name_to_module: - logger.info(f"no module found for LoRA weight: {key}") - continue + return module = name_to_module[module_name] + # logger.info(f"apply {key} to {module}") + + oft_blocks = lora_sd[key] + + if isinstance(module, torch.nn.Linear): + out_dim = module.out_features + elif isinstance(module, torch.nn.Conv2d): + out_dim = module.out_channels + + num_blocks = dim + block_size = out_dim // dim + constraint = (0 if alpha is None else alpha) * out_dim + + block_Q = oft_blocks - oft_blocks.transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + I = torch.eye(block_size, device=oft_blocks.device).unsqueeze(0).repeat(num_blocks, 1, 1) + block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) + block_R_weighted = multiplier * block_R + (1 - multiplier) * I + R = torch.block_diag(*block_R_weighted) + + # get org weight + org_sd = module.state_dict() + org_weight = org_sd["weight"].to(device) - down_weight = lora_sd[key] - up_weight = lora_sd[up_key] - - dim = down_weight.size()[0] - alpha = lora_sd.get(alpha_key, dim) - scale = alpha / dim - - # W <- W + U * D - weight = module.weight - # logger.info(module_name, down_weight.size(), up_weight.size()) - if len(weight.size()) == 2: - # linear - weight = weight + ratio * (up_weight @ down_weight) * scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = ( - weight - + ratio - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * scale - ) + R = R.to(org_weight.device, dtype=org_weight.dtype) + + if org_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", org_weight, R) else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # logger.info(conved.size(), weight.size(), module.stride, module.padding) - weight = weight + ratio * conved * scale - + weight = torch.einsum("oi, op -> pi", org_weight, R) + + weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor + module.weight = torch.nn.Parameter(weight) + with concurrent.futures.ThreadPoolExecutor() as executor: + list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys()))) + def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): base_alphas = {} # alpha for merged model From 57ae44eb6138fe4a3864fffa62090f9d0113417d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 13 Sep 2024 19:45:00 +0900 Subject: [PATCH 47/87] refactor to make safer --- networks/sdxl_merge_lora.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index 2c998c8c..d5a54e02 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -44,11 +44,11 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): def detect_method_from_training_model(models, dtype): for model in models: lora_sd, _ = load_state_dict(model, dtype) - for key in tqdm(lora_sd.keys()): - if 'lora_up' in key or 'lora_down' in key: - return 'LoRA' - elif "oft_blocks" in key: - return 'OFT' + for key in tqdm(lora_sd.keys()): + if 'lora_up' in key or 'lora_down' in key: + return 'LoRA' + elif "oft_blocks" in key: + return 'OFT' def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): text_encoder1.to(merge_dtype) @@ -76,6 +76,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ ) elif method == 'OFT': prefix = oft.OFTNetwork.OFT_PREFIX_UNET + # ALL_LINEAR includes ATTN_ONLY, so we don't need to specify ATTN_ONLY target_replace_modules = ( oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 ) @@ -83,17 +84,11 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: for child_name, child_module in module.named_modules(): - if method == 'LoRA': - if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": - lora_name = prefix + "." + name + "." + child_name - lora_name = lora_name.replace(".", "_") - name_to_module[lora_name] = child_module - elif method == 'OFT': - if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": - oft_name = prefix + "." + name + "." + child_name - oft_name = oft_name.replace(".", "_") - name_to_module[oft_name] = child_module - + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + name_to_module[lora_name] = child_module + for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") @@ -168,6 +163,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ # find original module for this OFT module_name = ".".join(key.split(".")[:-1]) if module_name not in name_to_module: + logger.info(f"no module found for OFT weight: {key}") return module = name_to_module[module_name] @@ -208,7 +204,9 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ module.weight = torch.nn.Parameter(weight) - with concurrent.futures.ThreadPoolExecutor() as executor: + # TODO multi-threading may cause OOM on CPU if cpu_count is too high and RAM is not enough + max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys()))) From 3387dc7306087b84646666e49323980c89d14945 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 13 Sep 2024 19:45:42 +0900 Subject: [PATCH 48/87] formatting, update README --- README.md | 6 +++ networks/sdxl_merge_lora.py | 86 +++++++++++++++++++++---------------- 2 files changed, 54 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index fd81a781..d5d2a7f7 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### Sep 13, 2024 / 2024-09-13: + +- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580). Will be included in the next release. + +- `sdxl_merge_lora.py` が OFT をサポートしました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。次のリリースに含まれます。 + ### Jun 23, 2024 / 2024-06-23: - Fixed `cache_latents.py` and `cache_text_encoder_outputs.py` not working. (Will be included in the next release.) diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index d5a54e02..d5b6f7f3 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -10,11 +10,14 @@ import library.model_util as model_util import lora import oft from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) import concurrent.futures + def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": sd = load_file(file_name) @@ -41,20 +44,22 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): else: torch.save(model, file_name) + def detect_method_from_training_model(models, dtype): for model in models: lora_sd, _ = load_state_dict(model, dtype) for key in tqdm(lora_sd.keys()): - if 'lora_up' in key or 'lora_down' in key: - return 'LoRA' + if "lora_up" in key or "lora_down" in key: + return "LoRA" elif "oft_blocks" in key: - return 'OFT' + return "OFT" + def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): text_encoder1.to(merge_dtype) text_encoder1.to(merge_dtype) unet.to(merge_dtype) - + # detect the method: OFT or LoRA_module method = detect_method_from_training_model(models, merge_dtype) logger.info(f"method:{method}") @@ -62,7 +67,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ # create module map name_to_module = {} for i, root_module in enumerate([text_encoder1, text_encoder2, unet]): - if method == 'LoRA': + if method == "LoRA": if i <= 1: if i == 0: prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 @@ -72,9 +77,9 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ else: prefix = lora.LoRANetwork.LORA_PREFIX_UNET target_replace_modules = ( - lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 ) - elif method == 'OFT': + elif method == "OFT": prefix = oft.OFTNetwork.OFT_PREFIX_UNET # ALL_LINEAR includes ATTN_ONLY, so we don't need to specify ATTN_ONLY target_replace_modules = ( @@ -88,15 +93,14 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ lora_name = prefix + "." + name + "." + child_name lora_name = lora_name.replace(".", "_") name_to_module[lora_name] = child_module - - + for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) logger.info(f"merging...") - if method == 'LoRA': + if method == "LoRA": for key in tqdm(lora_sd.keys()): if "lora_down" in key: up_key = key.replace("lora_down", "lora_up") @@ -139,12 +143,11 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ module.weight = torch.nn.Parameter(weight) - - elif method == 'OFT': - - multiplier=1.0 - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - + elif method == "OFT": + + multiplier = 1.0 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + for key in tqdm(lora_sd.keys()): if "oft_blocks" in key: oft_blocks = lora_sd[key] @@ -154,12 +157,12 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ if "alpha" in key: oft_blocks = lora_sd[key] alpha = oft_blocks.item() - break - + break + def merge_to(key): if "alpha" in key: return - + # find original module for this OFT module_name = ".".join(key.split(".")[:-1]) if module_name not in name_to_module: @@ -168,18 +171,18 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ module = name_to_module[module_name] # logger.info(f"apply {key} to {module}") - + oft_blocks = lora_sd[key] - + if isinstance(module, torch.nn.Linear): out_dim = module.out_features elif isinstance(module, torch.nn.Conv2d): out_dim = module.out_channels - + num_blocks = dim block_size = out_dim // dim constraint = (0 if alpha is None else alpha) * out_dim - + block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=constraint) @@ -188,24 +191,24 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) block_R_weighted = multiplier * block_R + (1 - multiplier) * I R = torch.block_diag(*block_R_weighted) - + # get org weight org_sd = module.state_dict() org_weight = org_sd["weight"].to(device) R = R.to(org_weight.device, dtype=org_weight.dtype) - + if org_weight.dim() == 4: weight = torch.einsum("oihw, op -> pihw", org_weight, R) else: weight = torch.einsum("oi, op -> pi", org_weight, R) - - weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor - + + weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor + module.weight = torch.nn.Parameter(weight) # TODO multi-threading may cause OOM on CPU if cpu_count is too high and RAM is not enough - max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU + max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys()))) @@ -258,7 +261,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): for key in tqdm(lora_sd.keys()): if "alpha" in key: continue - + if "lora_up" in key and concat: concat_dim = 1 elif "lora_down" in key and concat: @@ -272,8 +275,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): alpha = alphas[lora_module_name] scale = math.sqrt(alpha / base_alpha) * ratio - scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 - + scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 + if key in merged_sd: assert ( merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None @@ -295,7 +298,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): dim = merged_sd[key_down].shape[0] perm = torch.randperm(dim) merged_sd[key_down] = merged_sd[key_down][perm] - merged_sd[key_up] = merged_sd[key_up][:,perm] + merged_sd[key_up] = merged_sd[key_up][:, perm] logger.info("merged model") logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") @@ -323,7 +326,9 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): def merge(args): - assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + assert len(args.models) == len( + args.ratios + ), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" def str_to_dtype(p): if p == "float": @@ -410,10 +415,16 @@ def setup_parser() -> argparse.ArgumentParser: help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする", ) parser.add_argument( - "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + "--save_to", + type=str, + default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", ) parser.add_argument( - "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" + "--models", + type=str, + nargs="*", + help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors", ) parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") parser.add_argument( @@ -431,8 +442,7 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--shuffle", action="store_true", - help="shuffle lora weight./ " - + "LoRAの重みをシャッフルする", + help="shuffle lora weight./ " + "LoRAの重みをシャッフルする", ) return parser From 734d2e5b2b7a1551f3750a15e71060f3beed98e9 Mon Sep 17 00:00:00 2001 From: terracottahaniwa <57107346+terracottahaniwa@users.noreply.github.com> Date: Fri, 13 Sep 2024 20:45:35 +0900 Subject: [PATCH 49/87] Support Lora Block Weight (LBW) to svd_merge_lora.py (#1575) * support lora block weight * solve license incompatibility * Fix issue: lbw index calculation --- networks/svd_merge_lora.py | 150 ++++++++++++++++++++++++++++++++++++- 1 file changed, 146 insertions(+), 4 deletions(-) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index cb00a600..6e163aec 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -1,5 +1,8 @@ import argparse +import itertools +import json import os +import re import time import torch from safetensors.torch import load_file, save_file @@ -14,6 +17,106 @@ logger = logging.getLogger(__name__) CLAMP_QUANTILE = 0.99 +ACCEPTABLE = [12, 17, 20, 26] +SDXL_LAYER_NUM = [12, 20] + +LAYER12 = { + "BASE": True, + "IN00": False, "IN01": False, "IN02": False, "IN03": False, "IN04": True, "IN05": True, + "IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "MID": True, + "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, + "OUT06": False, "OUT07": False, "OUT08": False, "OUT09": False, "OUT10": False, "OUT11": False +} + +LAYER17 = { + "BASE": True, + "IN00": False, "IN01": True, "IN02": True, "IN03": False, "IN04": True, "IN05": True, + "IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "MID": True, + "OUT00": False, "OUT01": False, "OUT02": False, "OUT03": True, "OUT04": True, "OUT05": True, + "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True, +} + +LAYER20 = { + "BASE": True, + "IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True, + "IN06": True, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "MID": True, + "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, + "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": False, "OUT10": False, "OUT11": False, +} + +LAYER26 = { + "BASE": True, + "IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True, + "IN06": True, "IN07": True, "IN08": True, "IN09": True, "IN10": True, "IN11": True, + "MID": True, + "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, + "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True, +} + +assert len([v for v in LAYER12.values() if v]) == 12 +assert len([v for v in LAYER17.values() if v]) == 17 +assert len([v for v in LAYER20.values() if v]) == 20 +assert len([v for v in LAYER26.values() if v]) == 26 + +RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") + + +def get_lbw_block_index(lora_name: str, is_sdxl: bool = False) -> int: + # lbw block index is 0-based, but 0 for text encoder, so we return 0 for text encoder + if "text_model_encoder_" in lora_name: # LoRA for text encoder + return 0 + + # lbw block index is 1-based for U-Net, and no "input_blocks.0" in CompVis SD, so "input_blocks.1" have index 2 + block_idx = -1 # invalid lora name + if not is_sdxl: + NUM_OF_BLOCKS = 12 # up/down blocks + m = RE_UPDOWN.search(lora_name) + if m: + g = m.groups() + up_down = g[0] + i = int(g[1]) + j = int(g[3]) + if up_down == "down": + if g[2] == "resnets" or g[2] == "attentions": + idx = 3 * i + j + 1 + elif g[2] == "downsamplers": + idx = 3 * (i + 1) + else: + return block_idx # invalid lora name + elif up_down == "up": + if g[2] == "resnets" or g[2] == "attentions": + idx = 3 * i + j + elif g[2] == "upsamplers": + idx = 3 * i + 2 + else: + return block_idx # invalid lora name + + if g[0] == "down": + block_idx = 1 + idx # 1-based index, down block index + elif g[0] == "up": + block_idx = 1 + NUM_OF_BLOCKS + 1 + idx # 1-based index, num blocks, mid block, up block index + + elif "mid_block_" in lora_name: + block_idx = 1 + NUM_OF_BLOCKS # 1-based index, num blocks, mid block + else: + if lora_name.startswith("lora_unet_"): + name = lora_name[len("lora_unet_") :] + if name.startswith("time_embed_") or name.startswith("label_emb_"): # 1, No LoRA in sd-scripts + block_idx = 1 + elif name.startswith("input_blocks_"): # 1-8 to 2-9 + block_idx = 1 + int(name.split("_")[2]) + elif name.startswith("middle_block_"): # 10 + block_idx = 10 + elif name.startswith("output_blocks_"): # 0-8 to 11-19 + block_idx = 11 + int(name.split("_")[2]) + elif name.startswith("out_"): # 20, No LoRA in sd-scripts + block_idx = 20 + + return block_idx + def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": @@ -42,12 +145,34 @@ def save_to_file(file_name, state_dict, dtype, metadata): torch.save(state_dict, file_name) -def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): +def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype): logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") merged_sd = {} - v2 = None + v2 = None # This is meaning LoRA Metadata v2, Not meaning SD2 base_model = None - for model, ratio in zip(models, ratios): + + if lbws: + try: + # lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している + lbws = [json.loads(lbw) for lbw in lbws] + except Exception: + raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください") + assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください" + assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください" + assert all(len(lbw) in ACCEPTABLE for lbw in lbws), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください" + assert all(all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください" + + layer_num = len(lbws[0]) + is_sdxl = True if layer_num in SDXL_LAYER_NUM else False + FLAGS = { + "12": LAYER12.values(), + "17": LAYER17.values(), + "20": LAYER20.values(), + "26": LAYER26.values(), + }[str(layer_num)] + LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag] + + for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) @@ -57,6 +182,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty if base_model is None: base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) + if lbw: + lbw_weights = [1] * 26 + for index, value in zip(LBW_TARGET_IDX, lbw): + lbw_weights[index] = value + print(dict(zip(LAYER26.keys(), lbw_weights))) + # merge logger.info(f"merging...") for key in tqdm(list(lora_sd.keys())): @@ -93,6 +224,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty # W <- W + U * D scale = alpha / network_dim + if lbw: + index = get_lbw_block_index(key, is_sdxl) + is_lbw_target = index in LBW_TARGET_IDX + if is_lbw_target: + scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける + if device: # and isinstance(scale, torch.Tensor): scale = scale.to(device) @@ -170,6 +307,10 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty def merge(args): assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + if args.lbws: + assert len(args.models) == len(args.lbws), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください" + else: + args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく def str_to_dtype(p): if p == "float": @@ -187,7 +328,7 @@ def merge(args): new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank state_dict, metadata, v2, base_model = merge_lora_models( - args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype + args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype ) logger.info(f"calculating hashes and creating metadata...") @@ -237,6 +378,7 @@ def setup_parser() -> argparse.ArgumentParser: "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" ) parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率") parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") parser.add_argument( "--new_conv_rank", From f4a0bea6dce152e2210f611f94acfdfaa72068fe Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 13 Sep 2024 21:26:06 +0900 Subject: [PATCH 50/87] format by black --- networks/svd_merge_lora.py | 188 +++++++++++++++++++++++++++++-------- 1 file changed, 147 insertions(+), 41 deletions(-) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index 6e163aec..0decd904 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -11,8 +11,10 @@ from library import sai_model_spec, train_util import library.model_util as model_util import lora from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) CLAMP_QUANTILE = 0.99 @@ -22,38 +24,118 @@ SDXL_LAYER_NUM = [12, 20] LAYER12 = { "BASE": True, - "IN00": False, "IN01": False, "IN02": False, "IN03": False, "IN04": True, "IN05": True, - "IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "IN00": False, + "IN01": False, + "IN02": False, + "IN03": False, + "IN04": True, + "IN05": True, + "IN06": False, + "IN07": True, + "IN08": True, + "IN09": False, + "IN10": False, + "IN11": False, "MID": True, - "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, - "OUT06": False, "OUT07": False, "OUT08": False, "OUT09": False, "OUT10": False, "OUT11": False + "OUT00": True, + "OUT01": True, + "OUT02": True, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": False, + "OUT07": False, + "OUT08": False, + "OUT09": False, + "OUT10": False, + "OUT11": False, } LAYER17 = { "BASE": True, - "IN00": False, "IN01": True, "IN02": True, "IN03": False, "IN04": True, "IN05": True, - "IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "IN00": False, + "IN01": True, + "IN02": True, + "IN03": False, + "IN04": True, + "IN05": True, + "IN06": False, + "IN07": True, + "IN08": True, + "IN09": False, + "IN10": False, + "IN11": False, "MID": True, - "OUT00": False, "OUT01": False, "OUT02": False, "OUT03": True, "OUT04": True, "OUT05": True, - "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True, + "OUT00": False, + "OUT01": False, + "OUT02": False, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": True, + "OUT07": True, + "OUT08": True, + "OUT09": True, + "OUT10": True, + "OUT11": True, } LAYER20 = { "BASE": True, - "IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True, - "IN06": True, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "IN00": True, + "IN01": True, + "IN02": True, + "IN03": True, + "IN04": True, + "IN05": True, + "IN06": True, + "IN07": True, + "IN08": True, + "IN09": False, + "IN10": False, + "IN11": False, "MID": True, - "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, - "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": False, "OUT10": False, "OUT11": False, + "OUT00": True, + "OUT01": True, + "OUT02": True, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": True, + "OUT07": True, + "OUT08": True, + "OUT09": False, + "OUT10": False, + "OUT11": False, } LAYER26 = { "BASE": True, - "IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True, - "IN06": True, "IN07": True, "IN08": True, "IN09": True, "IN10": True, "IN11": True, + "IN00": True, + "IN01": True, + "IN02": True, + "IN03": True, + "IN04": True, + "IN05": True, + "IN06": True, + "IN07": True, + "IN08": True, + "IN09": True, + "IN10": True, + "IN11": True, "MID": True, - "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, - "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True, + "OUT00": True, + "OUT01": True, + "OUT02": True, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": True, + "OUT07": True, + "OUT08": True, + "OUT09": True, + "OUT10": True, + "OUT11": True, } assert len([v for v in LAYER12.values() if v]) == 12 @@ -145,6 +227,33 @@ def save_to_file(file_name, state_dict, dtype, metadata): torch.save(state_dict, file_name) +def format_lbws(lbws): + try: + # lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している + lbws = [json.loads(lbw) for lbw in lbws] + except Exception: + raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください") + assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください" + assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください" + assert all( + len(lbw) in ACCEPTABLE for lbw in lbws + ), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください" + assert all( + all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws + ), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください" + + layer_num = len(lbws[0]) + is_sdxl = True if layer_num in SDXL_LAYER_NUM else False + FLAGS = { + "12": LAYER12.values(), + "17": LAYER17.values(), + "20": LAYER20.values(), + "26": LAYER26.values(), + }[str(layer_num)] + LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag] + return lbws, is_sdxl, LBW_TARGET_IDX + + def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype): logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") merged_sd = {} @@ -152,25 +261,10 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer base_model = None if lbws: - try: - # lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している - lbws = [json.loads(lbw) for lbw in lbws] - except Exception: - raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください") - assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください" - assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください" - assert all(len(lbw) in ACCEPTABLE for lbw in lbws), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください" - assert all(all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください" - - layer_num = len(lbws[0]) - is_sdxl = True if layer_num in SDXL_LAYER_NUM else False - FLAGS = { - "12": LAYER12.values(), - "17": LAYER17.values(), - "20": LAYER20.values(), - "26": LAYER26.values(), - }[str(layer_num)] - LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag] + lbws, is_sdxl, LBW_TARGET_IDX = format_lbws(lbws) + else: + is_sdxl = False + LBW_TARGET_IDX = [] for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): logger.info(f"loading: {model}") @@ -186,7 +280,7 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer lbw_weights = [1] * 26 for index, value in zip(LBW_TARGET_IDX, lbw): lbw_weights[index] = value - print(dict(zip(LAYER26.keys(), lbw_weights))) + logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}") # merge logger.info(f"merging...") @@ -306,9 +400,13 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer def merge(args): - assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + assert len(args.models) == len( + args.ratios + ), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" if args.lbws: - assert len(args.models) == len(args.lbws), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください" + assert len(args.models) == len( + args.lbws + ), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください" else: args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく @@ -372,10 +470,16 @@ def setup_parser() -> argparse.ArgumentParser: help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", ) parser.add_argument( - "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + "--save_to", + type=str, + default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", ) parser.add_argument( - "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" + "--models", + type=str, + nargs="*", + help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors", ) parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率") @@ -386,7 +490,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ", ) - parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") + parser.add_argument( + "--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う" + ) parser.add_argument( "--no_metadata", action="store_true", From b755ebd0a4dd2967171b6b5909624325359a2aa0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 13 Sep 2024 21:29:31 +0900 Subject: [PATCH 51/87] add LBW support for SDXL merge LoRA --- README.md | 12 +++++- networks/sdxl_merge_lora.py | 75 ++++++++++++++++++++++++++++++++----- 2 files changed, 76 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index d5d2a7f7..0be2f9a7 100644 --- a/README.md +++ b/README.md @@ -139,9 +139,17 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Sep 13, 2024 / 2024-09-13: -- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580). Will be included in the next release. +- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580). +- `svd_merge_lora.py` now supports LBW. Thanks to terracottahaniwa. See PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) for details. +- `sdxl_merge_lora.py` also supports LBW. +- See [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) by hako-mikan for details on LBW. +- These will be included in the next release. -- `sdxl_merge_lora.py` が OFT をサポートしました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。次のリリースに含まれます。 +- `sdxl_merge_lora.py` が OFT をサポートされました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。 +- `svd_merge_lora.py` で LBW がサポートされました。PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) terracottahaniwa 氏に感謝します。 +- `sdxl_merge_lora.py` でも LBW がサポートされました。 +- LBW の詳細は hako-mikan 氏の [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) をご覧ください。 +- 以上は次回リリースに含まれます。 ### Jun 23, 2024 / 2024-06-23: diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index d5b6f7f3..62f5a87d 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -1,7 +1,9 @@ +import itertools import math import argparse import os import time +import concurrent.futures import torch from safetensors.torch import load_file, save_file from tqdm import tqdm @@ -9,13 +11,13 @@ from library import sai_model_spec, sdxl_model_util, train_util import library.model_util as model_util import lora import oft +from svd_merge_lora import format_lbws, get_lbw_block_index, LAYER26 from library.utils import setup_logging setup_logging() import logging logger = logging.getLogger(__name__) -import concurrent.futures def load_state_dict(file_name, dtype): @@ -47,6 +49,7 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): def detect_method_from_training_model(models, dtype): for model in models: + # TODO It is better to use key names to detect the method lora_sd, _ = load_state_dict(model, dtype) for key in tqdm(lora_sd.keys()): if "lora_up" in key or "lora_down" in key: @@ -55,15 +58,20 @@ def detect_method_from_training_model(models, dtype): return "OFT" -def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): - text_encoder1.to(merge_dtype) +def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws, merge_dtype): text_encoder1.to(merge_dtype) + text_encoder2.to(merge_dtype) unet.to(merge_dtype) # detect the method: OFT or LoRA_module method = detect_method_from_training_model(models, merge_dtype) logger.info(f"method:{method}") + if lbws: + lbws, _, LBW_TARGET_IDX = format_lbws(lbws) + else: + LBW_TARGET_IDX = [] + # create module map name_to_module = {} for i, root_module in enumerate([text_encoder1, text_encoder2, unet]): @@ -94,12 +102,18 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ lora_name = lora_name.replace(".", "_") name_to_module[lora_name] = child_module - for model, ratio in zip(models, ratios): + for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) logger.info(f"merging...") + if lbw: + lbw_weights = [1] * 26 + for index, value in zip(LBW_TARGET_IDX, lbw): + lbw_weights[index] = value + logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}") + if method == "LoRA": for key in tqdm(lora_sd.keys()): if "lora_down" in key: @@ -121,6 +135,12 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ alpha = lora_sd.get(alpha_key, dim) scale = alpha / dim + if lbw: + index = get_lbw_block_index(key, True) + is_lbw_target = index in LBW_TARGET_IDX + if is_lbw_target: + scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける + # W <- W + U * D weight = module.weight # logger.info(module_name, down_weight.size(), up_weight.size()) @@ -145,7 +165,6 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ elif method == "OFT": - multiplier = 1.0 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") for key in tqdm(lora_sd.keys()): @@ -183,6 +202,13 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ block_size = out_dim // dim constraint = (0 if alpha is None else alpha) * out_dim + multiplier = 1 + if lbw: + index = get_lbw_block_index(key, False) + is_lbw_target = index in LBW_TARGET_IDX + if is_lbw_target: + multiplier *= lbw_weights[index] + block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=constraint) @@ -213,17 +239,35 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys()))) -def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): +def merge_lora_models(models, ratios, lbws, merge_dtype, concat=False, shuffle=False): base_alphas = {} # alpha for merged model base_dims = {} + # detect the method: OFT or LoRA_module + method = detect_method_from_training_model(models, merge_dtype) + if method == "OFT": + raise ValueError( + "OFT model is not supported for merging OFT models. / OFTモデルはOFTモデル同士のマージには対応していません" + ) + + if lbws: + lbws, _, LBW_TARGET_IDX = format_lbws(lbws) + else: + LBW_TARGET_IDX = [] + merged_sd = {} v2 = None base_model = None - for model, ratio in zip(models, ratios): + for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) + if lbw: + lbw_weights = [1] * 26 + for index, value in zip(LBW_TARGET_IDX, lbw): + lbw_weights[index] = value + logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}") + if lora_metadata is not None: if v2 is None: v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず @@ -277,6 +321,12 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): scale = math.sqrt(alpha / base_alpha) * ratio scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 + if lbw: + index = get_lbw_block_index(key, True) + is_lbw_target = index in LBW_TARGET_IDX + if is_lbw_target: + scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける + if key in merged_sd: assert ( merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None @@ -329,6 +379,12 @@ def merge(args): assert len(args.models) == len( args.ratios ), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + if args.lbws: + assert len(args.models) == len( + args.lbws + ), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください" + else: + args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく def str_to_dtype(p): if p == "float": @@ -356,7 +412,7 @@ def merge(args): ckpt_info, ) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu") - merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype) + merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, args.lbws, merge_dtype) if args.no_metadata: sai_metadata = None @@ -372,7 +428,7 @@ def merge(args): args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype ) else: - state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) + state_dict, metadata = merge_lora_models(args.models, args.ratios, args.lbws, merge_dtype, args.concat, args.shuffle) logger.info(f"calculating hashes and creating metadata...") @@ -427,6 +483,7 @@ def setup_parser() -> argparse.ArgumentParser: help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors", ) parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率") parser.add_argument( "--no_metadata", action="store_true", From 93d9fbf60761fc1158e37f45f0d0c142913d70f5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 13 Sep 2024 22:37:11 +0900 Subject: [PATCH 52/87] improve OFT implementation closes #944 --- README.md | 26 +++++++++- gen_img.py | 3 +- networks/check_lora_weights.py | 2 +- networks/oft.py | 94 ++++++++++++++++++++++------------ 4 files changed, 88 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 0130ccff..def528a2 100644 --- a/README.md +++ b/README.md @@ -143,7 +143,31 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - transformers, accelerate and huggingface_hub are updated. - If you encounter any issues, please report them. -- en: The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds! +- Improvements in OFT (Orthogonal Finetuning) Implementation + 1. Optimization of Calculation Order: + - Changed the calculation order in the forward method from (Wx)R to W(xR). + - This has improved computational efficiency and processing speed. + 2. Correction of Bias Application: + - In the previous implementation, R was incorrectly applied to the bias. + - The new implementation now correctly handles bias by using F.conv2d and F.linear. + 3. Efficiency Enhancement in Matrix Operations: + - Introduced einsum in both the forward and merge_to methods. + - This has optimized matrix operations, resulting in further speed improvements. + 4. Proper Handling of Data Types: + - Improved to use torch.float32 during calculations and convert results back to the original data type. + - This maintains precision while ensuring compatibility with the original model. + 5. Unified Processing for Conv2d and Linear Layers: + - Implemented a consistent method for applying OFT to both layer types. + - These changes have made the OFT implementation more efficient and accurate, potentially leading to improved model performance and training stability. + + - Additional Information + * Recommended α value for OFT constraint: We recommend using α values between 1e-4 and 1e-2. This differs slightly from the original implementation of "(α\*out_dim\*out_dim)". Our implementation uses "(α\*out_dim)", hence we recommend higher values than the 1e-5 suggested in the original implementation. + + * Performance Improvement: Training speed has been improved by approximately 30%. + + * Inference Environment: This implementation is compatible with and operates within Stable Diffusion web UI (SD1/2 and SDXL). + +- The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds! - See the [transformers documentation](https://huggingface.co/docs/transformers/v4.44.2/en/main_classes/optimizer_schedules#schedules) for details on each scheduler. - `--lr_warmup_steps` and `--lr_decay_steps` can now be specified as a ratio of the number of training steps, not just the step value. Example: `--lr_warmup_steps=0.1` or `--lr_warmup_steps=10%`, etc. diff --git a/gen_img.py b/gen_img.py index d0a8f814..59bcd5b0 100644 --- a/gen_img.py +++ b/gen_img.py @@ -86,7 +86,8 @@ CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" """ -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): +# def replace_unet_modules(unet: diffusers.models.unets.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): +def replace_unet_modules(unet, mem_eff_attn, xformers, sdpa): if mem_eff_attn: logger.info("Enable memory efficient attention for U-Net") diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py index 794659c9..f8eab53b 100644 --- a/networks/check_lora_weights.py +++ b/networks/check_lora_weights.py @@ -18,7 +18,7 @@ def main(file): keys = list(sd.keys()) for key in keys: - if "lora_up" in key or "lora_down" in key: + if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key or "oft_" in key: values.append((key, sd[key])) print(f"number of LoRA modules: {len(values)}") diff --git a/networks/oft.py b/networks/oft.py index 461a9869..6321def3 100644 --- a/networks/oft.py +++ b/networks/oft.py @@ -4,13 +4,17 @@ import math import os from typing import Dict, List, Optional, Tuple, Type, Union from diffusers import AutoencoderKL +import einops from transformers import CLIPTextModel import numpy as np import torch +import torch.nn.functional as F import re from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -45,11 +49,16 @@ class OFTModule(torch.nn.Module): if type(alpha) == torch.Tensor: alpha = alpha.detach().numpy() - self.constraint = alpha * out_dim + + # constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility + # original alpha is 1e-6, so we use 1e-3 or 1e-4 for alpha + self.constraint = alpha * out_dim + self.register_buffer("alpha", torch.tensor(alpha)) self.block_size = out_dim // self.num_blocks self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size)) + self.I = torch.eye(self.block_size).unsqueeze(0).repeat(self.num_blocks, 1, 1) # cpu self.out_dim = out_dim self.shape = org_module.weight.shape @@ -69,27 +78,36 @@ class OFTModule(torch.nn.Module): norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) - block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I - R = torch.block_diag(*block_R_weighted) - - return R + if self.I.device != block_Q.device: + self.I = self.I.to(block_Q.device) + I = self.I + block_R = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) + block_R_weighted = self.multiplier * (block_R - I) + I + return block_R_weighted def forward(self, x, scale=None): - x = self.org_forward(x) if self.multiplier == 0.0: - return x + return self.org_forward(x) + org_module = self.org_module[0] + org_dtype = x.dtype - R = self.get_weight().to(x.device, dtype=x.dtype) - if x.dim() == 4: - x = x.permute(0, 2, 3, 1) - x = torch.matmul(x, R) - x = x.permute(0, 3, 1, 2) - else: - x = torch.matmul(x, R) - return x + R = self.get_weight().to(torch.float32) + W = org_module.weight.to(torch.float32) + + if len(W.shape) == 4: # Conv2d + W_reshaped = einops.rearrange(W, "(k n) ... -> k n ...", k=self.num_blocks, n=self.block_size) + RW = torch.einsum("k n m, k n ... -> k m ...", R, W_reshaped) + RW = einops.rearrange(RW, "k m ... -> (k m) ...") + result = F.conv2d( + x, RW.to(org_dtype), org_module.bias, org_module.stride, org_module.padding, org_module.dilation, org_module.groups + ) + else: # Linear + W_reshaped = einops.rearrange(W, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size) + RW = torch.einsum("k n m, k n p -> k m p", R, W_reshaped) + RW = einops.rearrange(RW, "k m p -> (k m) p") + result = F.linear(x, RW.to(org_dtype), org_module.bias) + return result class OFTInfModule(OFTModule): @@ -115,18 +133,19 @@ class OFTInfModule(OFTModule): return self.org_forward(x) return super().forward(x, scale) - def merge_to(self, multiplier=None, sign=1): - R = self.get_weight(multiplier) * sign - + def merge_to(self, multiplier=None): # get org weight org_sd = self.org_module[0].state_dict() - org_weight = org_sd["weight"] - R = R.to(org_weight.device, dtype=org_weight.dtype) + org_weight = org_sd["weight"].to(torch.float32) - if org_weight.dim() == 4: - weight = torch.einsum("oihw, op -> pihw", org_weight, R) - else: - weight = torch.einsum("oi, op -> pi", org_weight, R) + R = self.get_weight(multiplier).to(torch.float32) + + weight = org_weight.reshape(self.num_blocks, self.block_size, -1) + weight = torch.einsum("k n m, k n ... -> k m ...", R, weight) + weight = weight.reshape(org_weight.shape) + + # convert back to original dtype + weight = weight.to(org_sd["weight"].dtype) # set weight to org_module org_sd["weight"] = weight @@ -145,8 +164,16 @@ def create_network( ): if network_dim is None: network_dim = 4 # default - if network_alpha is None: - network_alpha = 1.0 + if network_alpha is None: # should be set + logger.info( + "network_alpha is not set, use default value 1e-3 / network_alphaが設定されていないのでデフォルト値 1e-3 を使用します" + ) + network_alpha = 1e-3 + elif network_alpha >= 1: + logger.warning( + "network_alpha is too large (>=1, maybe default value is too large), please consider to set smaller value like 1e-3" + " / network_alphaが大きすぎるようです(>=1, デフォルト値が大きすぎる可能性があります)。1e-3のような小さな値を推奨" + ) enable_all_linear = kwargs.get("enable_all_linear", None) enable_conv = kwargs.get("enable_conv", None) @@ -190,12 +217,11 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh else: if dim is None: dim = param.size()[0] - if has_conv2d is None and param.dim() == 4: + if has_conv2d is None and "in_layers_2" in name: has_conv2d = True - if all_linear is None: - if param.dim() == 3 and "attn" not in name: - all_linear = True - if dim is not None and alpha is not None and has_conv2d is not None: + if all_linear is None and "_ff_" in name: + all_linear = True + if dim is not None and alpha is not None and has_conv2d is not None and all_linear is not None: break if has_conv2d is None: has_conv2d = False @@ -241,7 +267,7 @@ class OFTNetwork(torch.nn.Module): self.alpha = alpha logger.info( - f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}" + f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}, enable_all_linear: {enable_all_linear}" ) # create module instances From 2d8ee3c28007393386528cfeec0a9b714dafd85b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 14 Sep 2024 15:48:16 +0900 Subject: [PATCH 53/87] OFT for FLUX.1 --- flux_minimal_inference.py | 20 +- networks/lora_flux.py | 6 +- networks/oft.py | 2 +- networks/oft_flux.py | 482 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 504 insertions(+), 6 deletions(-) create mode 100644 networks/oft_flux.py diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index de607c52..2f1b9a37 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -14,9 +14,11 @@ from tqdm import tqdm from PIL import Image import accelerate from transformers import CLIPTextModel +from safetensors.torch import load_file from library import device_utils from library.device_utils import init_ipex, get_preferred_device +from networks import oft_flux init_ipex() @@ -405,7 +407,7 @@ if __name__ == "__main__": type=str, nargs="*", default=[], - help="LoRA weights, only supports networks.lora_flux, each argument is a `path;multiplier` (semi-colon separated)", + help="LoRA weights, only supports networks.lora_flux and lora_oft, each argument is a `path;multiplier` (semi-colon separated)", ) parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") parser.add_argument("--width", type=int, default=target_width) @@ -482,9 +484,19 @@ if __name__ == "__main__": else: multiplier = 1.0 - lora_model, weights_sd = lora_flux.create_network_from_weights( - multiplier, weights_file, ae, [clip_l, t5xxl], model, None, True - ) + weights_sd = load_file(weights_file) + is_lora = is_oft = False + for key in weights_sd.keys(): + if key.startswith("lora"): + is_lora = True + if key.startswith("oft"): + is_oft = True + if is_lora or is_oft: + break + + module = lora_flux if is_lora else oft_flux + lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True) + if args.merge_lora_weights: lora_model.merge_to([clip_l, t5xxl], model, weights_sd) else: diff --git a/networks/lora_flux.py b/networks/lora_flux.py index dd267de0..ea7df8b4 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -41,7 +41,11 @@ class LoRAModule(torch.nn.Module): module_dropout=None, split_dims: Optional[List[int]] = None, ): - """if alpha == 0 or None, alpha is rank (no scaling).""" + """ + if alpha == 0 or None, alpha is rank (no scaling). + + split_dims is used to mimic the split qkv of FLUX as same as Diffusers + """ super().__init__() self.lora_name = lora_name diff --git a/networks/oft.py b/networks/oft.py index 6321def3..0c3a5393 100644 --- a/networks/oft.py +++ b/networks/oft.py @@ -51,7 +51,7 @@ class OFTModule(torch.nn.Module): alpha = alpha.detach().numpy() # constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility - # original alpha is 1e-6, so we use 1e-3 or 1e-4 for alpha + # original alpha is 1e-5, so we use 1e-2 or 1e-4 for alpha self.constraint = alpha * out_dim self.register_buffer("alpha", torch.tensor(alpha)) diff --git a/networks/oft_flux.py b/networks/oft_flux.py new file mode 100644 index 00000000..27b8b637 --- /dev/null +++ b/networks/oft_flux.py @@ -0,0 +1,482 @@ +# OFT network module + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +import einops +from transformers import CLIPTextModel +import numpy as np +import torch +import torch.nn.functional as F +import re +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class OFTModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + oft_name, + org_module: torch.nn.Module, + multiplier=1.0, + dim=4, + alpha=1, + split_dims: Optional[List[int]] = None, + ): + """ + dim -> num blocks + alpha -> constraint + + split_dims is used to mimic the split qkv of FLUX as same as Diffusers + """ + super().__init__() + self.oft_name = oft_name + self.num_blocks = dim + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().numpy() + self.register_buffer("alpha", torch.tensor(alpha)) + + # No conv2d in FLUX + # if "Linear" in org_module.__class__.__name__: + self.out_dim = org_module.out_features + # elif "Conv" in org_module.__class__.__name__: + # out_dim = org_module.out_channels + + if split_dims is None: + split_dims = [self.out_dim] + else: + assert sum(split_dims) == self.out_dim, "sum of split_dims must be equal to out_dim" + self.split_dims = split_dims + + # assert all dim is divisible by num_blocks + for split_dim in self.split_dims: + assert split_dim % self.num_blocks == 0, "split_dim must be divisible by num_blocks" + + self.constraint = [alpha * split_dim for split_dim in self.split_dims] + self.block_size = [split_dim // self.num_blocks for split_dim in self.split_dims] + self.oft_blocks = torch.nn.ParameterList( + [torch.nn.Parameter(torch.zeros(self.num_blocks, block_size, block_size)) for block_size in self.block_size] + ) + self.I = [torch.eye(block_size).unsqueeze(0).repeat(self.num_blocks, 1, 1) for block_size in self.block_size] + + self.shape = org_module.weight.shape + self.multiplier = multiplier + self.org_module = [org_module] # moduleにならないようにlistに入れる + + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + if self.I[0].device != self.oft_blocks[0].device: + self.I = [I.to(self.oft_blocks[0].device) for I in self.I] + + block_R_weighted_list = [] + for i in range(len(self.oft_blocks)): + block_Q = self.oft_blocks[i] - self.oft_blocks[i].transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint[i]) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + + I = self.I[i] + block_R = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) + block_R_weighted = self.multiplier * (block_R - I) + I + + block_R_weighted_list.append(block_R_weighted) + + return block_R_weighted_list + + def forward(self, x, scale=None): + if self.multiplier == 0.0: + return self.org_forward(x) + + org_module = self.org_module[0] + org_dtype = x.dtype + + R = self.get_weight() + W = org_module.weight.to(torch.float32) + B = org_module.bias.to(torch.float32) + + # split W to match R + results = [] + d2 = 0 + for i in range(len(R)): + d1 = d2 + d2 += self.split_dims[i] + + W1 = W[d1:d2] + W_reshaped = einops.rearrange(W1, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size[i]) + RW_1 = torch.einsum("k n m, k n p -> k m p", R[i], W_reshaped) + RW_1 = einops.rearrange(RW_1, "k m p -> (k m) p") + + B1 = B[d1:d2] + result = F.linear(x, RW_1.to(org_dtype), B1.to(org_dtype)) + results.append(result) + + result = torch.cat(results, dim=-1) + return result + + +class OFTInfModule(OFTModule): + def __init__( + self, + oft_name, + org_module: torch.nn.Module, + multiplier=1.0, + dim=4, + alpha=1, + split_dims: Optional[List[int]] = None, + **kwargs, + ): + # no dropout for inference + super().__init__(oft_name, org_module, multiplier, dim, alpha, split_dims) + self.enabled = True + self.network: OFTNetwork = None + + def set_network(self, network): + self.network = network + + def forward(self, x, scale=None): + if not self.enabled: + return self.org_forward(x) + return super().forward(x, scale) + + def merge_to(self, multiplier=None): + # get org weight + org_sd = self.org_module[0].state_dict() + W = org_sd["weight"].to(torch.float32) + R = self.get_weight(multiplier).to(torch.float32) + + d2 = 0 + W_list = [] + for i in range(len(self.oft_blocks)): + d1 = d2 + d2 += self.split_dims[i] + + W1 = W[d1:d2] + W_reshaped = einops.rearrange(W1, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size[i]) + W1 = torch.einsum("k n m, k n p -> k m p", R[i], W_reshaped) + W1 = einops.rearrange(W1, "k m p -> (k m) p") + + W_list.append(W1) + + W = torch.cat(W_list, dim=-1) + + # convert back to original dtype + W = W.to(org_sd["weight"].dtype) + + # set weight to org_module + org_sd["weight"] = W + self.org_module[0].load_state_dict(org_sd) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: AutoencoderKL, + text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], + unet, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: # should be set + logger.info( + "network_alpha is not set, use default value 1e-3 / network_alphaが設定されていないのでデフォルト値 1e-3 を使用します" + ) + network_alpha = 1e-3 + elif network_alpha >= 1: + logger.warning( + "network_alpha is too large (>=1, maybe default value is too large), please consider to set smaller value like 1e-3" + " / network_alphaが大きすぎるようです(>=1, デフォルト値が大きすぎる可能性があります)。1e-3のような小さな値を推奨" + ) + + # attn only or all linear (FFN) layers + enable_all_linear = kwargs.get("enable_all_linear", None) + # enable_conv = kwargs.get("enable_conv", None) + if enable_all_linear is not None: + enable_all_linear = bool(enable_all_linear) + # if enable_conv is not None: + # enable_conv = bool(enable_conv) + + network = OFTNetwork( + text_encoder, + unet, + multiplier=multiplier, + dim=network_dim, + alpha=network_alpha, + enable_all_linear=enable_all_linear, + varbose=True, + ) + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # check dim, alpha and if weights have for conv2d + dim = None + alpha = None + all_linear = None + for name, param in weights_sd.items(): + if name.endswith(".alpha"): + if alpha is None: + alpha = param.item() + elif "qkv" in name: + continue # ignore qkv + else: + if dim is None: + dim = param.size()[0] + if all_linear is None and "_mlp" in name: + all_linear = True + if dim is not None and alpha is not None and all_linear is not None: + break + if all_linear is None: + all_linear = False + + module_class = OFTInfModule if for_inference else OFTModule + network = OFTNetwork( + text_encoder, + unet, + multiplier=multiplier, + dim=dim, + alpha=alpha, + enable_all_linear=all_linear, + module_class=module_class, + ) + return network, weights_sd + + +class OFTNetwork(torch.nn.Module): + FLUX_TARGET_REPLACE_MODULE_ALL_LINEAR = ["DoubleStreamBlock", "SingleStreamBlock"] + FLUX_TARGET_REPLACE_MODULE_ATTN_ONLY = ["SelfAttention"] + OFT_PREFIX_UNET = "oft_unet" + + def __init__( + self, + text_encoder: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + dim: int = 4, + alpha: float = 1, + enable_all_linear: Optional[bool] = False, + module_class: Union[Type[OFTModule], Type[OFTInfModule]] = OFTModule, + varbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.train_t5xxl = False # make compatible with LoRA + self.multiplier = multiplier + + self.dim = dim + self.alpha = alpha + + logger.info( + f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_all_linear: {enable_all_linear}" + ) + + # create module instances + def create_modules( + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[OFTModule]: + prefix = self.OFT_PREFIX_UNET + ofts = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = "Linear" in child_module.__class__.__name__ + + if is_linear: + oft_name = prefix + "." + name + "." + child_name + oft_name = oft_name.replace(".", "_") + # logger.info(oft_name) + + if "double" in oft_name and "qkv" in oft_name: + split_dims = [3072] * 3 + elif "single" in oft_name and "linear1" in oft_name: + split_dims = [3072] * 3 + [12288] + else: + split_dims = None + + oft = module_class(oft_name, child_module, self.multiplier, dim, alpha, split_dims) + ofts.append(oft) + return ofts + + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + if enable_all_linear: + target_modules = OFTNetwork.FLUX_TARGET_REPLACE_MODULE_ALL_LINEAR + else: + target_modules = OFTNetwork.FLUX_TARGET_REPLACE_MODULE_ATTN_ONLY + + self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules) + logger.info(f"create OFT for Flux: {len(self.unet_ofts)} modules.") + + # assertion + names = set() + for oft in self.unet_ofts: + assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}" + names.add(oft.oft_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for oft in self.unet_ofts: + oft.multiplier = self.multiplier + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): + assert apply_unet, "apply_unet must be True" + + for oft in self.unet_ofts: + oft.apply_to() + self.add_module(oft.oft_name, oft) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoder, unet, weights_sd, dtype, device): + logger.info("enable OFT for U-Net") + + for oft in self.unet_ofts: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(oft.oft_name): + sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key] + oft.load_state_dict(sd_for_lora, False) + oft.merge_to() + + logger.info(f"weights are merged") + + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + self.requires_grad_(True) + all_params = [] + + def enumerate_params(ofts): + params = [] + for oft in ofts: + params.extend(oft.parameters()) + + # logger.info num of params + num_params = 0 + for p in params: + num_params += p.numel() + logger.info(f"OFT params: {num_params}") + return params + + param_data = {"params": enumerate_params(self.unet_ofts)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) + + return all_params + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + oft.merge_to() + # sd = org_module.state_dict() + # org_weight = sd["weight"] + # lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype) + # sd["weight"] = org_weight + lora_weight + # assert sd["weight"].shape == org_weight.shape + # org_module.load_state_dict(sd) + + org_module._lora_restored = False + oft.enabled = False From c9ff4de90597e933b441502d45c175fe46b99714 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 14 Sep 2024 22:17:52 +0900 Subject: [PATCH 54/87] Add support for specifying rank for each layer in FLUX.1 --- README.md | 61 ++++++++++++++++++++++++ networks/lora_flux.py | 107 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 161 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 6e32fa31..9a979479 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 14, 2024: +- You can now specify the rank for each layer in FLUX.1. See [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) for details. +- OFT is now supported with FLUX.1. See [FLUX.1 OFT training](#flux1-oft-training) for details. + Sep 11, 2024: Logging to wandb is improved. See PR [#1576](https://github.com/kohya-ss/sd-scripts/pull/1576) for details. Thanks to p1atdev! @@ -46,6 +50,7 @@ Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. ` - [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training) - [Inference for FLUX.1 LoRA model](#inference-for-flux1-lora-model) - [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) +- [FLUX.1 OFT training](#flux1-oft-training) - [FLUX.1 fine-tuning](#flux1-fine-tuning) - [Key Features for FLUX.1 fine-tuning](#key-features-for-flux1-fine-tuning) - [Extract LoRA from FLUX.1 Models](#extract-lora-from-flux1-models) @@ -191,6 +196,62 @@ In the implementation of Black Forest Labs' model, the projection layers of q/k/ The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large. +#### Specify rank for each layer in FLUX.1 + +You can specify the rank for each layer in FLUX.1 by specifying the following network_args. If you specify `0`, LoRA will not be applied to that layer. + +When network_args is not specified, the default value (`network_dim`) is applied, same as before. + +|network_args|target layer| +|---|---| +|img_attn_dim|img_attn in DoubleStreamBlock| +|txt_attn_dim|txt_attn in DoubleStreamBlock| +|img_mlp_dim|img_mlp in DoubleStreamBlock| +|txt_mlp_dim|txt_mlp in DoubleStreamBlock| +|img_mod_dim|img_mod in DoubleStreamBlock| +|txt_mod_dim|txt_mod in DoubleStreamBlock| +|single_dim|linear1 and linear2 in SingleStreamBlock| +|single_mod_dim|modulation in SingleStreamBlock| + +example: +``` +--network_args "img_attn_dim=4" "img_mlp_dim=8" "txt_attn_dim=2" "txt_mlp_dim=2" +"img_mod_dim=2" "txt_mod_dim=2" "single_dim=4" "single_mod_dim=2" +``` + +You can apply LoRA to the conditioning layers of Flux by specifying `in_dims` in network_args. When specifying, be sure to specify 5 numbers in `[]` as a comma-separated list. + +example: +``` +--network_args "in_dims=[4,2,2,2,4]" +``` + +Each number corresponds to `img_in`, `time_in`, `vector_in`, `guidance_in`, `txt_in`. The above example applies LoRA to all conditioning layers, with rank 4 for `img_in`, 2 for `time_in`, `vector_in`, `guidance_in`, and 4 for `txt_in`. + +If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,0,4]` applies LoRA only to `img_in` and `txt_in`. + +### FLUX.1 OFT training + +You can train OFT with almost the same options as LoRA, such as `--timestamp_sampling`. The following points are different. + +- Change `--network_module` from `networks.lora_flux` to `networks.oft_flux`. +- `--network_dim` is the number of OFT blocks. Unlike LoRA rank, the smaller the dim, the larger the model. We recommend about 64 or 128. Please make the output dimension of the target layer of OFT divisible by the value of `--network_dim` (an error will occur if it is not divisible). Valid values are 64, 128, 256, 512, 1024, etc. +- `--network_alpha` is treated as a constraint for OFT. We recommend about 1e-2 to 1e-4. The default value when omitted is 1, which is too large, so be sure to specify it. +- CLIP/T5XXL is not supported. Specify `--network_train_unet_only`. +- `--network_args` specifies the hyperparameters of OFT. The following are valid: + - Specify `enable_all_linear=True` to target all linear connections in the MLP layer. The default is False, which targets only attention. + +Currently, there is no environment to infer FLUX.1 OFT. Inference is only possible with `flux_minimal_inference.py` (specify OFT model with `--lora`). + +Sample command is below. It will work with 24GB VRAM GPUs with the batch size of 1. + +``` +--network_module networks.oft_flux --network_dim 128 --network_alpha 1e-3 +--network_args "enable_all_linear=True" --learning_rate 1e-5 +``` + +The training can be done with 16GB VRAM GPUs without `--enable_all_linear` option and with Adafactor optimizer. + ### Inference for FLUX.1 with LoRA model The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. diff --git a/networks/lora_flux.py b/networks/lora_flux.py index ea7df8b4..a34cde1a 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -316,6 +316,44 @@ def create_network( else: conv_alpha = float(conv_alpha) + # attn dim, mlp dim: only for DoubleStreamBlock. SingleStreamBlock is not supported because of combined qkv + img_attn_dim = kwargs.get("img_attn_dim", None) + txt_attn_dim = kwargs.get("txt_attn_dim", None) + img_mlp_dim = kwargs.get("img_mlp_dim", None) + txt_mlp_dim = kwargs.get("txt_mlp_dim", None) + img_mod_dim = kwargs.get("img_mod_dim", None) + txt_mod_dim = kwargs.get("txt_mod_dim", None) + single_dim = kwargs.get("single_dim", None) # SingleStreamBlock + single_mod_dim = kwargs.get("single_mod_dim", None) # SingleStreamBlock + if img_attn_dim is not None: + img_attn_dim = int(img_attn_dim) + if txt_attn_dim is not None: + txt_attn_dim = int(txt_attn_dim) + if img_mlp_dim is not None: + img_mlp_dim = int(img_mlp_dim) + if txt_mlp_dim is not None: + txt_mlp_dim = int(txt_mlp_dim) + if img_mod_dim is not None: + img_mod_dim = int(img_mod_dim) + if txt_mod_dim is not None: + txt_mod_dim = int(txt_mod_dim) + if single_dim is not None: + single_dim = int(single_dim) + if single_mod_dim is not None: + single_mod_dim = int(single_mod_dim) + type_dims = [img_attn_dim, txt_attn_dim, img_mlp_dim, txt_mlp_dim, img_mod_dim, txt_mod_dim, single_dim, single_mod_dim] + if all([d is None for d in type_dims]): + type_dims = None + + # in_dims [img, time, vector, guidance, txt] + in_dims = kwargs.get("in_dims", None) + if in_dims is not None: + in_dims = in_dims.strip() + if in_dims.startswith("[") and in_dims.endswith("]"): + in_dims = in_dims[1:-1] + in_dims = [int(d) for d in in_dims.split(",")] # is it better to use ast.literal_eval? + assert len(in_dims) == 5, f"invalid in_dims: {in_dims}, must be 5 dimensions (img, time, vector, guidance, txt)" + # rank/module dropout rank_dropout = kwargs.get("rank_dropout", None) if rank_dropout is not None: @@ -339,6 +377,11 @@ def create_network( if train_t5xxl is not None: train_t5xxl = True if train_t5xxl == "True" else False + # verbose + verbose = kwargs.get("verbose", False) + if verbose is not None: + verbose = True if verbose == "True" else False + # すごく引数が多いな ( ^ω^)・・・ network = LoRANetwork( text_encoders, @@ -354,7 +397,9 @@ def create_network( train_blocks=train_blocks, split_qkv=split_qkv, train_t5xxl=train_t5xxl, - varbose=True, + type_dims=type_dims, + in_dims=in_dims, + verbose=verbose, ) loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) @@ -462,7 +507,9 @@ class LoRANetwork(torch.nn.Module): train_blocks: Optional[str] = None, split_qkv: bool = False, train_t5xxl: bool = False, - varbose: Optional[bool] = False, + type_dims: Optional[List[int]] = None, + in_dims: Optional[List[int]] = None, + verbose: Optional[bool] = False, ) -> None: super().__init__() self.multiplier = multiplier @@ -478,12 +525,17 @@ class LoRANetwork(torch.nn.Module): self.split_qkv = split_qkv self.train_t5xxl = train_t5xxl + self.type_dims = type_dims + self.in_dims = in_dims + self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None self.loraplus_text_encoder_lr_ratio = None if modules_dim is not None: logger.info(f"create LoRA network from weights") + self.in_dims = [0] * 5 # create in_dims + # verbose = True else: logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") logger.info( @@ -502,7 +554,12 @@ class LoRANetwork(torch.nn.Module): # create module instances def create_modules( - is_flux: bool, text_encoder_idx: Optional[int], root_module: torch.nn.Module, target_replace_modules: List[str] + is_flux: bool, + text_encoder_idx: Optional[int], + root_module: torch.nn.Module, + target_replace_modules: List[str], + filter: Optional[str] = None, + default_dim: Optional[int] = None, ) -> List[LoRAModule]: prefix = ( self.LORA_PREFIX_FLUX @@ -513,16 +570,22 @@ class LoRANetwork(torch.nn.Module): loras = [] skipped = [] for name, module in root_module.named_modules(): - if module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None or module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None: # dirty hack for all modules + module = root_module # search all modules + for child_name, child_module in module.named_modules(): is_linear = child_module.__class__.__name__ == "Linear" is_conv2d = child_module.__class__.__name__ == "Conv2d" is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) if is_linear or is_conv2d: - lora_name = prefix + "." + name + "." + child_name + lora_name = prefix + "." + (name + "." if name else "") + child_name lora_name = lora_name.replace(".", "_") + if filter is not None and not filter in lora_name: + continue + dim = None alpha = None @@ -534,8 +597,25 @@ class LoRANetwork(torch.nn.Module): else: # 通常、すべて対象とする if is_linear or is_conv2d_1x1: - dim = self.lora_dim + dim = default_dim if default_dim is not None else self.lora_dim alpha = self.alpha + + if type_dims is not None: + identifier = [ + ("img_attn",), + ("txt_attn",), + ("img_mlp",), + ("txt_mlp",), + ("img_mod",), + ("txt_mod",), + ("single_blocks", "linear"), + ("modulation",), + ] + for i, d in enumerate(type_dims): + if d is not None and all([id in lora_name for id in identifier[i]]): + dim = d + break + elif self.conv_lora_dim is not None: dim = self.conv_lora_dim alpha = self.conv_alpha @@ -566,6 +646,9 @@ class LoRANetwork(torch.nn.Module): split_dims=split_dims, ) loras.append(lora) + + if target_replace_modules is None: + break # all modules are searched return loras, skipped # create LoRA for text encoder @@ -594,10 +677,20 @@ class LoRANetwork(torch.nn.Module): self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules) + + # img, time, vector, guidance, txt + if self.in_dims: + for filter, in_dim in zip(["_img_in", "_time_in", "_vector_in", "_guidance_in", "_txt_in"], self.in_dims): + loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim) + self.unet_loras.extend(loras) + logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.") + if verbose: + for lora in self.unet_loras: + logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}") skipped = skipped_te + skipped_un - if varbose and len(skipped) > 0: + if verbose and len(skipped) > 0: logger.warning( f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) From 6445bb2bc974cec51256ae38c1be0900e90e6f87 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 14 Sep 2024 22:37:26 +0900 Subject: [PATCH 55/87] update README --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9a979479..c94ea359 100644 --- a/README.md +++ b/README.md @@ -213,10 +213,12 @@ When network_args is not specified, the default value (`network_dim`) is applied |single_dim|linear1 and linear2 in SingleStreamBlock| |single_mod_dim|modulation in SingleStreamBlock| +`"verbose=True"` is also available for debugging. It shows the rank of each layer. + example: ``` --network_args "img_attn_dim=4" "img_mlp_dim=8" "txt_attn_dim=2" "txt_mlp_dim=2" -"img_mod_dim=2" "txt_mod_dim=2" "single_dim=4" "single_mod_dim=2" +"img_mod_dim=2" "txt_mod_dim=2" "single_dim=4" "single_mod_dim=2" "verbose=True" ``` You can apply LoRA to the conditioning layers of Flux by specifying `in_dims` in network_args. When specifying, be sure to specify 5 numbers in `[]` as a comma-separated list. From 9f44ef133083c530874c6cf022a4de8fda3edae2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 15 Sep 2024 13:52:23 +0900 Subject: [PATCH 56/87] add diffusers to FLUX.1 conversion script --- README.md | 19 ++- tools/convert_diffusers_to_flux.py | 223 +++++++++++++++++++++++++++++ 2 files changed, 241 insertions(+), 1 deletion(-) create mode 100644 tools/convert_diffusers_to_flux.py diff --git a/README.md b/README.md index c94ea359..7d6c336e 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,12 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 15, 2024: + +Added a script `convert_diffusers_to_flux.py` to convert Diffusers format FLUX.1 models (checkpoints) to BFL format. See `--help` for usage. Only Flux models are supported. AE/CLIP/T5XXL are not supported. + +The implementation is based on 2kpr's code. Thanks to 2kpr! + Sep 14, 2024: - You can now specify the rank for each layer in FLUX.1. See [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) for details. - OFT is now supported with FLUX.1. See [FLUX.1 OFT training](#flux1-oft-training) for details. @@ -57,6 +63,7 @@ Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. ` - [Convert FLUX LoRA](#convert-flux-lora) - [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) - [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) +- [Convert Diffusers to FLUX.1](#convert-diffusers-to-flux1) ### FLUX.1 LoRA training @@ -355,7 +362,7 @@ If you use LoRA in the inference environment, converting it to AI-toolkit format Note that re-conversion will increase the size of LoRA. -CLIP-L LoRA is not supported. +CLIP-L/T5XXL LoRA is not supported. ### Merge LoRA to FLUX.1 checkpoint @@ -435,6 +442,16 @@ resolution = [512, 512] num_repeats = 1 ``` +### Convert Diffusers to FLUX.1 + +Script: `convert_diffusers_to_flux1.py` + +Converts Diffusers models to FLUX.1 models. The script is experimental. See `--help` for options. schnell and dev models are supported. AE/CLIP/T5XXL are not supported. The diffusers folder is a parent folder of `transfomer` folder. + +``` +python tools/convert_diffusers_to_flux.py --diffusers_path path/to/diffusers_folder_or_00001_safetensors --save_to path/to/flux1.safetensors --mem_eff_load_save --save_precision bf16 +``` + ## SD3 training SD3 training is done with `sd3_train.py`. diff --git a/tools/convert_diffusers_to_flux.py b/tools/convert_diffusers_to_flux.py new file mode 100644 index 00000000..9d8f7c74 --- /dev/null +++ b/tools/convert_diffusers_to_flux.py @@ -0,0 +1,223 @@ +# This script converts the diffusers of a Flux model to a safetensors file of a Flux.1 model. +# It is based on the implementation by 2kpr. Thanks to 2kpr! +# Major changes: +# - Iterates over three safetensors files to reduce memory usage, not loading all tensors at once. +# - Makes reverse map from diffusers map to avoid loading all tensors. +# - Removes dependency on .json file for weights mapping. +# - Adds support for custom memory efficient load and save functions. +# - Supports saving with different precision. +# - Supports .safetensors file as input. + +# Copyright 2024 2kpr. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import os +from pathlib import Path +import safetensors +from safetensors.torch import safe_open +import torch +from tqdm import tqdm + +from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +NUM_DOUBLE_BLOCKS = 19 +NUM_SINGLE_BLOCKS = 38 + +BFL_TO_DIFFUSERS_MAP = { + "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"], + "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"], + "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"], + "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"], + "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"], + "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"], + "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"], + "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"], + "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"], + "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"], + "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"], + "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"], + "txt_in.weight": ["context_embedder.weight"], + "txt_in.bias": ["context_embedder.bias"], + "img_in.weight": ["x_embedder.weight"], + "img_in.bias": ["x_embedder.bias"], + "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"], + "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"], + "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"], + "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"], + "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"], + "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"], + "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"], + "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"], + "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"], + "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"], + "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"], + "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"], + "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"], + "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"], + "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"], + "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"], + "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"], + "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"], + "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"], + "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"], + "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"], + "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"], + "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"], + "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"], + "single_blocks.().modulation.lin.weight": ["norm.linear.weight"], + "single_blocks.().modulation.lin.bias": ["norm.linear.bias"], + "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"], + "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"], + "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().linear2.bias": ["proj_out.bias"], + "final_layer.linear.weight": ["proj_out.weight"], + "final_layer.linear.bias": ["proj_out.bias"], + "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"], + "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"], +} + + +def convert(args): + # if diffusers_path is folder, get safetensors file + diffusers_path = Path(args.diffusers_path) + if diffusers_path.is_dir(): + diffusers_path = Path.joinpath(diffusers_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors") + + flux_path = Path(args.save_to) + if not os.path.exists(flux_path.parent): + os.makedirs(flux_path.parent) + + if not diffusers_path.exists(): + logger.error(f"Error: Missing transformer safetensors file: {diffusers_path}") + return + + mem_eff_flag = args.mem_eff_load_save + save_dtype = str_to_dtype(args.save_precision) if args.save_precision is not None else None + + # make reverse map from diffusers map + diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key) + for b in range(NUM_DOUBLE_BLOCKS): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("double_blocks."): + block_prefix = f"transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for b in range(NUM_SINGLE_BLOCKS): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("single_blocks."): + block_prefix = f"single_transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")): + for i, weight in enumerate(weights): + diffusers_to_bfl_map[weight] = (i, key) + + # iterate over three safetensors files to reduce memory usage + flux_sd = {} + for i in range(3): + # replace 00001 with 0000i + current_diffusers_path = Path(str(diffusers_path).replace("00001", f"0000{i+1}")) + logger.info(f"Loading diffusers file: {current_diffusers_path}") + + open_func = MemoryEfficientSafeOpen if mem_eff_flag else (lambda x: safe_open(x, framework="pt")) + with open_func(current_diffusers_path) as f: + for diffusers_key in tqdm(f.keys()): + if diffusers_key in diffusers_to_bfl_map: + tensor = f.get_tensor(diffusers_key).to("cpu") + if save_dtype is not None: + tensor = tensor.to(save_dtype) + + index, bfl_key = diffusers_to_bfl_map[diffusers_key] + if bfl_key not in flux_sd: + flux_sd[bfl_key] = [] + flux_sd[bfl_key].append((index, tensor)) + else: + logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}") + return + + # concat tensors if multiple tensors are mapped to a single key, sort by index + for key, values in flux_sd.items(): + if len(values) == 1: + flux_sd[key] = values[0][1] + else: + flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])]) + + # special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + if "final_layer.adaLN_modulation.1.weight" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"]) + if "final_layer.adaLN_modulation.1.bias" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"]) + + # save flux_sd to safetensors file + logger.info(f"Saving Flux safetensors file: {flux_path}") + if mem_eff_flag: + mem_eff_save_file(flux_sd, flux_path) + else: + safetensors.torch.save_file(flux_sd, flux_path) + + logger.info("Conversion completed.") + + +def setup_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--diffusers_path", + default=None, + type=str, + required=True, + help="Path to the original Flux diffusers folder or *-00001-of-00003.safetensors file." + " / 元のFlux diffusersフォルダーまたは*-00001-of-00003.safetensorsファイルへのパス", + ) + parser.add_argument( + "--save_to", + default=None, + type=str, + required=True, + help="Output path for the Flux safetensors file. / Flux safetensorsファイルの出力先", + ) + parser.add_argument( + "--mem_eff_load_save", + action="store_true", + help="use custom memory efficient load and save functions for FLUX.1 model" + " / カスタムのメモリ効率の良い読み込みと保存関数をFLUX.1モデルに使用する", + ) + parser.add_argument( + "--save_precision", + type=str, + default=None, + help="precision in saving, default is same as loading precision" + "float32, fp16, bf16, fp8 (same as fp8_e4m3fn), fp8_e4m3fn, fp8_e4m3fnuz, fp8_e5m2, fp8_e5m2fnuz" + " / 保存時に精度を変更して保存する、デフォルトは読み込み時と同じ精度", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + args = parser.parse_args() + convert(args) From be078bdaca41084a20edb952b98a82f3e05d2dad Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 15 Sep 2024 13:59:17 +0900 Subject: [PATCH 57/87] fix typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7d6c336e..f79fe21a 100644 --- a/README.md +++ b/README.md @@ -446,7 +446,7 @@ resolution = [512, 512] Script: `convert_diffusers_to_flux1.py` -Converts Diffusers models to FLUX.1 models. The script is experimental. See `--help` for options. schnell and dev models are supported. AE/CLIP/T5XXL are not supported. The diffusers folder is a parent folder of `transfomer` folder. +Converts Diffusers models to FLUX.1 models. The script is experimental. See `--help` for options. schnell and dev models are supported. AE/CLIP/T5XXL are not supported. The diffusers folder is a parent folder of `rmer` folder. ``` python tools/convert_diffusers_to_flux.py --diffusers_path path/to/diffusers_folder_or_00001_safetensors --save_to path/to/flux1.safetensors --mem_eff_load_save --save_precision bf16 From 96c677b4594ed6f28f3ef896f6deca7c3aced25d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 16 Sep 2024 10:42:09 +0900 Subject: [PATCH 58/87] fix to work lienar/cosine lr scheduler closes #1602 ref #1393 --- library/train_util.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 742d057e..60afd421 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4707,6 +4707,15 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): **lr_scheduler_kwargs, ) + # these schedulers do not require `num_decay_steps` + if name == SchedulerType.LINEAR or name == SchedulerType.COSINE: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + **lr_scheduler_kwargs, + ) + # All other schedulers require `num_decay_steps` if num_decay_steps is None: raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.") @@ -5837,14 +5846,9 @@ def sample_image_inference( wandb_tracker = accelerator.get_tracker("wandb") import wandb + # not to commit images to avoid inconsistency between training and logging steps - wandb_tracker.log( - {f"sample_{i}": wandb.Image( - image, - caption=prompt # positive prompt as a caption - )}, - commit=False - ) + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption # endregion From d8d15f1a7e09ca217930288b41bd239881126b93 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 16 Sep 2024 23:14:09 +0900 Subject: [PATCH 59/87] add support for specifying blocks in FLUX.1 LoRA training --- README.md | 24 ++++++++++++- networks/lora_flux.py | 82 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 103 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f79fe21a..24217d8b 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 16, 2024: + + Added `train_double_block_indices` and `train_double_block_indices` to the LoRA training script to specify the indices of the blocks to train. See [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) for details. + Sep 15, 2024: Added a script `convert_diffusers_to_flux.py` to convert Diffusers format FLUX.1 models (checkpoints) to BFL format. See `--help` for usage. Only Flux models are supported. AE/CLIP/T5XXL are not supported. @@ -54,9 +58,12 @@ Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. ` - [FLUX.1 LoRA training](#flux1-lora-training) - [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training) - - [Inference for FLUX.1 LoRA model](#inference-for-flux1-lora-model) + - [Distribution of timesteps](#distribution-of-timesteps) - [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) + - [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) + - [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) - [FLUX.1 OFT training](#flux1-oft-training) +- [Inference for FLUX.1 with LoRA model](#inference-for-flux1-with-lora-model) - [FLUX.1 fine-tuning](#flux1-fine-tuning) - [Key Features for FLUX.1 fine-tuning](#key-features-for-flux1-fine-tuning) - [Extract LoRA from FLUX.1 Models](#extract-lora-from-flux1-models) @@ -239,6 +246,21 @@ Each number corresponds to `img_in`, `time_in`, `vector_in`, `guidance_in`, `txt If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,0,4]` applies LoRA only to `img_in` and `txt_in`. +#### Specify blocks to train in FLUX.1 LoRA training + +You can specify the blocks to train in FLUX.1 LoRA training by specifying `train_double_block_indices` and `train_single_block_indices` in network_args. The indices are 0-based. The default (when omitted) is to train all blocks. The indices are specified as a list of integers or a range of integers, like `0,1,5,8` or `0,1,4-5,7`. The number of double blocks is 19, and the number of single blocks is 38, so the valid range is 0-18 and 0-37, respectively. `all` is also available to train all blocks, `none` is also available to train no blocks. + +example: +``` +--network_args "train_double_block_indices=0,1,8-12,18" "train_single_block_indices=3,10,20-25,37" +``` + +``` +--network_args "train_double_block_indices=none" "train_single_block_indices=10-15" +``` + +If you specify one of `train_double_block_indices` or `train_single_block_indices`, the other will be trained as usual. + ### FLUX.1 OFT training You can train OFT with almost the same options as LoRA, such as `--timestamp_sampling`. The following points are different. diff --git a/networks/lora_flux.py b/networks/lora_flux.py index a34cde1a..f549ac18 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -24,6 +24,10 @@ import logging logger = logging.getLogger(__name__) +NUM_DOUBLE_BLOCKS = 19 +NUM_SINGLE_BLOCKS = 38 + + class LoRAModule(torch.nn.Module): """ replaces forward method of the original Linear, instead of replacing the original Linear module. @@ -354,6 +358,50 @@ def create_network( in_dims = [int(d) for d in in_dims.split(",")] # is it better to use ast.literal_eval? assert len(in_dims) == 5, f"invalid in_dims: {in_dims}, must be 5 dimensions (img, time, vector, guidance, txt)" + # double/single train blocks + def parse_block_selection(selection: str, total_blocks: int) -> List[bool]: + """ + Parse a block selection string and return a list of booleans. + + Args: + selection (str): A string specifying which blocks to select. + total_blocks (int): The total number of blocks available. + + Returns: + List[bool]: A list of booleans indicating which blocks are selected. + """ + if selection == "all": + return [True] * total_blocks + if selection == "none" or selection == "": + return [False] * total_blocks + + selected = [False] * total_blocks + ranges = selection.split(",") + + for r in ranges: + if "-" in r: + start, end = map(str.strip, r.split("-")) + start = int(start) + end = int(end) + assert 0 <= start < total_blocks, f"invalid start index: {start}" + assert 0 <= end < total_blocks, f"invalid end index: {end}" + assert start <= end, f"invalid range: {start}-{end}" + for i in range(start, end + 1): + selected[i] = True + else: + index = int(r) + assert 0 <= index < total_blocks, f"invalid index: {index}" + selected[index] = True + + return selected + + train_double_block_indices = kwargs.get("train_double_block_indices", None) + train_single_block_indices = kwargs.get("train_single_block_indices", None) + if train_double_block_indices is not None: + train_double_block_indices = parse_block_selection(train_double_block_indices, NUM_DOUBLE_BLOCKS) + if train_single_block_indices is not None: + train_single_block_indices = parse_block_selection(train_single_block_indices, NUM_SINGLE_BLOCKS) + # rank/module dropout rank_dropout = kwargs.get("rank_dropout", None) if rank_dropout is not None: @@ -399,6 +447,8 @@ def create_network( train_t5xxl=train_t5xxl, type_dims=type_dims, in_dims=in_dims, + train_double_block_indices=train_double_block_indices, + train_single_block_indices=train_single_block_indices, verbose=verbose, ) @@ -509,6 +559,8 @@ class LoRANetwork(torch.nn.Module): train_t5xxl: bool = False, type_dims: Optional[List[int]] = None, in_dims: Optional[List[int]] = None, + train_double_block_indices: Optional[List[bool]] = None, + train_single_block_indices: Optional[List[bool]] = None, verbose: Optional[bool] = False, ) -> None: super().__init__() @@ -527,6 +579,8 @@ class LoRANetwork(torch.nn.Module): self.type_dims = type_dims self.in_dims = in_dims + self.train_double_block_indices = train_double_block_indices + self.train_single_block_indices = train_single_block_indices self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None @@ -600,7 +654,7 @@ class LoRANetwork(torch.nn.Module): dim = default_dim if default_dim is not None else self.lora_dim alpha = self.alpha - if type_dims is not None: + if is_flux and type_dims is not None: identifier = [ ("img_attn",), ("txt_attn",), @@ -613,9 +667,33 @@ class LoRANetwork(torch.nn.Module): ] for i, d in enumerate(type_dims): if d is not None and all([id in lora_name for id in identifier[i]]): - dim = d + dim = d # may be 0 for skip break + if ( + is_flux + and dim + and ( + self.train_double_block_indices is not None + or self.train_single_block_indices is not None + ) + and ("double" in lora_name or "single" in lora_name) + ): + # "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..." + block_index = int(lora_name.split("_")[4]) # bit dirty + if ( + "double" in lora_name + and self.train_double_block_indices is not None + and not self.train_double_block_indices[block_index] + ): + dim = 0 + elif ( + "single" in lora_name + and self.train_single_block_indices is not None + and not self.train_single_block_indices[block_index] + ): + dim = 0 + elif self.conv_lora_dim is not None: dim = self.conv_lora_dim alpha = self.conv_alpha From 0cbe95bcc7e88f518802f29fe2b99da806963267 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 17 Sep 2024 21:21:28 +0900 Subject: [PATCH 60/87] fix text_encoder_lr to work with int closes #1608 --- networks/lora_flux.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index f549ac18..91e9cd77 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -966,8 +966,8 @@ class LoRANetwork(torch.nn.Module): # if float, use the same value for both text encoders if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): text_encoder_lr = [default_lr, default_lr] - elif isinstance(text_encoder_lr, float): - text_encoder_lr = [text_encoder_lr, text_encoder_lr] + elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int): + text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr)] elif len(text_encoder_lr) == 1: text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]] From a2ad7e5644f08141fe053a2b63446d70d777bdcf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 17 Sep 2024 21:42:14 +0900 Subject: [PATCH 61/87] blocks_to_swap=0 means no swap --- flux_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flux_train.py b/flux_train.py index 33481df8..5d8326b1 100644 --- a/flux_train.py +++ b/flux_train.py @@ -265,7 +265,7 @@ def train(args): flux.requires_grad_(True) - is_swapping_blocks = args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None + is_swapping_blocks = args.double_blocks_to_swap or args.single_blocks_to_swap if is_swapping_blocks: # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # This idea is based on 2kpr's great work. Thank you! From bbd160b4ca9293881c222f9b9e1d832af69699db Mon Sep 17 00:00:00 2001 From: "Kohya S." <52813779+kohya-ss@users.noreply.github.com> Date: Wed, 18 Sep 2024 07:55:04 +0900 Subject: [PATCH 62/87] sd3 schedule free opt (#1605) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * New ScheduleFree support for Flux (#1600) * init * use no schedule * fix typo * update for eval() * fix typo * update * Update train_util.py * Update requirements.txt * update sfwrapper WIP * no need to check schedulefree optimizer * remove debug print * comment out schedulefree wrapper * update readme --------- Co-authored-by: 青龍聖者@bdsqlsz <865105819@qq.com> --- README.md | 8 +++ library/train_util.py | 156 +++++++++++++++++++++++++++++++++++++++--- requirements.txt | 1 + 3 files changed, 156 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 24217d8b..dc986292 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,14 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 18, 2024: + +- Schedule-free optimizer is added. Thanks to sdbds! See PR [#1600](https://github.com/kohya-ss/sd-scripts/pull/1600) for details. + - `schedulefree` is added to the dependencies. Please update the library if necessary. + - AdamWScheduleFree or SGDScheduleFree can be used. Specify `adamwschedulefree` or `sgdschedulefree` in `--optimizer_type`. + - Wrapper classes are not available for now. + - These can be used not only for FLUX.1 training but also for other training scripts after merging to the dev/main branch. + Sep 16, 2024: Added `train_double_block_indices` and `train_double_block_indices` to the LoRA training script to specify the indices of the blocks to train. See [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) for details. diff --git a/library/train_util.py b/library/train_util.py index 60afd421..a54f23ff 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3303,6 +3303,20 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")', ) + # parser.add_argument( + # "--optimizer_schedulefree_wrapper", + # action="store_true", + # help="use schedulefree_wrapper any optimizer / 任意のオプティマイザにschedulefree_wrapperを使用", + # ) + + # parser.add_argument( + # "--schedulefree_wrapper_args", + # type=str, + # default=None, + # nargs="*", + # help='additional arguments for schedulefree_wrapper (like "momentum=0.9 weight_decay_at_y=0.1 ...") / オプティマイザの追加引数(例: "momentum=0.9 weight_decay_at_y=0.1 ...")', + # ) + parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ") parser.add_argument( "--lr_scheduler_args", @@ -4582,26 +4596,146 @@ def get_optimizer(args, trainable_params): optimizer_class = torch.optim.AdamW optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type.endswith("schedulefree".lower()): + try: + import schedulefree as sf + except ImportError: + raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") + if optimizer_type == "AdamWScheduleFree".lower(): + optimizer_class = sf.AdamWScheduleFree + logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}") + elif optimizer_type == "SGDScheduleFree".lower(): + optimizer_class = sf.SGDScheduleFree + logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}") + else: + raise ValueError(f"Unknown optimizer type: {optimizer_type}") + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + # make optimizer as train mode: we don't need to call train again, because eval will not be called in training loop + optimizer.train() + if optimizer is None: # 任意のoptimizerを使う - optimizer_type = args.optimizer_type # lowerでないやつ(微妙) - logger.info(f"use {optimizer_type} | {optimizer_kwargs}") - if "." not in optimizer_type: - optimizer_module = torch.optim - else: - values = optimizer_type.split(".") - optimizer_module = importlib.import_module(".".join(values[:-1])) - optimizer_type = values[-1] + case_sensitive_optimizer_type = args.optimizer_type # not lower + logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}") - optimizer_class = getattr(optimizer_module, optimizer_type) + if "." not in case_sensitive_optimizer_type: # from torch.optim + optimizer_module = torch.optim + else: # from other library + values = case_sensitive_optimizer_type.split(".") + optimizer_module = importlib.import_module(".".join(values[:-1])) + case_sensitive_optimizer_type = values[-1] + + optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type) optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + """ + # wrap any of above optimizer with schedulefree, if optimizer is not schedulefree + if args.optimizer_schedulefree_wrapper and not optimizer_type.endswith("schedulefree".lower()): + try: + import schedulefree as sf + except ImportError: + raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") + + schedulefree_wrapper_kwargs = {} + if args.schedulefree_wrapper_args is not None and len(args.schedulefree_wrapper_args) > 0: + for arg in args.schedulefree_wrapper_args: + key, value = arg.split("=") + value = ast.literal_eval(value) + schedulefree_wrapper_kwargs[key] = value + + sf_wrapper = sf.ScheduleFreeWrapper(optimizer, **schedulefree_wrapper_kwargs) + sf_wrapper.train() # make optimizer as train mode + + # we need to make optimizer as a subclass of torch.optim.Optimizer, we make another Proxy class over SFWrapper + class OptimizerProxy(torch.optim.Optimizer): + def __init__(self, sf_wrapper): + self._sf_wrapper = sf_wrapper + + def __getattr__(self, name): + return getattr(self._sf_wrapper, name) + + # override properties + @property + def state(self): + return self._sf_wrapper.state + + @state.setter + def state(self, state): + self._sf_wrapper.state = state + + @property + def param_groups(self): + return self._sf_wrapper.param_groups + + @param_groups.setter + def param_groups(self, param_groups): + self._sf_wrapper.param_groups = param_groups + + @property + def defaults(self): + return self._sf_wrapper.defaults + + @defaults.setter + def defaults(self, defaults): + self._sf_wrapper.defaults = defaults + + def add_param_group(self, param_group): + self._sf_wrapper.add_param_group(param_group) + + def load_state_dict(self, state_dict): + self._sf_wrapper.load_state_dict(state_dict) + + def state_dict(self): + return self._sf_wrapper.state_dict() + + def zero_grad(self): + self._sf_wrapper.zero_grad() + + def step(self, closure=None): + self._sf_wrapper.step(closure) + + def train(self): + self._sf_wrapper.train() + + def eval(self): + self._sf_wrapper.eval() + + # isinstance チェックをパスするためのメソッド + def __instancecheck__(self, instance): + return isinstance(instance, (type(self), Optimizer)) + + optimizer = OptimizerProxy(sf_wrapper) + + logger.info(f"wrap optimizer with ScheduleFreeWrapper | {schedulefree_wrapper_kwargs}") + """ + optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) return optimizer_name, optimizer_args, optimizer +def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool: + return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper + + +def get_dummy_scheduler(optimizer: Optimizer) -> Any: + # dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers. + # this scheduler is used for logging only. + # this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler + class DummyScheduler: + def __init__(self, optimizer: Optimizer): + self.optimizer = optimizer + + def step(self): + pass + + def get_last_lr(self): + return [group["lr"] for group in self.optimizer.param_groups] + + return DummyScheduler(optimizer) + + # Modified version of get_scheduler() function from diffusers.optimizer.get_scheduler # Add some checking and features to the original function. @@ -4610,6 +4744,10 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): """ Unified API to get any scheduler from its name. """ + # if schedulefree optimizer, return dummy scheduler + if is_schedulefree_optimizer(optimizer, args): + return get_dummy_scheduler(optimizer) + name = args.lr_scheduler num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps num_warmup_steps: Optional[int] = ( diff --git a/requirements.txt b/requirements.txt index 9a4fa0c1..bab53f20 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ pytorch-lightning==1.9.0 bitsandbytes==0.43.3 prodigyopt==1.0 lion-pytorch==0.0.6 +schedulefree==1.2.7 tensorboard safetensors==0.4.4 # gradio==3.16.2 From e74502117bcf161ef5698fb0adba4f9fa0171b8d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 18 Sep 2024 08:04:32 +0900 Subject: [PATCH 63/87] update README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index dc986292..034a260f 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ The command to install PyTorch is as follows: Sep 18, 2024: - Schedule-free optimizer is added. Thanks to sdbds! See PR [#1600](https://github.com/kohya-ss/sd-scripts/pull/1600) for details. + - Details of the schedule-free optimizer can be found in [facebookresearch/schedule_free](https://github.com/facebookresearch/schedule_free). - `schedulefree` is added to the dependencies. Please update the library if necessary. - AdamWScheduleFree or SGDScheduleFree can be used. Specify `adamwschedulefree` or `sgdschedulefree` in `--optimizer_type`. - Wrapper classes are not available for now. From 1286e00bb0fc34c296f24b7057777f1c37cf8e11 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 18 Sep 2024 21:31:54 +0900 Subject: [PATCH 64/87] fix to call train/eval in schedulefree #1605 --- README.md | 3 +++ flux_train.py | 10 ++++++++++ library/train_util.py | 15 ++++++++++++++- train_network.py | 6 ++++++ 4 files changed, 33 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 034a260f..843ae181 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 18, 2024 (update 1): +Fixed an issue where train()/eval() was not called properly with the schedule-free optimizer. The schedule-free optimizer can be used in FLUX.1 LoRA training and fine-tuning for now. + Sep 18, 2024: - Schedule-free optimizer is added. Thanks to sdbds! See PR [#1600](https://github.com/kohya-ss/sd-scripts/pull/1600) for details. diff --git a/flux_train.py b/flux_train.py index 5d8326b1..bc4e6279 100644 --- a/flux_train.py +++ b/flux_train.py @@ -347,8 +347,13 @@ def train(args): logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") + optimizer_train_fn = lambda: None # dummy function + optimizer_eval_fn = lambda: None # dummy function else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) # prepare dataloader # strategies are set here because they cannot be referenced in another process. Copy them with the dataset @@ -760,6 +765,7 @@ def train(args): progress_bar.update(1) global_step += 1 + optimizer_eval_fn() flux_train_utils.sample_images( accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs ) @@ -778,6 +784,7 @@ def train(args): global_step, accelerator.unwrap_model(flux), ) + optimizer_train_fn() current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if len(accelerator.trackers) > 0: @@ -800,6 +807,7 @@ def train(args): accelerator.wait_for_everyone() + optimizer_eval_fn() if args.save_every_n_epochs is not None: if accelerator.is_main_process: flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( @@ -816,12 +824,14 @@ def train(args): flux_train_utils.sample_images( accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs ) + optimizer_train_fn() is_main_process = accelerator.is_main_process # if is_main_process: flux = accelerator.unwrap_model(flux) accelerator.end_training() + optimizer_eval_fn() if args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) diff --git a/library/train_util.py b/library/train_util.py index a54f23ff..fe9deb94 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -13,6 +13,7 @@ import shutil import time from typing import ( Any, + Callable, Dict, List, NamedTuple, @@ -4715,8 +4716,20 @@ def get_optimizer(args, trainable_params): return optimizer_name, optimizer_args, optimizer +def get_optimizer_train_eval_fn(optimizer: Optimizer, args: argparse.Namespace) -> Tuple[Callable, Callable]: + if not is_schedulefree_optimizer(optimizer, args): + # return dummy func + return lambda: None, lambda: None + + # get train and eval functions from optimizer + train_fn = optimizer.train + eval_fn = optimizer.eval + + return train_fn, eval_fn + + def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool: - return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper + return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper def get_dummy_scheduler(optimizer: Optimizer) -> Any: diff --git a/train_network.py b/train_network.py index 34385ae0..55faa143 100644 --- a/train_network.py +++ b/train_network.py @@ -498,6 +498,7 @@ class NetworkTrainer: # accelerator.print(f"trainable_params: {k} = {v}") optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) # prepare dataloader # strategies are set here because they cannot be referenced in another process. Copy them with the dataset @@ -1199,6 +1200,7 @@ class NetworkTrainer: progress_bar.update(1) global_step += 1 + optimizer_eval_fn() self.sample_images( accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet ) @@ -1217,6 +1219,7 @@ class NetworkTrainer: if remove_step_no is not None: remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) remove_model(remove_ckpt_name) + optimizer_train_fn() current_loss = loss.detach().item() loss_recorder.add(epoch=epoch, step=step, loss=current_loss) @@ -1243,6 +1246,7 @@ class NetworkTrainer: accelerator.wait_for_everyone() # 指定エポックごとにモデルを保存 + optimizer_eval_fn() if args.save_every_n_epochs is not None: saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs if is_main_process and saving: @@ -1258,6 +1262,7 @@ class NetworkTrainer: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + optimizer_train_fn() # end of epoch @@ -1268,6 +1273,7 @@ class NetworkTrainer: network = accelerator.unwrap_model(network) accelerator.end_training() + optimizer_eval_fn() if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) From e7040669bc9a31706fe9fedec14978b05223f968 Mon Sep 17 00:00:00 2001 From: Maru-mee <151493593+Maru-mee@users.noreply.github.com> Date: Thu, 19 Sep 2024 15:47:06 +0900 Subject: [PATCH 65/87] Bug fix: alpha_mask load --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index a46d9487..5a8da90e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2207,7 +2207,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph if alpha_mask: if "alpha_mask" not in npz: return False - if npz["alpha_mask"].shape[0:2] != reso: # HxW + if (npz["alpha_mask"].shape[1], npz["alpha_mask"].shape[0]) != reso: # HxW => WxH != reso return False else: if "alpha_mask" in npz: From 9c757c2fba43d4e91d773cf6e9b7e2e8e3e8b376 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 19 Sep 2024 21:14:57 +0900 Subject: [PATCH 66/87] fix SDXL block index to match LBW --- networks/svd_merge_lora.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index 0decd904..b4b9e3bf 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -184,18 +184,19 @@ def get_lbw_block_index(lora_name: str, is_sdxl: bool = False) -> int: elif "mid_block_" in lora_name: block_idx = 1 + NUM_OF_BLOCKS # 1-based index, num blocks, mid block else: + # SDXL: some numbers are skipped if lora_name.startswith("lora_unet_"): name = lora_name[len("lora_unet_") :] if name.startswith("time_embed_") or name.startswith("label_emb_"): # 1, No LoRA in sd-scripts block_idx = 1 elif name.startswith("input_blocks_"): # 1-8 to 2-9 block_idx = 1 + int(name.split("_")[2]) - elif name.startswith("middle_block_"): # 10 - block_idx = 10 - elif name.startswith("output_blocks_"): # 0-8 to 11-19 - block_idx = 11 + int(name.split("_")[2]) - elif name.startswith("out_"): # 20, No LoRA in sd-scripts - block_idx = 20 + elif name.startswith("middle_block_"): # 13 + block_idx = 13 + elif name.startswith("output_blocks_"): # 0-8 to 14-22 + block_idx = 14 + int(name.split("_")[2]) + elif name.startswith("out_"): # 23, No LoRA in sd-scripts + block_idx = 23 return block_idx From 3957372ded6fda20553acaf169993a422b829bdc Mon Sep 17 00:00:00 2001 From: Ed McManus Date: Thu, 19 Sep 2024 14:30:03 -0700 Subject: [PATCH 67/87] Retain alpha in `pil_resize` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently the alpha channel is dropped by `pil_resize()` when `--alpha_mask` is supplied and the image width does not exceed the bucket. This codepath is entered on the last line, here: ``` def trim_and_resize_if_required( random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int] ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]: image_height, image_width = image.shape[0:2] original_size = (image_width, image_height) # size before resize if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする if image_width > resized_size[0] and image_height > resized_size[1]: image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ else: image = pil_resize(image, resized_size) ``` --- library/utils.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/library/utils.py b/library/utils.py index a0bb1965..2171c719 100644 --- a/library/utils.py +++ b/library/utils.py @@ -305,13 +305,26 @@ class MemoryEfficientSafeOpen: raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") def pil_resize(image, size, interpolation=Image.LANCZOS): - pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + # Check if the image has an alpha channel + has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False - # use Pillow resize + if has_alpha: + # Convert BGRA to RGBA + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)) + else: + # Convert BGR to RGB + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + # Resize the image resized_pil = pil_image.resize(size, interpolation) - # return cv2 image - resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) + # Convert back to cv2 format + if has_alpha: + # Convert RGBA to BGRA + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA) + else: + # Convert RGB to BGR + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) return resized_cv2 From de4bb657b089cc28f4127e891b927895892e20b5 Mon Sep 17 00:00:00 2001 From: Ed McManus Date: Thu, 19 Sep 2024 14:38:32 -0700 Subject: [PATCH 68/87] Update utils.py Cleanup --- library/utils.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/library/utils.py b/library/utils.py index 2171c719..8a0c782c 100644 --- a/library/utils.py +++ b/library/utils.py @@ -305,25 +305,19 @@ class MemoryEfficientSafeOpen: raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") def pil_resize(image, size, interpolation=Image.LANCZOS): - # Check if the image has an alpha channel has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False if has_alpha: - # Convert BGRA to RGBA pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)) else: - # Convert BGR to RGB pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - # Resize the image resized_pil = pil_image.resize(size, interpolation) # Convert back to cv2 format if has_alpha: - # Convert RGBA to BGRA resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA) else: - # Convert RGB to BGR resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) return resized_cv2 From 0535cd29b926530255d5400374813432ec52c3df Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Fri, 20 Sep 2024 10:05:22 +0800 Subject: [PATCH 69/87] fix: backward compatibility for text_encoder_lr --- train_network.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 55faa143..dfa51a9c 100644 --- a/train_network.py +++ b/train_network.py @@ -471,7 +471,11 @@ class NetworkTrainer: if support_multiple_lrs: text_encoder_lr = args.text_encoder_lr else: - text_encoder_lr = None if args.text_encoder_lr is None or len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0] + # toml backward compatibility + if args.text_encoder_lr is None or isinstance(args.text_encoder_lr, float): + text_encoder_lr = args.text_encoder_lr + else: + text_encoder_lr = None if len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0] try: if support_multiple_lrs: results = network.prepare_optimizer_params_with_multiple_te_lrs(text_encoder_lr, args.unet_lr, args.learning_rate) From 583d4a436c1cef57fce405d0167fb7ce575fc768 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 20 Sep 2024 22:22:24 +0900 Subject: [PATCH 70/87] add compatibility for int LR (D-Adaptation etc.) #1620 --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index dfa51a9c..b24f89b1 100644 --- a/train_network.py +++ b/train_network.py @@ -472,7 +472,7 @@ class NetworkTrainer: text_encoder_lr = args.text_encoder_lr else: # toml backward compatibility - if args.text_encoder_lr is None or isinstance(args.text_encoder_lr, float): + if args.text_encoder_lr is None or isinstance(args.text_encoder_lr, float) or isinstance(args.text_encoder_lr, int): text_encoder_lr = args.text_encoder_lr else: text_encoder_lr = None if len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0] From e1f23af1bc733a1a89c35cf1be1301006c744b4a Mon Sep 17 00:00:00 2001 From: recris Date: Sat, 21 Sep 2024 12:58:32 +0100 Subject: [PATCH 71/87] make timestep sampling behave in the standard way when huber loss is used --- library/train_util.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 5a8da90e..72d2d811 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5124,34 +5124,27 @@ def save_sd_model_on_train_end_common( def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device): - - # TODO: if a huber loss is selected, it will use constant timesteps for each batch - # as. In the future there may be a smarter way + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device='cpu') if args.loss_type == "huber" or args.loss_type == "smooth_l1": - timesteps = torch.randint(min_timestep, max_timestep, (1,), device="cpu") - timestep = timesteps.item() - if args.huber_schedule == "exponential": alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps - huber_c = math.exp(-alpha * timestep) + huber_c = torch.exp(-alpha * timesteps) elif args.huber_schedule == "snr": - alphas_cumprod = noise_scheduler.alphas_cumprod[timestep] + alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps) sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c elif args.huber_schedule == "constant": - huber_c = args.huber_c + huber_c = torch.full((b_size,), args.huber_c) else: raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") - - timesteps = timesteps.repeat(b_size).to(device) + huber_c = huber_c.to(device) elif args.loss_type == "l2": - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) - huber_c = 1 # may be anything, as it's not used + huber_c = None # may be anything, as it's not used else: raise NotImplementedError(f"Unknown loss type {args.loss_type}") - timesteps = timesteps.long() + timesteps = timesteps.long().to(device) return timesteps, huber_c @@ -5190,20 +5183,21 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): return noise, noisy_latents, timesteps, huber_c -# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already def conditional_loss( - model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1 + model_pred: torch.Tensor, target: torch.Tensor, reduction: str, loss_type: str, huber_c: Optional[torch.Tensor] ): if loss_type == "l2": loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) elif loss_type == "huber": + huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss) elif reduction == "sum": loss = torch.sum(loss) elif loss_type == "smooth_l1": + huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss) From 29177d2f0389bd13e3f12c95d463fb0e1c58f9a1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 23 Sep 2024 21:14:03 +0900 Subject: [PATCH 72/87] retain alpha in pil_resize backport #1619 --- library/utils.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/library/utils.py b/library/utils.py index 5b7e657b..49d46a54 100644 --- a/library/utils.py +++ b/library/utils.py @@ -83,13 +83,20 @@ def setup_logging(args=None, log_level=None, reset=False): def pil_resize(image, size, interpolation=Image.LANCZOS): - pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False + + if has_alpha: + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)) + else: + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - # use Pillow resize resized_pil = pil_image.resize(size, interpolation) - # return cv2 image - resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) + # Convert back to cv2 format + if has_alpha: + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA) + else: + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) return resized_cv2 From ab7b23187062db86d34fc82db95f7266a68ab5c4 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Wed, 25 Sep 2024 19:38:52 +0800 Subject: [PATCH 73/87] init --- library/train_util.py | 21 ++++++++++++++++++--- requirements.txt | 2 +- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 5a8da90e..bdf7774e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2994,7 +2994,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): "--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, AdEMAMix8bit, PagedAdEMAMix8bit, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", ) # backward compatibility @@ -4032,7 +4032,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args): def get_optimizer(args, trainable_params): - # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" optimizer_type = args.optimizer_type if args.use_8bit_adam: @@ -4141,7 +4141,22 @@ def get_optimizer(args, trainable_params): raise AttributeError( "No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" ) - + elif optimizer_type == "Ademamix8bit".lower(): + logger.info(f"use 8-bit Ademamix optimizer | {optimizer_kwargs}") + try: + optimizer_class = bnb.optim.AdEMAMix8bit + except AttributeError: + raise AttributeError( + "No Ademamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / Ademamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) + elif optimizer_type == "PagedAdemamix8bit".lower(): + logger.info(f"use 8-bit PagedAdemamix optimizer | {optimizer_kwargs}") + try: + optimizer_class = bnb.optim.PagedAdEMAMix8bit + except AttributeError: + raise AttributeError( + "No PagedAdemamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / PagedAdemamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "PagedAdamW".lower(): diff --git a/requirements.txt b/requirements.txt index 15e6e58f..e6e1bf6f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ ftfy==6.1.1 opencv-python==4.8.1.78 einops==0.7.0 pytorch-lightning==1.9.0 -bitsandbytes==0.43.0 +bitsandbytes==0.44.0 prodigyopt==1.0 lion-pytorch==0.0.6 tensorboard From e74f58148c5994889463afa42bb6fc5d6447a75e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 25 Sep 2024 20:55:50 +0900 Subject: [PATCH 74/87] update README --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index def528a2..9eabdaee 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - transformers, accelerate and huggingface_hub are updated. - If you encounter any issues, please report them. +- Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris! + - Improvements in OFT (Orthogonal Finetuning) Implementation 1. Optimization of Calculation Order: - Changed the calculation order in the forward method from (Wx)R to W(xR). From 1beddd84e5c4db729a84356db227d981dc18cf8d Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Wed, 25 Sep 2024 22:58:26 +0800 Subject: [PATCH 75/87] delete code for cleaning --- library/train_util.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index bdf7774e..c4845c54 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4141,22 +4141,7 @@ def get_optimizer(args, trainable_params): raise AttributeError( "No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" ) - elif optimizer_type == "Ademamix8bit".lower(): - logger.info(f"use 8-bit Ademamix optimizer | {optimizer_kwargs}") - try: - optimizer_class = bnb.optim.AdEMAMix8bit - except AttributeError: - raise AttributeError( - "No Ademamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / Ademamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" - ) - elif optimizer_type == "PagedAdemamix8bit".lower(): - logger.info(f"use 8-bit PagedAdemamix optimizer | {optimizer_kwargs}") - try: - optimizer_class = bnb.optim.PagedAdEMAMix8bit - except AttributeError: - raise AttributeError( - "No PagedAdemamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / PagedAdemamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" - ) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "PagedAdamW".lower(): From 56a7bc171d48089fb50f8638537e42d07c579db3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 08:26:31 +0900 Subject: [PATCH 76/87] new block swap for FLUX.1 fine tuning --- README.md | 47 ++++++-- flux_train.py | 241 +++++++++++++++++++++++++++-------------- library/flux_models.py | 172 ++++++++++++++++------------- 3 files changed, 294 insertions(+), 166 deletions(-) diff --git a/README.md b/README.md index ef691e91..7d623f90 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 26, 2024: +The implementation of block swap during FLUX.1 fine-tuning has been changed to improve speed about 10% (depends on the environment). A new `--blocks_to_swap` option has been added, and `--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. `--double_blocks_to_swap` and `--single_blocks_to_swap` are working as before, but they will be removed in the future. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. + + Sep 18, 2024 (update 1): Fixed an issue where train()/eval() was not called properly with the schedule-free optimizer. The schedule-free optimizer can be used in FLUX.1 LoRA training and fine-tuning for now. @@ -307,6 +311,8 @@ python flux_minimal_inference.py --ckpt flux1-dev.safetensors --clip_l sd3/clip_ The memory-efficient training with block swap is based on 2kpr's implementation. Thanks to 2kpr! +__`--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. These options is still available, but they will be removed in the future. Please use `--blocks_to_swap` instead. These options are equivalent to specifying `double_blocks_to_swap + single_blocks_to_swap // 2` in `--blocks_to_swap`.__ + Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GPUs, and 64GB main memory is recommended. ``` @@ -319,39 +325,62 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 --timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 ---fused_backward_pass --double_blocks_to_swap 6 --cpu_offload_checkpointing --full_bf16 +--fused_backward_pass --blocks_to_swap 8 --full_bf16 ``` (The command is multi-line for readability. Please combine it into one line.) -Options are almost the same as LoRA training. The difference is `--full_bf16`, `--blockwise_fused_optimizers`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. +Options are almost the same as LoRA training. The difference is `--full_bf16`, `--fused_backward_pass` and `--blocks_to_swap`. `--cpu_offload_checkpointing` is also available. `--full_bf16` enables the training with bf16 (weights and gradients). `--fused_backward_pass` enables the fusing of the optimizer step into the backward pass for each parameter. This reduces the memory usage during training. Only Adafactor optimizer is supported for now. Stochastic rounding is also enabled when `--fused_backward_pass` and `--full_bf16` are specified. -`--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now. +`--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency and stochastic rounding. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now. -`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. `--double_blocks_to_swap` can be specified with `--single_blocks_to_swap`. The recommended maximum number of blocks to swap is 9 for double blocks and 18 for single blocks. Please see the next chapter for details. +`--blocks_to_swap` is the number of blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. The recommended maximum value is 36. -`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. +`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. All these options are experimental and may change in the future. The increasing the number of blocks to swap may reduce the memory usage, but the training speed will be slower. `--cpu_offload_checkpointing` also slows down the training. -Swap 6 double blocks and use cpu offload checkpointing may be a good starting point. Please try different settings according to VRAM usage and training speed. +Swap 8 blocks without cpu offload checkpointing may be a good starting point for 24GB VRAM GPUs. Please try different settings according to VRAM usage and training speed. The learning rate and the number of epochs are not optimized yet. Please adjust them according to the training results. +#### How to use block swap + +There are two possible ways to use block swap. It is unknown which is better. + +1. Swap the minimum number of blocks that fit in VRAM with batch size 1 and shorten the training speed of one step. + + The above command example is for this usage. + +2. Swap many blocks to increase the batch size and shorten the training speed per data. + + For example, swapping 20 blocks seems to increase the batch size to about 6. In this case, the training speed per data will be relatively faster than 1. + +#### Training with <24GB VRAM GPUs + +Swap 28 blocks without cpu offload checkpointing may be working with 12GB VRAM GPUs. Please try different settings according to VRAM size of your GPU. + +T5XXL requires about 10GB of VRAM, so 10GB of VRAM will be minimum requirement for FLUX.1 fine-tuning. + #### Key Features for FLUX.1 fine-tuning -1. Technical details of double/single block swap: +1. Technical details of block swap: - Reduce memory usage by transferring double and single blocks of FLUX.1 from GPU to CPU when they are not needed. - During forward pass, the weights of the blocks that have finished calculation are transferred to CPU, and the weights of the blocks to be calculated are transferred to GPU. - The same is true for the backward pass, but the order is reversed. The gradients remain on the GPU. - Since the transfer between CPU and GPU takes time, the training will be slower. - - `--double_blocks_to_swap` and `--single_blocks_to_swap` specify the number of blocks to swap. For example, `--double_blocks_to_swap 6` swaps 6 blocks at each step of training, but the remaining 13 blocks are always on the GPU. - - About 640MB of memory can be saved per double block, and about 320MB of memory can be saved per single block. + - `--blocks_to_swap` specify the number of blocks to swap. + - About 640MB of memory can be saved per block. + - Since the memory usage of one double block and two single blocks is almost the same, the transfer of single blocks is done in units of two. For example, consider the case of `--blocks_to_swap 6`. + - Before the forward pass, all double blocks and 26 (=38-12) single blocks are on the GPU. The last 12 single blocks are on the CPU. + - In the forward pass, the 6 double blocks that have finished calculation (the first 6 blocks) are transferred to the CPU, and the 12 single blocks to be calculated (the last 12 blocks) are transferred to the GPU. + - The same is true for the backward pass, but in reverse order. The 12 single blocks that have finished calculation are transferred to the CPU, and the 6 double blocks to be calculated are transferred to the GPU. + - After the backward pass, the blocks are back to their original locations. 2. Sample Image Generation: - Sample image generation during training is now supported. diff --git a/flux_train.py b/flux_train.py index bc4e6279..bf34208f 100644 --- a/flux_train.py +++ b/flux_train.py @@ -11,10 +11,12 @@ # - Per-block fused optimizer instances import argparse +from concurrent.futures import ThreadPoolExecutor import copy import math import os from multiprocessing import Value +import time from typing import List import toml @@ -265,14 +267,30 @@ def train(args): flux.requires_grad_(True) - is_swapping_blocks = args.double_blocks_to_swap or args.single_blocks_to_swap + # block swap + + # backward compatibility + if args.blocks_to_swap is None: + blocks_to_swap = args.double_blocks_to_swap or 0 + if args.single_blocks_to_swap is not None: + blocks_to_swap += args.single_blocks_to_swap // 2 + if blocks_to_swap > 0: + logger.warning( + "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead." + " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。" + ) + logger.info( + f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}." + ) + args.blocks_to_swap = blocks_to_swap + del blocks_to_swap + + is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 if is_swapping_blocks: # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # This idea is based on 2kpr's great work. Thank you! - logger.info( - f"enable block swap: double_blocks_to_swap={args.double_blocks_to_swap}, single_blocks_to_swap={args.single_blocks_to_swap}" - ) - flux.enable_block_swap(args.double_blocks_to_swap, args.single_blocks_to_swap) + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + flux.enable_block_swap(args.blocks_to_swap) if not cache_latents: # load VAE here if not cached @@ -443,82 +461,120 @@ def train(args): # resumeする train_util.resume_from_local_or_hf_if_specified(accelerator, args) + # memory efficient block swapping + + def get_block_unit(dbl_blocks, sgl_blocks, index: int): + if index < len(dbl_blocks): + return (dbl_blocks[index],) + else: + index -= len(dbl_blocks) + index *= 2 + return (sgl_blocks[index], sgl_blocks[index + 1]) + + def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, dbl_blocks, sgl_blocks, device): + def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda, dvc): + # print(f"Backward: Move block {bidx_to_cpu} to CPU") + for block in blocks_to_cpu: + block = block.to("cpu", non_blocking=True) + torch.cuda.empty_cache() + + # print(f"Backward: Move block {bidx_to_cuda} to CUDA") + for block in blocks_to_cuda: + block = block.to(dvc, non_blocking=True) + + torch.cuda.synchronize() + # print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda}") + return bidx_to_cpu, bidx_to_cuda + + blocks_to_cpu = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cpu) + blocks_to_cuda = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cuda) + + futures[block_idx_to_cuda] = thread_pool.submit( + move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda, device + ) + + def wait_blocks_move(block_idx, futures): + if block_idx not in futures: + return + # print(f"Backward: Wait for block {block_idx}") + # start_time = time.perf_counter() + future = futures.pop(block_idx) + future.result() + # print(f"Backward: Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") + # torch.cuda.synchronize() + # print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s") + if args.fused_backward_pass: # use fused optimizer for backward pass: other optimizers will be supported in the future import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) - double_blocks_to_swap = args.double_blocks_to_swap - single_blocks_to_swap = args.single_blocks_to_swap + blocks_to_swap = args.blocks_to_swap num_double_blocks = 19 # len(flux.double_blocks) num_single_blocks = 38 # len(flux.single_blocks) - handled_double_block_indices = set() - handled_single_block_indices = set() + num_block_units = num_double_blocks + num_single_blocks // 2 + handled_unit_indices = set() + + n = 1 # only asyncronous purpose, no need to increase this number + # n = 2 + # n = max(1, os.cpu_count() // 2) + thread_pool = ThreadPoolExecutor(max_workers=n) + futures = {} for param_group, param_name_group in zip(optimizer.param_groups, param_names): for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: grad_hook = None - if double_blocks_to_swap: - if param_name.startswith("double_blocks"): + if blocks_to_swap: + is_double = param_name.startswith("double_blocks") + is_single = param_name.startswith("single_blocks") + if is_double or is_single: block_idx = int(param_name.split(".")[1]) - if ( - block_idx not in handled_double_block_indices - and block_idx >= (num_double_blocks - double_blocks_to_swap) - 1 - and block_idx < num_double_blocks - 1 - ): - # swap next (already backpropagated) block - handled_double_block_indices.add(block_idx) - block_idx_cpu = block_idx + 1 - block_idx_cuda = double_blocks_to_swap - (num_double_blocks - block_idx_cpu) + unit_idx = block_idx if is_double else num_double_blocks + block_idx // 2 + if unit_idx not in handled_unit_indices: + # swap following (already backpropagated) block + handled_unit_indices.add(unit_idx) - # create swap hook - def create_double_swap_grad_hook(bidx, bidx_cuda): - def __grad_hook(tensor: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None + # if n blocks were already backpropagated + num_blocks_propagated = num_block_units - unit_idx - 1 + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap + waiting = unit_idx > 0 and unit_idx <= blocks_to_swap + if swapping or waiting: + block_idx_to_cpu = num_block_units - num_blocks_propagated + block_idx_to_cuda = blocks_to_swap - num_blocks_propagated + block_idx_to_wait = unit_idx - 1 - # swap blocks if necessary - flux.double_blocks[bidx].to("cpu") - flux.double_blocks[bidx_cuda].to(accelerator.device) - # print(f"Move double block {bidx} to cpu and {bidx_cuda} to device") + # create swap hook + def create_swap_grad_hook( + bidx_to_cpu, bidx_to_cuda, bidx_to_wait, uidx: int, swpng: bool, wtng: bool + ): + def __grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None - return __grad_hook + # print(f"Backward: {uidx}, {swpng}, {wtng}") + if swpng: + submit_move_blocks( + futures, + thread_pool, + bidx_to_cpu, + bidx_to_cuda, + flux.double_blocks, + flux.single_blocks, + accelerator.device, + ) + if wtng: + wait_blocks_move(bidx_to_wait, futures) - grad_hook = create_double_swap_grad_hook(block_idx_cpu, block_idx_cuda) - if single_blocks_to_swap: - if param_name.startswith("single_blocks"): - block_idx = int(param_name.split(".")[1]) - if ( - block_idx not in handled_single_block_indices - and block_idx >= (num_single_blocks - single_blocks_to_swap) - 1 - and block_idx < num_single_blocks - 1 - ): - handled_single_block_indices.add(block_idx) - block_idx_cpu = block_idx + 1 - block_idx_cuda = single_blocks_to_swap - (num_single_blocks - block_idx_cpu) - # print(param_name, block_idx_cpu, block_idx_cuda) + return __grad_hook - # create swap hook - def create_single_swap_grad_hook(bidx, bidx_cuda): - def __grad_hook(tensor: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None - - # swap blocks if necessary - flux.single_blocks[bidx].to("cpu") - flux.single_blocks[bidx_cuda].to(accelerator.device) - # print(f"Move single block {bidx} to cpu and {bidx_cuda} to device") - - return __grad_hook - - grad_hook = create_single_swap_grad_hook(block_idx_cpu, block_idx_cuda) + grad_hook = create_swap_grad_hook( + block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, unit_idx, swapping, waiting + ) if grad_hook is None: @@ -547,10 +603,15 @@ def train(args): num_parameters_per_group = [0] * len(optimizers) parameter_optimizer_map = {} - double_blocks_to_swap = args.double_blocks_to_swap - single_blocks_to_swap = args.single_blocks_to_swap + blocks_to_swap = args.blocks_to_swap num_double_blocks = 19 # len(flux.double_blocks) num_single_blocks = 38 # len(flux.single_blocks) + num_block_units = num_double_blocks + num_single_blocks // 2 + + n = 1 # only asyncronous purpose, no need to increase this number + # n = max(1, os.cpu_count() // 2) + thread_pool = ThreadPoolExecutor(max_workers=n) + futures = {} for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: @@ -571,18 +632,30 @@ def train(args): optimizers[i].zero_grad(set_to_none=True) # swap blocks if necessary - if btype == "double" and double_blocks_to_swap: - if bidx >= num_double_blocks - double_blocks_to_swap: - bidx_cuda = double_blocks_to_swap - (num_double_blocks - bidx) - flux.double_blocks[bidx].to("cpu") - flux.double_blocks[bidx_cuda].to(accelerator.device) - # print(f"Move double block {bidx} to cpu and {bidx_cuda} to device") - elif btype == "single" and single_blocks_to_swap: - if bidx >= num_single_blocks - single_blocks_to_swap: - bidx_cuda = single_blocks_to_swap - (num_single_blocks - bidx) - flux.single_blocks[bidx].to("cpu") - flux.single_blocks[bidx_cuda].to(accelerator.device) - # print(f"Move single block {bidx} to cpu and {bidx_cuda} to device") + if blocks_to_swap and (btype == "double" or (btype == "single" and bidx % 2 == 0)): + unit_idx = bidx if btype == "double" else num_double_blocks + bidx // 2 + num_blocks_propagated = num_block_units - unit_idx + + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap + waiting = unit_idx > 0 and unit_idx <= blocks_to_swap + + if swapping: + block_idx_to_cpu = num_block_units - num_blocks_propagated + block_idx_to_cuda = blocks_to_swap - num_blocks_propagated + # print(f"Backward: Swap blocks {block_idx_to_cpu} and {block_idx_to_cuda}") + submit_move_blocks( + futures, + thread_pool, + block_idx_to_cpu, + block_idx_to_cuda, + flux.double_blocks, + flux.single_blocks, + accelerator.device, + ) + + if waiting: + block_idx_to_wait = unit_idx - 1 + wait_blocks_move(block_idx_to_wait, futures) return optimizer_hook @@ -881,24 +954,26 @@ def setup_parser() -> argparse.ArgumentParser: help="skip latents validity check / latentsの正当性チェックをスキップする", ) parser.add_argument( - "--double_blocks_to_swap", + "--blocks_to_swap", type=int, default=None, help="[EXPERIMENTAL] " - "Sets the number of 'double_blocks' (~640MB) to swap during the forward and backward passes." + "Sets the number of blocks (~640MB) to swap during the forward and backward passes." "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." - " / 順伝播および逆伝播中にスワップする'変換ブロック'(約640MB)の数を設定します。" + " / 順伝播および逆伝播中にスワップするブロック(約640MB)の数を設定します。" "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", ) + parser.add_argument( + "--double_blocks_to_swap", + type=int, + default=None, + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", + ) parser.add_argument( "--single_blocks_to_swap", type=int, default=None, - help="[EXPERIMENTAL] " - "Sets the number of 'single_blocks' (~320MB) to swap during the forward and backward passes." - "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." - " / 順伝播および逆伝播中にスワップする'変換ブロック'(約320MB)の数を設定します。" - "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", ) parser.add_argument( "--cpu_offload_checkpointing", diff --git a/library/flux_models.py b/library/flux_models.py index b5726c29..a35dbc10 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -2,9 +2,12 @@ # license: Apache-2.0 License +from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass import math -from typing import Optional +import os +import time +from typing import Dict, List, Optional from library.device_utils import init_ipex, clean_memory_on_device @@ -917,8 +920,10 @@ class Flux(nn.Module): self.gradient_checkpointing = False self.cpu_offload_checkpointing = False - self.double_blocks_to_swap = None - self.single_blocks_to_swap = None + self.blocks_to_swap = None + + self.thread_pool: Optional[ThreadPoolExecutor] = None + self.num_block_units = len(self.double_blocks) + len(self.single_blocks) // 2 @property def device(self): @@ -956,38 +961,52 @@ class Flux(nn.Module): print("FLUX: Gradient checkpointing disabled.") - def enable_block_swap(self, double_blocks: Optional[int], single_blocks: Optional[int]): - self.double_blocks_to_swap = double_blocks - self.single_blocks_to_swap = single_blocks + def enable_block_swap(self, num_blocks: int): + self.blocks_to_swap = num_blocks + + n = 1 # async block swap. 1 is enough + # n = 2 + # n = max(1, os.cpu_count() // 2) + self.thread_pool = ThreadPoolExecutor(max_workers=n) def move_to_device_except_swap_blocks(self, device: torch.device): # assume model is on cpu - if self.double_blocks_to_swap: + if self.blocks_to_swap: save_double_blocks = self.double_blocks - self.double_blocks = None - if self.single_blocks_to_swap: save_single_blocks = self.single_blocks + self.double_blocks = None self.single_blocks = None self.to(device) - if self.double_blocks_to_swap: + if self.blocks_to_swap: self.double_blocks = save_double_blocks - if self.single_blocks_to_swap: self.single_blocks = save_single_blocks + def get_block_unit(self, index: int): + if index < len(self.double_blocks): + return (self.double_blocks[index],) + else: + index -= len(self.double_blocks) + index *= 2 + return self.single_blocks[index], self.single_blocks[index + 1] + + def get_unit_index(self, is_double: bool, index: int): + if is_double: + return index + else: + return len(self.double_blocks) + index // 2 + def prepare_block_swap_before_forward(self): - # move last n blocks to cpu: they are on cuda - if self.double_blocks_to_swap: - for i in range(len(self.double_blocks) - self.double_blocks_to_swap): - self.double_blocks[i].to(self.device) - for i in range(len(self.double_blocks) - self.double_blocks_to_swap, len(self.double_blocks)): - self.double_blocks[i].to("cpu") # , non_blocking=True) - if self.single_blocks_to_swap: - for i in range(len(self.single_blocks) - self.single_blocks_to_swap): - self.single_blocks[i].to(self.device) - for i in range(len(self.single_blocks) - self.single_blocks_to_swap, len(self.single_blocks)): - self.single_blocks[i].to("cpu") # , non_blocking=True) + # make: first n blocks are on cuda, and last n blocks are on cpu + if self.blocks_to_swap is None: + raise ValueError("Block swap is not enabled.") + for i in range(self.num_block_units - self.blocks_to_swap): + for b in self.get_block_unit(i): + b.to(self.device) + for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units): + for b in self.get_block_unit(i): + b.to("cpu") clean_memory_on_device(self.device) def forward( @@ -1017,69 +1036,73 @@ class Flux(nn.Module): ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) - if not self.double_blocks_to_swap: + if not self.blocks_to_swap: for block in self.double_blocks: img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - else: - # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning - for block_idx in range(self.double_blocks_to_swap): - block = self.double_blocks[len(self.double_blocks) - self.double_blocks_to_swap + block_idx] - if block.parameters().__next__().device.type != "cpu": - block.to("cpu") # , non_blocking=True) - # print(f"Moved double block {len(self.double_blocks) - self.double_blocks_to_swap + block_idx} to cpu.") - - block = self.double_blocks[block_idx] - if block.parameters().__next__().device.type == "cpu": - block.to(self.device) - # print(f"Moved double block {block_idx} to cuda.") - - to_cpu_block_index = 0 - for block_idx, block in enumerate(self.double_blocks): - # move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda - moving = block_idx >= len(self.double_blocks) - self.double_blocks_to_swap - if moving: - block.to(self.device) # move to cuda - # print(f"Moved double block {block_idx} to cuda.") - - img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - - if moving: - self.double_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) - # print(f"Moved double block {to_cpu_block_index} to cpu.") - to_cpu_block_index += 1 - - img = torch.cat((txt, img), 1) - - if not self.single_blocks_to_swap: + img = torch.cat((txt, img), 1) for block in self.single_blocks: img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) else: - # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning - for block_idx in range(self.single_blocks_to_swap): - block = self.single_blocks[len(self.single_blocks) - self.single_blocks_to_swap + block_idx] - if block.parameters().__next__().device.type != "cpu": - block.to("cpu") # , non_blocking=True) - # print(f"Moved single block {len(self.single_blocks) - self.single_blocks_to_swap + block_idx} to cpu.") + futures = {} - block = self.single_blocks[block_idx] - if block.parameters().__next__().device.type == "cpu": - block.to(self.device) - # print(f"Moved single block {block_idx} to cuda.") + def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda): + def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda): + # print(f"Moving {bidx_to_cpu} to cpu.") + for block in blocks_to_cpu: + block.to("cpu", non_blocking=True) + torch.cuda.empty_cache() + + # print(f"Moving {bidx_to_cuda} to cuda.") + for block in blocks_to_cuda: + block.to(self.device, non_blocking=True) + + torch.cuda.synchronize() + # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") + return block_idx_to_cpu, block_idx_to_cuda + + blocks_to_cpu = self.get_block_unit(block_idx_to_cpu) + blocks_to_cuda = self.get_block_unit(block_idx_to_cuda) + # print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.") + return self.thread_pool.submit(move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda) + + def wait_for_blocks_move(block_idx, ftrs): + if block_idx not in ftrs: + return + # print(f"Waiting for move blocks: {block_idx}") + # start_time = time.perf_counter() + ftr = ftrs.pop(block_idx) + ftr.result() + # torch.cuda.synchronize() + # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds") + + for block_idx, block in enumerate(self.double_blocks): + # print(f"Double block {block_idx}") + unit_idx = self.get_unit_index(is_double=True, index=block_idx) + wait_for_blocks_move(unit_idx, futures) + + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + + if unit_idx < self.blocks_to_swap: + block_idx_to_cpu = unit_idx + block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx + future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) + futures[block_idx_to_cuda] = future + + img = torch.cat((txt, img), 1) - to_cpu_block_index = 0 for block_idx, block in enumerate(self.single_blocks): - # move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda - moving = block_idx >= len(self.single_blocks) - self.single_blocks_to_swap - if moving: - block.to(self.device) # move to cuda - # print(f"Moved single block {block_idx} to cuda.") + # print(f"Single block {block_idx}") + unit_idx = self.get_unit_index(is_double=False, index=block_idx) + if block_idx % 2 == 0: + wait_for_blocks_move(unit_idx, futures) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if moving: - self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) - # print(f"Moved single block {to_cpu_block_index} to cpu.") - to_cpu_block_index += 1 + if block_idx % 2 == 1 and unit_idx < self.blocks_to_swap: + block_idx_to_cpu = unit_idx + block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx + future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) + futures[block_idx_to_cuda] = future img = img[:, txt.shape[1] :, ...] @@ -1088,6 +1111,7 @@ class Flux(nn.Module): vec = vec.to(self.device) img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img From da94fd934eb4951d1cb132abc9d2a355e44d7abf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 08:27:48 +0900 Subject: [PATCH 77/87] fix typos --- flux_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flux_train.py b/flux_train.py index bf34208f..022467ea 100644 --- a/flux_train.py +++ b/flux_train.py @@ -516,7 +516,7 @@ def train(args): num_block_units = num_double_blocks + num_single_blocks // 2 handled_unit_indices = set() - n = 1 # only asyncronous purpose, no need to increase this number + n = 1 # only asynchronous purpose, no need to increase this number # n = 2 # n = max(1, os.cpu_count() // 2) thread_pool = ThreadPoolExecutor(max_workers=n) @@ -608,7 +608,7 @@ def train(args): num_single_blocks = 38 # len(flux.single_blocks) num_block_units = num_double_blocks + num_single_blocks // 2 - n = 1 # only asyncronous purpose, no need to increase this number + n = 1 # only asynchronous purpose, no need to increase this number # n = max(1, os.cpu_count() // 2) thread_pool = ThreadPoolExecutor(max_workers=n) futures = {} From bf91bea2e4363e5b3e0db11f0955ab93a19a0452 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 20:51:40 +0900 Subject: [PATCH 78/87] fix flip_aug, alpha_mask, random_crop issue in caching --- README.md | 2 ++ library/train_util.py | 44 +++++++++++++++++++++++++++++++------------ 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 9eabdaee..b67a2c4e 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - transformers, accelerate and huggingface_hub are updated. - If you encounter any issues, please report them. +- Fixed a bug in the cache of latents. When `flip_aug`, `alpha_mask`, and `random_crop` are different in multiple subsets in the dataset configuration file (.toml), the last subset is used instead of reflecting them correctly. + - Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris! - Improvements in OFT (Orthogonal Finetuning) Implementation diff --git a/library/train_util.py b/library/train_util.py index 72d2d811..a31d00c6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -998,9 +998,26 @@ class BaseDataset(torch.utils.data.Dataset): # sort by resolution image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) - # split by resolution - batches = [] - batch = [] + # split by resolution and some conditions + class Condition: + def __init__(self, reso, flip_aug, alpha_mask, random_crop): + self.reso = reso + self.flip_aug = flip_aug + self.alpha_mask = alpha_mask + self.random_crop = random_crop + + def __eq__(self, other): + return ( + self.reso == other.reso + and self.flip_aug == other.flip_aug + and self.alpha_mask == other.alpha_mask + and self.random_crop == other.random_crop + ) + + batches: List[Tuple[Condition, List[ImageInfo]]] = [] + batch: List[ImageInfo] = [] + current_condition = None + logger.info("checking cache validity...") for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] @@ -1021,28 +1038,31 @@ class BaseDataset(torch.utils.data.Dataset): if cache_available: # do not add to batch continue - # if last member of batch has different resolution, flush the batch - if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: - batches.append(batch) + # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty + condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) + if len(batch) > 0 and current_condition != condition: + batches.append((current_condition, batch)) batch = [] batch.append(info) + current_condition = condition # if number of data in batch is enough, flush the batch if len(batch) >= vae_batch_size: - batches.append(batch) + batches.append((current_condition, batch)) batch = [] + current_condition = None if len(batch) > 0: - batches.append(batch) + batches.append((current_condition, batch)) if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only return # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded logger.info("caching latents...") - for batch in tqdm(batches, smoothing=1, total=len(batches)): - cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): + cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる # SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する @@ -2315,7 +2335,7 @@ def debug_dataset(train_dataset, show_input_ids=False): if "alpha_masks" in example and example["alpha_masks"] is not None: alpha_mask = example["alpha_masks"][j] logger.info(f"alpha mask size: {alpha_mask.size()}") - alpha_mask = (alpha_mask[0].numpy() * 255.0).astype(np.uint8) + alpha_mask = (alpha_mask.numpy() * 255.0).astype(np.uint8) if os.name == "nt": cv2.imshow("alpha_mask", alpha_mask) @@ -5124,7 +5144,7 @@ def save_sd_model_on_train_end_common( def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device): - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device='cpu') + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") if args.loss_type == "huber" or args.loss_type == "smooth_l1": if args.huber_schedule == "exponential": From 392e8dedd84e469b125e2935e3ecf02e6270a5b2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 21:14:11 +0900 Subject: [PATCH 79/87] fix flip_aug, alpha_mask, random_crop issue in caching in caching strategy --- library/train_util.py | 41 ++++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 319337a4..17dd447e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -993,9 +993,26 @@ class BaseDataset(torch.utils.data.Dataset): # sort by resolution image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) - # split by resolution - batches = [] - batch = [] + # split by resolution and some conditions + class Condition: + def __init__(self, reso, flip_aug, alpha_mask, random_crop): + self.reso = reso + self.flip_aug = flip_aug + self.alpha_mask = alpha_mask + self.random_crop = random_crop + + def __eq__(self, other): + return ( + self.reso == other.reso + and self.flip_aug == other.flip_aug + and self.alpha_mask == other.alpha_mask + and self.random_crop == other.random_crop + ) + + batches: List[Tuple[Condition, List[ImageInfo]]] = [] + batch: List[ImageInfo] = [] + current_condition = None + logger.info("checking cache validity...") for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] @@ -1016,20 +1033,23 @@ class BaseDataset(torch.utils.data.Dataset): if cache_available: # do not add to batch continue - # if last member of batch has different resolution, flush the batch - if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: - batches.append(batch) + # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty + condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) + if len(batch) > 0 and current_condition != condition: + batches.append((current_condition, batch)) batch = [] batch.append(info) + current_condition = condition # if number of data in batch is enough, flush the batch if len(batch) >= caching_strategy.batch_size: - batches.append(batch) + batches.append((current_condition, batch)) batch = [] + current_condition = None if len(batch) > 0: - batches.append(batch) + batches.append((current_condition, batch)) # if cache to disk, don't cache latents in non-main process, set to info only if caching_strategy.cache_to_disk and not is_main_process: @@ -1041,9 +1061,8 @@ class BaseDataset(torch.utils.data.Dataset): # iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded logger.info("caching latents...") - for batch in tqdm(batches, smoothing=1, total=len(batches)): - # cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) - caching_strategy.cache_batch_latents(model, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): + caching_strategy.cache_batch_latents(model, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと From a94bc84dec8e85e8a71217b4d2570a52c6779b73 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 21:37:31 +0900 Subject: [PATCH 80/87] fix to work bitsandbytes optimizers with full path #1640 --- library/train_util.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index b40945ab..47c36768 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3014,7 +3014,11 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): "--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, AdEMAMix8bit, PagedAdEMAMix8bit, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, " + "Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, " + "DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, " + "AdaFactor. " + "Also, you can use any optimizer by specifying the full path to the class, like 'bitsandbytes.optim.AdEMAMix8bit' or 'bitsandbytes.optim.PagedAdEMAMix8bit'.", ) # backward compatibility @@ -4105,6 +4109,7 @@ def get_optimizer(args, trainable_params): lr = args.learning_rate optimizer = None + optimizer_class = None if optimizer_type == "Lion".lower(): try: @@ -4162,7 +4167,8 @@ def get_optimizer(args, trainable_params): "No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" ) - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + if optimizer_class is not None: + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "PagedAdamW".lower(): logger.info(f"use PagedAdamW optimizer | {optimizer_kwargs}") @@ -4338,6 +4344,7 @@ def get_optimizer(args, trainable_params): optimizer_class = getattr(optimizer_module, optimizer_type) optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + # for logging optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) From ce49ced699298aa885d9a64b969fe8c77f30893b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 21:37:40 +0900 Subject: [PATCH 81/87] update readme --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index b67a2c4e..9f024c1c 100644 --- a/README.md +++ b/README.md @@ -140,9 +140,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Working in progress - __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries. - - transformers, accelerate and huggingface_hub are updated. + - bitsandbytes, transformers, accelerate and huggingface_hub are updated. - If you encounter any issues, please report them. +- `bitsandbytes` is updated to 0.44.0. Now you can use `AdEMAMix8bit` and `PagedAdEMAMix8bit` in the training script. PR [#1640](https://github.com/kohya-ss/sd-scripts/pull/1640) Thanks to sdbds! + - There is no abbreviation, so please specify the full path like `--optimizer_type bitsandbytes.optim.AdEMAMix8bit` (not bnb but bitsandbytes). + - Fixed a bug in the cache of latents. When `flip_aug`, `alpha_mask`, and `random_crop` are different in multiple subsets in the dataset configuration file (.toml), the last subset is used instead of reflecting them correctly. - Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris! From a9aa52658a0d9ba7910a1d1983b650bc9de7153e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 28 Sep 2024 17:12:56 +0900 Subject: [PATCH 82/87] fix sample generation is not working in FLUX1 fine tuning #1647 --- library/flux_models.py | 5 +++-- library/flux_train_utils.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/library/flux_models.py b/library/flux_models.py index a35dbc10..0bc1c02b 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -999,8 +999,9 @@ class Flux(nn.Module): def prepare_block_swap_before_forward(self): # make: first n blocks are on cuda, and last n blocks are on cpu - if self.blocks_to_swap is None: - raise ValueError("Block swap is not enabled.") + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + # raise ValueError("Block swap is not enabled.") + return for i in range(self.num_block_units - self.blocks_to_swap): for b in self.get_block_unit(i): b.to(self.device) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f77d4b58..1d1eb9d2 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -313,6 +313,7 @@ def denoise( guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + model.prepare_block_swap_before_forward() pred = model( img=img, img_ids=img_ids, @@ -325,7 +326,8 @@ def denoise( ) img = img + (t_prev - t_curr) * pred - + + model.prepare_block_swap_before_forward() return img From 822fe578591e44ac949830e03a8841e222483052 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 28 Sep 2024 20:57:27 +0900 Subject: [PATCH 83/87] add workaround for 'Some tensors share memory' error #1614 --- networks/convert_flux_lora.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/networks/convert_flux_lora.py b/networks/convert_flux_lora.py index bd4c1cf7..fe6466eb 100644 --- a/networks/convert_flux_lora.py +++ b/networks/convert_flux_lora.py @@ -412,6 +412,10 @@ def main(args): state_dict = convert_ai_toolkit_to_sd_scripts(state_dict) elif args.src == "sd-scripts" and args.dst == "ai-toolkit": state_dict = convert_sd_scripts_to_ai_toolkit(state_dict) + + # eliminate 'shared tensors' + for k in list(state_dict.keys()): + state_dict[k] = state_dict[k].detach().clone() else: raise NotImplementedError(f"Conversion from {args.src} to {args.dst} is not supported") From 1a0f5b0c389f4e9fab5edb06b36f203e8894d581 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Sep 2024 00:35:29 +0900 Subject: [PATCH 84/87] re-fix sample generation is not working in FLUX1 split mode #1647 --- flux_train_network.py | 3 +++ library/flux_train_utils.py | 1 - 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/flux_train_network.py b/flux_train_network.py index a6e57eed..65b121e7 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -300,6 +300,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): self.flux_lower = flux_lower self.target_device = device + def prepare_block_swap_before_forward(self): + pass + def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None): self.flux_lower.to("cpu") clean_memory_on_device(self.target_device) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 1d1eb9d2..b3c9184f 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -196,7 +196,6 @@ def sample_image_inference( tokens_and_masks = tokenize_strategy.tokenize(prompt) # strategy has apply_t5_attn_mask option encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) - print([x.shape if x is not None else None for x in encoded_text_encoder_conds]) # if text_encoder_conds is not cached, use encoded_text_encoder_conds if len(text_encoder_conds) == 0: From fe2aa32484a948f16955909e64c21da7fe1e4e0c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Sep 2024 09:49:25 +0900 Subject: [PATCH 85/87] adjust min/max bucket reso divisible by reso steps #1632 --- README.md | 2 ++ docs/config_README-en.md | 2 ++ docs/config_README-ja.md | 2 ++ fine_tune.py | 2 ++ library/train_util.py | 40 ++++++++++++++++++++++++++++++++------ train_controlnet.py | 2 ++ train_db.py | 2 ++ train_network.py | 2 +- train_textual_inversion.py | 2 +- 9 files changed, 48 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 9f024c1c..de5cddb9 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - bitsandbytes, transformers, accelerate and huggingface_hub are updated. - If you encounter any issues, please report them. +- There was a bug where the min_bucket_reso/max_bucket_reso in the dataset configuration did not create the correct resolution bucket if it was not divisible by bucket_reso_steps. These values are now warned and automatically rounded to a divisible value. Thanks to Maru-mee for raising the issue. Related PR [#1632](https://github.com/kohya-ss/sd-scripts/pull/1632) + - `bitsandbytes` is updated to 0.44.0. Now you can use `AdEMAMix8bit` and `PagedAdEMAMix8bit` in the training script. PR [#1640](https://github.com/kohya-ss/sd-scripts/pull/1640) Thanks to sdbds! - There is no abbreviation, so please specify the full path like `--optimizer_type bitsandbytes.optim.AdEMAMix8bit` (not bnb but bitsandbytes). diff --git a/docs/config_README-en.md b/docs/config_README-en.md index 83bea329..66a50dc0 100644 --- a/docs/config_README-en.md +++ b/docs/config_README-en.md @@ -128,6 +128,8 @@ These are options related to the configuration of the data set. They cannot be d * `batch_size` * This corresponds to the command-line argument `--train_batch_size`. +* `max_bucket_reso`, `min_bucket_reso` + * Specify the maximum and minimum resolutions of the bucket. It must be divisible by `bucket_reso_steps`. These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each. diff --git a/docs/config_README-ja.md b/docs/config_README-ja.md index cc74c341..0ed95e0e 100644 --- a/docs/config_README-ja.md +++ b/docs/config_README-ja.md @@ -118,6 +118,8 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学 * `batch_size` * コマンドライン引数の `--train_batch_size` と同等です。 +* `max_bucket_reso`, `min_bucket_reso` + * bucketの最大、最小解像度を指定します。`bucket_reso_steps` で割り切れる必要があります。 これらの設定はデータセットごとに固定です。 つまり、データセットに所属するサブセットはこれらの設定を共有することになります。 diff --git a/fine_tune.py b/fine_tune.py index d865cd2d..b556672d 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -91,6 +91,8 @@ def train(args): ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + train_dataset_group.verify_bucket_reso_steps(64) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return diff --git a/library/train_util.py b/library/train_util.py index 47c36768..0cb6383a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -653,6 +653,34 @@ class BaseDataset(torch.utils.data.Dataset): # caching self.caching_mode = None # None, 'latents', 'text' + def adjust_min_max_bucket_reso_by_steps( + self, resolution: Tuple[int, int], min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int + ) -> Tuple[int, int]: + # make min/max bucket reso to be multiple of bucket_reso_steps + if min_bucket_reso % bucket_reso_steps != 0: + adjusted_min_bucket_reso = min_bucket_reso - min_bucket_reso % bucket_reso_steps + logger.warning( + f"min_bucket_reso is adjusted to be multiple of bucket_reso_steps" + f" / min_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {min_bucket_reso} -> {adjusted_min_bucket_reso}" + ) + min_bucket_reso = adjusted_min_bucket_reso + if max_bucket_reso % bucket_reso_steps != 0: + adjusted_max_bucket_reso = max_bucket_reso + bucket_reso_steps - max_bucket_reso % bucket_reso_steps + logger.warning( + f"max_bucket_reso is adjusted to be multiple of bucket_reso_steps" + f" / max_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {max_bucket_reso} -> {adjusted_max_bucket_reso}" + ) + max_bucket_reso = adjusted_max_bucket_reso + + assert ( + min(resolution) >= min_bucket_reso + ), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" + assert ( + max(resolution) <= max_bucket_reso + ), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" + + return min_bucket_reso, max_bucket_reso + def set_seed(self, seed): self.seed = seed @@ -1533,12 +1561,9 @@ class DreamBoothDataset(BaseDataset): self.enable_bucket = enable_bucket if self.enable_bucket: - assert ( - min(resolution) >= min_bucket_reso - ), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" - assert ( - max(resolution) <= max_bucket_reso - ), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" + min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps( + resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps + ) self.min_bucket_reso = min_bucket_reso self.max_bucket_reso = max_bucket_reso self.bucket_reso_steps = bucket_reso_steps @@ -1901,6 +1926,9 @@ class FineTuningDataset(BaseDataset): self.enable_bucket = enable_bucket if self.enable_bucket: + min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps( + resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps + ) self.min_bucket_reso = min_bucket_reso self.max_bucket_reso = max_bucket_reso self.bucket_reso_steps = bucket_reso_steps diff --git a/train_controlnet.py b/train_controlnet.py index c9ac6c5a..6938c4bc 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -107,6 +107,8 @@ def train(args): ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + train_dataset_group.verify_bucket_reso_steps(64) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return diff --git a/train_db.py b/train_db.py index 39d8ea6e..2c7f0258 100644 --- a/train_db.py +++ b/train_db.py @@ -93,6 +93,8 @@ def train(args): if args.no_token_padding: train_dataset_group.disable_token_padding() + train_dataset_group.verify_bucket_reso_steps(64) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return diff --git a/train_network.py b/train_network.py index 7ba07385..044ec3aa 100644 --- a/train_network.py +++ b/train_network.py @@ -95,7 +95,7 @@ class NetworkTrainer: return logs def assert_extra_args(self, args, train_dataset_group): - pass + train_dataset_group.verify_bucket_reso_steps(64) def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index ade077c3..96e7bd50 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -99,7 +99,7 @@ class TextualInversionTrainer: self.is_sdxl = False def assert_extra_args(self, args, train_dataset_group): - pass + train_dataset_group.verify_bucket_reso_steps(64) def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) From 1567549220b5936af0c534ca23656ecd2f4882f0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Sep 2024 09:51:36 +0900 Subject: [PATCH 86/87] update help text #1632 --- library/train_util.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 0cb6383a..422dceca 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3865,8 +3865,20 @@ def add_dataset_arguments( action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする", ) - parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") - parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") + parser.add_argument( + "--min_bucket_reso", + type=int, + default=256, + help="minimum resolution for buckets, must be divisible by bucket_reso_steps " + " / bucketの最小解像度、bucket_reso_stepsで割り切れる必要があります", + ) + parser.add_argument( + "--max_bucket_reso", + type=int, + default=1024, + help="maximum resolution for buckets, must be divisible by bucket_reso_steps " + " / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります", + ) parser.add_argument( "--bucket_reso_steps", type=int, From 012e7e63a5b1acdf69c72eee4cb330a5a6defc41 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Sep 2024 23:18:16 +0900 Subject: [PATCH 87/87] fix to work linear/cosine scheduler closes #1651 ref #1393 --- library/train_util.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 422dceca..27910dc9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4496,6 +4496,15 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): **lr_scheduler_kwargs, ) + # these schedulers do not require `num_decay_steps` + if name == SchedulerType.LINEAR or name == SchedulerType.COSINE: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + **lr_scheduler_kwargs, + ) + # All other schedulers require `num_decay_steps` if num_decay_steps is None: raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")