From a9c5aa1f9336cedf1e294fd3c8c22bb649d51015 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 5 Jan 2025 22:28:51 +0900 Subject: [PATCH 01/71] add CFG to FLUX.1 sample image --- library/flux_train_utils.py | 152 ++++++++++++++++++++++++------------ 1 file changed, 104 insertions(+), 48 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f7f06c5c..9f954f58 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -40,7 +40,7 @@ def sample_images( text_encoders, sample_prompts_te_outputs, prompt_replacement=None, - controlnet=None + controlnet=None, ): if steps == 0: if not args.sample_at_first: @@ -101,7 +101,7 @@ def sample_images( steps, sample_prompts_te_outputs, prompt_replacement, - controlnet + controlnet, ) else: # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) @@ -125,7 +125,7 @@ def sample_images( steps, sample_prompts_te_outputs, prompt_replacement, - controlnet + controlnet, ) torch.set_rng_state(rng_state) @@ -147,14 +147,14 @@ def sample_image_inference( steps, sample_prompts_te_outputs, prompt_replacement, - controlnet + controlnet, ): assert isinstance(prompt_dict, dict) - # negative_prompt = prompt_dict.get("negative_prompt") + negative_prompt = prompt_dict.get("negative_prompt") sample_steps = prompt_dict.get("sample_steps", 20) width = prompt_dict.get("width", 512) height = prompt_dict.get("height", 512) - scale = prompt_dict.get("scale", 3.5) + scale = prompt_dict.get("scale", 1.0) # 1.0 means no guidance seed = prompt_dict.get("seed") controlnet_image = prompt_dict.get("controlnet_image") prompt: str = prompt_dict.get("prompt", "") @@ -162,8 +162,8 @@ def sample_image_inference( if prompt_replacement is not None: prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) - # if negative_prompt is not None: - # negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) if seed is not None: torch.manual_seed(seed) @@ -173,16 +173,18 @@ def sample_image_inference( torch.seed() torch.cuda.seed() - # if negative_prompt is None: - # negative_prompt = "" + if negative_prompt is None: + negative_prompt = "" height = max(64, height - height % 16) # round to divisible by 16 width = max(64, width - width % 16) # round to divisible by 16 logger.info(f"prompt: {prompt}") - # logger.info(f"negative_prompt: {negative_prompt}") + if scale != 1.0: + logger.info(f"negative_prompt: {negative_prompt}") logger.info(f"height: {height}") logger.info(f"width: {width}") logger.info(f"sample_steps: {sample_steps}") - logger.info(f"scale: {scale}") + if scale != 1.0: + logger.info(f"scale: {scale}") # logger.info(f"sample_sampler: {sampler_name}") if seed is not None: logger.info(f"seed: {seed}") @@ -191,26 +193,37 @@ def sample_image_inference( tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() - text_encoder_conds = [] - if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: - text_encoder_conds = sample_prompts_te_outputs[prompt] - print(f"Using cached text encoder outputs for prompt: {prompt}") - if text_encoders is not None: - print(f"Encoding prompt: {prompt}") - tokens_and_masks = tokenize_strategy.tokenize(prompt) - # strategy has apply_t5_attn_mask option - encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + def encode_prompt(prpt): + text_encoder_conds = [] + if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs: + text_encoder_conds = sample_prompts_te_outputs[prpt] + print(f"Using cached text encoder outputs for prompt: {prpt}") + if text_encoders is not None: + print(f"Encoding prompt: {prpt}") + tokens_and_masks = tokenize_strategy.tokenize(prpt) + # strategy has apply_t5_attn_mask option + encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) - # if text_encoder_conds is not cached, use encoded_text_encoder_conds - if len(text_encoder_conds) == 0: - text_encoder_conds = encoded_text_encoder_conds - else: - # if encoded_text_encoder_conds is not None, update cached text_encoder_conds - for i in range(len(encoded_text_encoder_conds)): - if encoded_text_encoder_conds[i] is not None: - text_encoder_conds[i] = encoded_text_encoder_conds[i] + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] + return text_encoder_conds - l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + l_pooled, t5_out, txt_ids, t5_attn_mask = encode_prompt(prompt) + # encode negative prompts + if scale != 1.0: + neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode_prompt(negative_prompt) + neg_t5_attn_mask = ( + neg_t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask and neg_t5_attn_mask is not None else None + ) + neg_cond = (scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask) + else: + neg_cond = None # sample image weight_dtype = ae.dtype # TOFO give dtype as argument @@ -235,7 +248,20 @@ def sample_image_inference( controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device) with accelerator.autocast(), torch.no_grad(): - x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) + x = denoise( + flux, + noise, + img_ids, + t5_out, + txt_ids, + l_pooled, + timesteps=timesteps, + guidance=scale, + t5_attn_mask=t5_attn_mask, + controlnet=controlnet, + controlnet_img=controlnet_image, + neg_cond=neg_cond, + ) x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) @@ -305,22 +331,24 @@ def denoise( model: flux_models.Flux, img: torch.Tensor, img_ids: torch.Tensor, - txt: torch.Tensor, + txt: torch.Tensor, # t5_out txt_ids: torch.Tensor, - vec: torch.Tensor, + vec: torch.Tensor, # l_pooled timesteps: list[float], guidance: float = 4.0, t5_attn_mask: Optional[torch.Tensor] = None, controlnet: Optional[flux_models.ControlNetFlux] = None, controlnet_img: Optional[torch.Tensor] = None, + neg_cond: Optional[Tuple[float, torch.Tensor, torch.Tensor, torch.Tensor]] = None, ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) - + do_cfg = neg_cond is not None for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) model.prepare_block_swap_before_forward() + if controlnet is not None: block_samples, block_single_samples = controlnet( img=img, @@ -336,20 +364,48 @@ def denoise( else: block_samples = None block_single_samples = None - pred = model( - img=img, - img_ids=img_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, - block_controlnet_hidden_states=block_samples, - block_controlnet_single_hidden_states=block_single_samples, - timesteps=t_vec, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, - ) - img = img + (t_prev - t_curr) * pred + if not do_cfg: + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + block_controlnet_hidden_states=block_samples, + block_controlnet_single_hidden_states=block_single_samples, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + img = img + (t_prev - t_curr) * pred + else: + cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask = neg_cond + nc_c_t5_attn_mask = None if t5_attn_mask is None else torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0) + + # TODO is it ok to use the same block samples for both cond and uncond? + block_samples = None if block_samples is None else torch.cat([block_samples, block_samples], dim=0) + block_single_samples = ( + None if block_single_samples is None else torch.cat([block_single_samples, block_single_samples], dim=0) + ) + + nc_c_pred = model( + img=torch.cat([img, img], dim=0), + img_ids=torch.cat([img_ids, img_ids], dim=0), + txt=torch.cat([neg_t5_out, txt], dim=0), + txt_ids=torch.cat([txt_ids, txt_ids], dim=0), + y=torch.cat([neg_l_pooled, vec], dim=0), + block_controlnet_hidden_states=block_samples, + block_controlnet_single_hidden_states=block_single_samples, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=nc_c_t5_attn_mask, + ) + neg_pred, pred = torch.chunk(nc_c_pred, 2, dim=0) + pred = neg_pred + (pred - neg_pred) * cfg_scale + + img = img + (t_prev - t_curr) * pred model.prepare_block_swap_before_forward() return img @@ -567,7 +623,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): "--controlnet_model_name_or_path", type=str, default=None, - help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)" + help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)", ) parser.add_argument( "--t5xxl_max_token_length", From 58e9e146a3c72716af909191835d4f41521b4c27 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 14 Feb 2025 14:02:21 -0500 Subject: [PATCH 02/71] Add resize interpolation configuration --- library/config_util.py | 7 ++++- library/train_util.py | 70 ++++++++++++++++++++++++++++++++++++------ 2 files changed, 67 insertions(+), 10 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index a2e07dc6..53727f25 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -75,6 +75,7 @@ class BaseSubsetParams: custom_attributes: Optional[Dict[str, Any]] = None validation_seed: int = 0 validation_split: float = 0.0 + resize_interpolation: Optional[str] = None @dataclass @@ -106,7 +107,7 @@ class BaseDatasetParams: debug_dataset: bool = False validation_seed: Optional[int] = None validation_split: float = 0.0 - + resize_interpolation: Optional[str] = None @dataclass class DreamBoothDatasetParams(BaseDatasetParams): @@ -196,6 +197,7 @@ class ConfigSanitizer: "caption_prefix": str, "caption_suffix": str, "custom_attributes": dict, + "resize_interpolation": str, } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { @@ -241,6 +243,7 @@ class ConfigSanitizer: "validation_split": float, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, + "resize_interpolation": str, } # options handled by argparse but not handled by user config @@ -525,6 +528,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu [{dataset_type} {i}] batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} + resize_interpolation: {dataset.resize_interpolation} enable_bucket: {dataset.enable_bucket} """) @@ -558,6 +562,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu token_warmup_min: {subset.token_warmup_min}, token_warmup_step: {subset.token_warmup_step}, alpha_mask: {subset.alpha_mask} + resize_interpolation: {subset.resize_interpolation} custom_attributes: {subset.custom_attributes} """), " ") diff --git a/library/train_util.py b/library/train_util.py index 39b4af85..a07834ad 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -210,6 +210,7 @@ class ImageInfo: self.text_encoder_pool2: Optional[torch.Tensor] = None self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime + self.resize_interpolation: Optional[str] = None class BucketManager: @@ -434,6 +435,7 @@ class BaseSubset: custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + resize_interpolation: Optional[str] = None, ) -> None: self.image_dir = image_dir self.alpha_mask = alpha_mask if alpha_mask is not None else False @@ -464,6 +466,8 @@ class BaseSubset: self.validation_seed = validation_seed self.validation_split = validation_split + self.resize_interpolation = resize_interpolation + class DreamBoothSubset(BaseSubset): def __init__( @@ -495,6 +499,7 @@ class DreamBoothSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + resize_interpolation: Optional[str] = None, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -522,6 +527,7 @@ class DreamBoothSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, + resize_interpolation=resize_interpolation, ) self.is_reg = is_reg @@ -564,6 +570,7 @@ class FineTuningSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + resize_interpolation: Optional[str] = None, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -591,6 +598,7 @@ class FineTuningSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, + resize_interpolation=resize_interpolation, ) self.metadata_file = metadata_file @@ -629,6 +637,7 @@ class ControlNetSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + resize_interpolation: Optional[str] = None, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -656,6 +665,7 @@ class ControlNetSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, + resize_interpolation=resize_interpolation, ) self.conditioning_data_dir = conditioning_data_dir @@ -676,6 +686,7 @@ class BaseDataset(torch.utils.data.Dataset): resolution: Optional[Tuple[int, int]], network_multiplier: float, debug_dataset: bool, + resize_interpolation: Optional[str] = None ) -> None: super().__init__() @@ -710,6 +721,10 @@ class BaseDataset(torch.utils.data.Dataset): self.image_transforms = IMAGE_TRANSFORMS + if resize_interpolation is not None: + assert validate_interpolation_fn(resize_interpolation), f"Resize interpolation \"{resize_interpolation}\" is not a valid interpolation" + self.resize_interpolation = resize_interpolation + self.image_data: Dict[str, ImageInfo] = {} self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} @@ -1499,7 +1514,9 @@ class BaseDataset(torch.utils.data.Dataset): nh = int(height * scale + 0.5) nw = int(width * scale + 0.5) assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" - image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) + interpolation = get_cv2_interpolation(subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation) + logger.info(f"Interpolation: {interpolation}") + image = cv2.resize(image, (nw, nh), interpolation=interpolation if interpolation is not None else cv2.INTER_AREA) face_cx = int(face_cx * scale + 0.5) face_cy = int(face_cy * scale + 0.5) height, width = nh, nw @@ -1596,7 +1613,7 @@ class BaseDataset(torch.utils.data.Dataset): if self.enable_bucket: img, original_size, crop_ltrb = trim_and_resize_if_required( - subset.random_crop, img, image_info.bucket_reso, image_info.resized_size + subset.random_crop, img, image_info.bucket_reso, image_info.resized_size, resize_interpolation=image_info.resize_interpolation ) else: if face_cx > 0: # 顔位置情報あり @@ -1857,8 +1874,9 @@ class DreamBoothDataset(BaseDataset): debug_dataset: bool, validation_split: float, validation_seed: Optional[int], + resize_interpolation: Optional[str], ) -> None: - super().__init__(resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" @@ -2087,6 +2105,7 @@ class DreamBoothDataset(BaseDataset): for img_path, caption, size in zip(img_paths, captions, sizes): info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path) + info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation if size is not None: info.image_size = size if subset.is_reg: @@ -2370,8 +2389,9 @@ class ControlNetDataset(BaseDataset): debug_dataset: bool, validation_split: float, validation_seed: Optional[int], + resize_interpolation: Optional[str] = None, ) -> None: - super().__init__(resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) db_subsets = [] for subset in subsets: @@ -2403,6 +2423,7 @@ class ControlNetDataset(BaseDataset): subset.caption_suffix, subset.token_warmup_min, subset.token_warmup_step, + resize_interpolation=subset.resize_interpolation, ) db_subsets.append(db_subset) @@ -2421,6 +2442,7 @@ class ControlNetDataset(BaseDataset): debug_dataset, validation_split, validation_seed, + resize_interpolation, ) # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) @@ -2430,6 +2452,7 @@ class ControlNetDataset(BaseDataset): self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images self.validation_split = validation_split self.validation_seed = validation_seed + self.resize_interpolation = resize_interpolation # assert all conditioning data exists missing_imgs = [] @@ -2517,8 +2540,10 @@ class ControlNetDataset(BaseDataset): assert ( cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1] ), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}" + + interpolation = get_cv2_interpolation(self.resize_interpolation) cond_img = cv2.resize( - cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA + cond_img, image_info.resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA ) # INTER_AREAでやりたいのでcv2でリサイズ # TODO support random crop @@ -2930,7 +2955,7 @@ def load_image(image_path, alpha=False): # 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom) def trim_and_resize_if_required( - random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int] + random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int], resize_interpolation: Optional[str] = None ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]: image_height, image_width = image.shape[0:2] original_size = (image_width, image_height) # size before resize @@ -2938,7 +2963,8 @@ def trim_and_resize_if_required( if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする if image_width > resized_size[0] and image_height > resized_size[1]: - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + interpolation = get_cv2_interpolation(resize_interpolation) + image = cv2.resize(image, resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ else: image = pil_resize(image, resized_size) @@ -2985,7 +3011,7 @@ def load_images_and_masks_for_caching( for info in image_infos: image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 - image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) + image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation) original_sizes.append(original_size) crop_ltrbs.append(crop_ltrb) @@ -3026,7 +3052,7 @@ def cache_batch_latents( for info in image_infos: image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 - image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) + image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation) info.latents_original_size = original_size info.latents_crop_ltrb = crop_ltrb @@ -6533,3 +6559,29 @@ class LossRecorder: if losses == 0: return 0 return self.loss_total / losses + +def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]: + """ + Convert interpolation ovalue to cv2 interpolation integer + """ + if interpolation is None: + return None + + if interpolation == "lanczos": + return cv2.INTER_LANCZOS4 + elif interpolation == "nearest": + return cv2.INTER_NEAREST + elif interpolation == "bilinear" or interpolation == "linear": + return cv2.INTER_LINEAR + elif interpolation == "bicubic" or interpolation == "cubic": + return cv2.INTER_CUBIC + elif interpolation == "area": + return cv2.INTER_AREA + else: + return None + +def validate_interpolation_fn(interpolation_str: str) -> bool: + """ + Check if a interpolation function is supported + """ + return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area"] From d0128d18be009c5e221db96d53b6045bdd5af04f Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 14 Feb 2025 14:14:57 -0500 Subject: [PATCH 03/71] Add resize interpolation CLI option --- library/train_util.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index a07834ad..d41d1ff3 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4530,7 +4530,13 @@ def add_dataset_arguments( action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します", ) - + parser.add_argument( + "--resize_interpolation", + type=str, + default=None, + choices=["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area"], + help="Resize interpolation when required. Default: area Options: lanczos, nearest, bilinear, bicubic, area / 必要に応じてサイズ補間を変更します。デフォルト: area オプション: lanczos, nearest, bilinear, bicubic, area", + ) parser.add_argument( "--token_warmup_min", type=int, From 7729c4c8f962d4f0b5fd73fb86399e73ab9cce8b Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 14 Feb 2025 14:18:53 -0500 Subject: [PATCH 04/71] Add metadata --- train_network.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/train_network.py b/train_network.py index 674f1cb6..16c1774e 100644 --- a/train_network.py +++ b/train_network.py @@ -973,6 +973,7 @@ class NetworkTrainer: "ss_max_validation_steps": args.max_validation_steps, "ss_validate_every_n_epochs": args.validate_every_n_epochs, "ss_validate_every_n_steps": args.validate_every_n_steps, + "ss_resize_interpolation": train_dataset_group.resize_interpolation } self.update_metadata(metadata, args) # architecture specific metadata @@ -998,6 +999,7 @@ class NetworkTrainer: "max_bucket_reso": dataset.max_bucket_reso, "tag_frequency": dataset.tag_frequency, "bucket_info": dataset.bucket_info, + "resize_interpolation": dataset.resize_interpolation, } subsets_metadata = [] @@ -1015,6 +1017,7 @@ class NetworkTrainer: "enable_wildcard": bool(subset.enable_wildcard), "caption_prefix": subset.caption_prefix, "caption_suffix": subset.caption_suffix, + "resize_interpolation": subset.resize_interpolation, } image_dir_or_metadata_file = None From 545425c13e855838781f0d0af24c4c5df992c87d Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 14 Feb 2025 14:24:25 -0500 Subject: [PATCH 05/71] Typo --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index d41d1ff3..94145cad 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -6568,7 +6568,7 @@ class LossRecorder: def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]: """ - Convert interpolation ovalue to cv2 interpolation integer + Convert interpolation value to cv2 interpolation integer """ if interpolation is None: return None From ca1c129ffd2439dec3f00a6a78a5cc5858d08cb5 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 14 Feb 2025 16:18:24 -0500 Subject: [PATCH 06/71] Fix metadata --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 16c1774e..c625eee6 100644 --- a/train_network.py +++ b/train_network.py @@ -973,7 +973,7 @@ class NetworkTrainer: "ss_max_validation_steps": args.max_validation_steps, "ss_validate_every_n_epochs": args.validate_every_n_epochs, "ss_validate_every_n_steps": args.validate_every_n_steps, - "ss_resize_interpolation": train_dataset_group.resize_interpolation + "ss_resize_interpolation": args.resize_interpolation } self.update_metadata(metadata, args) # architecture specific metadata From 7f2747176bb01757b95e086c548f2bcf8f689005 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 19 Feb 2025 14:20:24 -0500 Subject: [PATCH 07/71] Use resize_image where resizing is required --- finetune/tag_images_by_wd14_tagger.py | 7 +- library/train_util.py | 45 ++---------- library/utils.py | 101 +++++++++++++++++++++++++- tools/detect_face_rotate.py | 11 +-- tools/resize_images_to_resolution.py | 20 +---- 5 files changed, 113 insertions(+), 71 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index cbc3d2d6..406f12f2 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -11,7 +11,7 @@ from PIL import Image from tqdm import tqdm import library.train_util as train_util -from library.utils import setup_logging, pil_resize +from library.utils import setup_logging, resize_image setup_logging() import logging @@ -42,10 +42,7 @@ def preprocess_image(image): pad_t = pad_y // 2 image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255) - if size > IMAGE_SIZE: - image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA) - else: - image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE)) + image = resize_image(image, image.shape[0], image.shape[1], IMAGE_SIZE, IMAGE_SIZE) image = image.astype(np.float32) return image diff --git a/library/train_util.py b/library/train_util.py index 94145cad..46219d4f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -84,7 +84,7 @@ import library.model_util as model_util import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec import library.deepspeed_utils as deepspeed_utils -from library.utils import setup_logging, pil_resize +from library.utils import setup_logging, resize_image setup_logging() import logging @@ -1514,9 +1514,7 @@ class BaseDataset(torch.utils.data.Dataset): nh = int(height * scale + 0.5) nw = int(width * scale + 0.5) assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" - interpolation = get_cv2_interpolation(subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation) - logger.info(f"Interpolation: {interpolation}") - image = cv2.resize(image, (nw, nh), interpolation=interpolation if interpolation is not None else cv2.INTER_AREA) + image = resize_image(image, width, height, nw, nh, subset.resize_interpolation) face_cx = int(face_cx * scale + 0.5) face_cy = int(face_cy * scale + 0.5) height, width = nh, nw @@ -2541,10 +2539,7 @@ class ControlNetDataset(BaseDataset): cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1] ), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}" - interpolation = get_cv2_interpolation(self.resize_interpolation) - cond_img = cv2.resize( - cond_img, image_info.resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA - ) # INTER_AREAでやりたいのでcv2でリサイズ + cond_img = resize_image(cond_img, original_size_hw[1], original_size_hw[0], target_size_hw[1], target_size_hw[0], self.resize_interpolation) # TODO support random crop # 現在サポートしているcropはrandomではなく中央のみ @@ -2558,7 +2553,7 @@ class ControlNetDataset(BaseDataset): # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # resize to target if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: - cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0]))) + cond_img = resize_image(cond_img, cond_img.shape[0], cond_img.shape[1], target_size_hw[1], target_size_hw[0], self.resize_interpolation) if flipped: cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride @@ -2961,12 +2956,7 @@ def trim_and_resize_if_required( original_size = (image_width, image_height) # size before resize if image_width != resized_size[0] or image_height != resized_size[1]: - # リサイズする - if image_width > resized_size[0] and image_height > resized_size[1]: - interpolation = get_cv2_interpolation(resize_interpolation) - image = cv2.resize(image, resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ - else: - image = pil_resize(image, resized_size) + image = resize_image(image, image_width, image_height, resized_size[0], resized_size[1], resize_interpolation) image_height, image_width = image.shape[0:2] @@ -6566,28 +6556,3 @@ class LossRecorder: return 0 return self.loss_total / losses -def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]: - """ - Convert interpolation value to cv2 interpolation integer - """ - if interpolation is None: - return None - - if interpolation == "lanczos": - return cv2.INTER_LANCZOS4 - elif interpolation == "nearest": - return cv2.INTER_NEAREST - elif interpolation == "bilinear" or interpolation == "linear": - return cv2.INTER_LINEAR - elif interpolation == "bicubic" or interpolation == "cubic": - return cv2.INTER_CUBIC - elif interpolation == "area": - return cv2.INTER_AREA - else: - return None - -def validate_interpolation_fn(interpolation_str: str) -> bool: - """ - Check if a interpolation function is supported - """ - return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area"] diff --git a/library/utils.py b/library/utils.py index 07079c6d..9156864e 100644 --- a/library/utils.py +++ b/library/utils.py @@ -16,7 +16,6 @@ from PIL import Image import numpy as np from safetensors.torch import load_file - def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() @@ -89,6 +88,8 @@ def setup_logging(args=None, log_level=None, reset=False): logger = logging.getLogger(__name__) logger.info(msg_init) +setup_logging() +logger = logging.getLogger(__name__) # endregion @@ -377,7 +378,7 @@ def load_safetensors( # region Image utils -def pil_resize(image, size, interpolation=Image.LANCZOS): +def pil_resize(image, size, interpolation): has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False if has_alpha: @@ -385,7 +386,7 @@ def pil_resize(image, size, interpolation=Image.LANCZOS): else: pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - resized_pil = pil_image.resize(size, interpolation) + resized_pil = pil_image.resize(size, resample=interpolation) # Convert back to cv2 format if has_alpha: @@ -396,6 +397,100 @@ def pil_resize(image, size, interpolation=Image.LANCZOS): return resized_cv2 +def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None): + """ + Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS + + Args: + image: numpy.ndarray + width: int Original image width + height: int Original image height + resized_width: int Resized image width + resized_height: int Resized image height + resize_interpolation: Optional[str] Resize interpolation method "lanczos", "area", "bilinear", "bicubic", "nearest", "box" + + Returns: + image + """ + interpolation = get_cv2_interpolation(resize_interpolation) + resized_size = (resized_width, resized_height) + if width > resized_width and height > resized_width: + image = cv2.resize(image, resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + logger.debug(f"resize image using {resize_interpolation}") + else: + image = cv2.resize(image, resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_LANCZOS4) # INTER_AREAでやりたいのでcv2でリサイズ + logger.debug(f"resize image using {resize_interpolation}") + + return image + + +def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]: + """ + Convert interpolation value to cv2 interpolation integer + + https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121 + """ + if interpolation is None: + return None + + if interpolation == "lanczos" or interpolation == "lanczos4": + # Lanczos interpolation over 8x8 neighborhood + return cv2.INTER_LANCZOS4 + elif interpolation == "nearest": + # Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab. + return cv2.INTER_NEAREST_EXACT + elif interpolation == "bilinear" or interpolation == "linear": + # bilinear interpolation + return cv2.INTER_LINEAR + elif interpolation == "bicubic" or interpolation == "cubic": + # bicubic interpolation + return cv2.INTER_CUBIC + elif interpolation == "area": + # resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method. + return cv2.INTER_AREA + elif interpolation == "box": + # resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method. + return cv2.INTER_AREA + else: + return None + +def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]: + """ + Convert interpolation value to PIL interpolation + + https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters + """ + if interpolation is None: + return None + + if interpolation == "lanczos": + return Image.Resampling.LANCZOS + elif interpolation == "nearest": + # Pick one nearest pixel from the input image. Ignore all other input pixels. + return Image.Resampling.NEAREST + elif interpolation == "bilinear" or interpolation == "linear": + # For resize calculate the output pixel value using linear interpolation on all pixels that may contribute to the output value. For other transformations linear interpolation over a 2x2 environment in the input image is used. + return Image.Resampling.BILINEAR + elif interpolation == "bicubic" or interpolation == "cubic": + # For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used. + return Image.Resampling.BICUBIC + elif interpolation == "area": + # Image.Resampling.BOX may be more appropriate if upscaling + # Area interpolation is related to cv2.INTER_AREA + # Produces a sharper image than Resampling.BILINEAR, doesn’t have dislocations on local level like with Resampling.BOX. + return Image.Resampling.HAMMING + elif interpolation == "box": + # Each pixel of source image contributes to one pixel of the destination image with identical weights. For upscaling is equivalent of Resampling.NEAREST. + return Image.Resampling.BOX + else: + return None + +def validate_interpolation_fn(interpolation_str: str) -> bool: + """ + Check if a interpolation function is supported + """ + return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"] + # endregion # TODO make inf_utils.py diff --git a/tools/detect_face_rotate.py b/tools/detect_face_rotate.py index d2a4d9cf..16fd7d0b 100644 --- a/tools/detect_face_rotate.py +++ b/tools/detect_face_rotate.py @@ -15,7 +15,7 @@ import os from anime_face_detector import create_detector from tqdm import tqdm import numpy as np -from library.utils import setup_logging, pil_resize +from library.utils import setup_logging, resize_image setup_logging() import logging logger = logging.getLogger(__name__) @@ -170,12 +170,9 @@ def process(args): scale = max(cur_crop_width / w, cur_crop_height / h) if scale != 1.0: - w = int(w * scale + .5) - h = int(h * scale + .5) - if scale < 1.0: - face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA) - else: - face_img = pil_resize(face_img, (w, h)) + rw = int(w * scale + .5) + rh = int(h * scale + .5) + face_img = resize_image(face_img, w, h, rw, rh) cx = int(cx * scale + .5) cy = int(cy * scale + .5) fw = int(fw * scale + .5) diff --git a/tools/resize_images_to_resolution.py b/tools/resize_images_to_resolution.py index 0f9e00b1..f5fbae2b 100644 --- a/tools/resize_images_to_resolution.py +++ b/tools/resize_images_to_resolution.py @@ -6,7 +6,7 @@ import shutil import math from PIL import Image import numpy as np -from library.utils import setup_logging, pil_resize +from library.utils import setup_logging, resize_image setup_logging() import logging logger = logging.getLogger(__name__) @@ -22,14 +22,6 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi if not os.path.exists(dst_img_folder): os.makedirs(dst_img_folder) - # Select interpolation method - if interpolation == 'lanczos4': - pil_interpolation = Image.LANCZOS - elif interpolation == 'cubic': - pil_interpolation = Image.BICUBIC - else: - cv2_interpolation = cv2.INTER_AREA - # Iterate through all files in src_img_folder img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py for filename in os.listdir(src_img_folder): @@ -63,11 +55,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi new_height = int(img.shape[0] * math.sqrt(scale_factor)) new_width = int(img.shape[1] * math.sqrt(scale_factor)) - # Resize image - if cv2_interpolation: - img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) - else: - img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation) + img = resize_image(img, img.shape[0], img.shape[1], new_height, new_width, interpolation) else: new_height, new_width = img.shape[0:2] @@ -113,8 +101,8 @@ def setup_parser() -> argparse.ArgumentParser: help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128") parser.add_argument('--divisible_by', type=int, help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1) - parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'], - default='area', help='Interpolation method for resizing / リサイズ時の補完方法') + parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4', 'nearest', 'linear', 'box'], + default=None, help='Interpolation method for resizing. Default to area if smaller, lanczos if larger / サイズ変更の補間方法。小さい場合はデフォルトでエリア、大きい場合はランチョスになります。') parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存') parser.add_argument('--copy_associated_files', action='store_true', help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする') From 9a415ba9651be2d4be7b3321ad244f9001b83738 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Thu, 27 Feb 2025 00:21:57 +0300 Subject: [PATCH 08/71] JPEG XL support --- library/jpeg_xl_util.py | 184 ++++++++++++++++++++++++++++++++++++++++ library/train_util.py | 6 +- 2 files changed, 189 insertions(+), 1 deletion(-) create mode 100644 library/jpeg_xl_util.py diff --git a/library/jpeg_xl_util.py b/library/jpeg_xl_util.py new file mode 100644 index 00000000..82da10ef --- /dev/null +++ b/library/jpeg_xl_util.py @@ -0,0 +1,184 @@ +# Modifed from https://github.com/Fraetor/jxl_decode +# Added partial read support for 200x speedup +import os + +class JXLBitstream: + """ + A stream of bits with methods for easy handling. + """ + + def __init__(self, file, offset=0, offsets=[]) -> None: + self.shift = 0 + self.bitstream = [] + self.file = file + self.offset = offset + self.offsets = offsets + if self.offsets: + self.offset = self.offsets[0][1] + self.previous_data_len = 0 + self.index = 0 + self.file.seek(self.offset) + + def get_bits(self, length: int = 1) -> int: + if self.offsets and self.shift + length > self.previous_data_len + self.offsets[self.index][2]: + self.partial_to_read_length = length + if self.shift < self.previous_data_len + self.offsets[self.index][2]: + self.partial_read(0, length) + self.bitstream += self.file.read(self.partial_to_read_length) + else: + self.bitstream += self.file.read(length) + bitmask = 2**length - 1 + bits = (int.from_bytes(self.bitstream, "little") >> self.shift) & bitmask + self.shift += length + return bits + + def partial_read(self, readed_length, length): + self.previous_data_len += self.offsets[self.index][2] + to_read_length = self.previous_data_len - (self.shift + readed_length) + self.bitstream += self.file.read(to_read_length) + readed_length += to_read_length + self.partial_to_read_length -= to_read_length + self.index += 1 + self.file.seek(self.offsets[self.index][1]) + if self.shift + length > self.previous_data_len + self.offsets[self.index][2]: + self.partial_read(readed_length, length) + + +def decode_codestream(file, offset=0, offsets=[]): + """ + Decodes the actual codestream. + JXL codestream specification: http://www-internal/2022/18181-1 + """ + + # Convert codestream to int within an object to get some handy methods. + codestream = JXLBitstream(file, offset=offset, offsets=offsets) + + # Skip signature + codestream.get_bits(16) + + # SizeHeader + div8 = codestream.get_bits(1) + if div8: + height = 8 * (1 + codestream.get_bits(5)) + else: + distribution = codestream.get_bits(2) + match distribution: + case 0: + height = 1 + codestream.get_bits(9) + case 1: + height = 1 + codestream.get_bits(13) + case 2: + height = 1 + codestream.get_bits(18) + case 3: + height = 1 + codestream.get_bits(30) + ratio = codestream.get_bits(3) + if div8 and not ratio: + width = 8 * (1 + codestream.get_bits(5)) + elif not ratio: + distribution = codestream.get_bits(2) + match distribution: + case 0: + width = 1 + codestream.get_bits(9) + case 1: + width = 1 + codestream.get_bits(13) + case 2: + width = 1 + codestream.get_bits(18) + case 3: + width = 1 + codestream.get_bits(30) + else: + match ratio: + case 1: + width = height + case 2: + width = (height * 12) // 10 + case 3: + width = (height * 4) // 3 + case 4: + width = (height * 3) // 2 + case 5: + width = (height * 16) // 9 + case 6: + width = (height * 5) // 4 + case 7: + width = (height * 2) // 1 + return width, height + + +def decode_container(file): + """ + Parses the ISOBMFF container, extracts the codestream, and decodes it. + JXL container specification: http://www-internal/2022/18181-2 + """ + + def parse_box(file, file_start) -> dict: + file.seek(file_start) + LBox = int.from_bytes(file.read(4), "big") + XLBox = None + if 1 < LBox <= 8: + raise ValueError(f"Invalid LBox at byte {file_start}.") + if LBox == 1: + file.seek(file_start + 8) + XLBox = int.from_bytes(file.read(8), "big") + if XLBox <= 16: + raise ValueError(f"Invalid XLBox at byte {file_start}.") + if XLBox: + header_length = 16 + box_length = XLBox + else: + header_length = 8 + if LBox == 0: + box_length = os.fstat(file.fileno()).st_size - file_start + else: + box_length = LBox + file.seek(file_start + 4) + box_type = file.read(4) + file.seek(file_start) + return { + "length": box_length, + "type": box_type, + "offset": header_length, + } + + file.seek(0) + # Reject files missing required boxes. These two boxes are required to be at + # the start and contain no values, so we can manually check there presence. + # Signature box. (Redundant as has already been checked.) + if file.read(12) != bytes.fromhex("0000000C 4A584C20 0D0A870A"): + raise ValueError("Invalid signature box.") + # File Type box. + if file.read(20) != bytes.fromhex( + "00000014 66747970 6A786C20 00000000 6A786C20" + ): + raise ValueError("Invalid file type box.") + + offset = 0 + offsets = [] + data_offset_not_found = True + container_pointer = 32 + file_size = os.fstat(file.fileno()).st_size + while data_offset_not_found: + box = parse_box(file, container_pointer) + match box["type"]: + case b"jxlc": + offset = container_pointer + box["offset"] + data_offset_not_found = False + case b"jxlp": + file.seek(container_pointer + box["offset"]) + index = int.from_bytes(file.read(4), "big") + offsets.append([index, container_pointer + box["offset"] + 4, box["length"] - box["offset"] - 4]) + container_pointer += box["length"] + if container_pointer >= file_size: + data_offset_not_found = False + + if offsets: + offsets.sort(key=lambda i: i[0]) + file.seek(0) + + return decode_codestream(file, offset=offset, offsets=offsets) + + +def get_jxl_size(path): + with open(path, "rb") as file: + if file.read(2) == bytes.fromhex("FF0A"): + return decode_codestream(file) + return decode_container(file) diff --git a/library/train_util.py b/library/train_util.py index 100ef475..916b8834 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -118,14 +118,16 @@ except: # JPEG-XL on Linux try: from jxlpy import JXLImagePlugin + from library.jpeg_xl_util import get_jxl_size IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) except: pass -# JPEG-XL on Windows +# JPEG-XL on Linux and Windows try: import pillow_jxl + from library.jpeg_xl_util import get_jxl_size IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) except: @@ -1156,6 +1158,8 @@ class BaseDataset(torch.utils.data.Dataset): ) def get_image_size(self, image_path): + if image_path.endswith(".jxl"): + return get_jxl_size(image_path) return imagesize.get(image_path) def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False): From 2f69f4dbdb679ca887d3bc4438667360ae934ca9 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Thu, 27 Feb 2025 00:30:19 +0300 Subject: [PATCH 09/71] fix typo --- library/jpeg_xl_util.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/library/jpeg_xl_util.py b/library/jpeg_xl_util.py index 82da10ef..d127d44d 100644 --- a/library/jpeg_xl_util.py +++ b/library/jpeg_xl_util.py @@ -1,5 +1,5 @@ -# Modifed from https://github.com/Fraetor/jxl_decode -# Added partial read support for 200x speedup +# Modified from https://github.com/Fraetor/jxl_decode +# Added partial read support for up to 200x speedup import os class JXLBitstream: @@ -32,16 +32,16 @@ class JXLBitstream: self.shift += length return bits - def partial_read(self, readed_length, length): + def partial_read(self, current_length, length): self.previous_data_len += self.offsets[self.index][2] - to_read_length = self.previous_data_len - (self.shift + readed_length) + to_read_length = self.previous_data_len - (self.shift + current_length) self.bitstream += self.file.read(to_read_length) - readed_length += to_read_length + current_length += to_read_length self.partial_to_read_length -= to_read_length self.index += 1 self.file.seek(self.offsets[self.index][1]) if self.shift + length > self.previous_data_len + self.offsets[self.index][2]: - self.partial_read(readed_length, length) + self.partial_read(current_length, length) def decode_codestream(file, offset=0, offsets=[]): From acdca2abb781eb207cb760c759dc4d23f8ca5e72 Mon Sep 17 00:00:00 2001 From: Ivan Chikish Date: Sat, 1 Mar 2025 17:06:17 +0300 Subject: [PATCH 10/71] Fix [occasionally] missing text encoder attn modules Should fix #1952 I added alternative name for CLIPAttention. I have no idea why this name changed. Now it should accept both names. --- networks/dylora.py | 2 +- networks/lora.py | 2 +- networks/lora_diffusers.py | 2 +- networks/lora_fa.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/networks/dylora.py b/networks/dylora.py index b0925453..82d96f59 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -268,7 +268,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh class DyLoRANetwork(torch.nn.Module): UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"] LORA_PREFIX_UNET = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" diff --git a/networks/lora.py b/networks/lora.py index 6f33f1a1..1699a60f 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -866,7 +866,7 @@ class LoRANetwork(torch.nn.Module): UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"] LORA_PREFIX_UNET = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" diff --git a/networks/lora_diffusers.py b/networks/lora_diffusers.py index b99b0244..56b74d10 100644 --- a/networks/lora_diffusers.py +++ b/networks/lora_diffusers.py @@ -278,7 +278,7 @@ def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0): class LoRANetwork(torch.nn.Module): UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"] LORA_PREFIX_UNET = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 919222ce..5fe778b4 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -755,7 +755,7 @@ class LoRANetwork(torch.nn.Module): UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"] LORA_PREFIX_UNET = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" From 3f49053c9068a0dcfa3a360d032529d87f878f8b Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sun, 2 Mar 2025 19:32:06 +0800 Subject: [PATCH 11/71] fatser fix bug for SDXL super SD1.5 assert cant use 32 --- sdxl_train_network.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 83969bb1..3559ab88 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -18,7 +18,6 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): self.is_sdxl = True def assert_extra_args(self, args, train_dataset_group): - super().assert_extra_args(args, train_dataset_group) sdxl_train_util.verify_sdxl_training_args(args) if args.cache_text_encoder_outputs: From ea53290f625b29c2cfc1c63cc83d6dcd1492731c Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 6 Mar 2025 00:00:38 -0500 Subject: [PATCH 12/71] Add LoRA-GGPO for Flux --- networks/lora_flux.py | 134 +++++++++++++++++++++++++++++++++++++++++- train_network.py | 4 ++ 2 files changed, 137 insertions(+), 1 deletion(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 91e9cd77..98cf8c55 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -9,6 +9,7 @@ import math import os +from contextlib import contextmanager from typing import Dict, List, Optional, Tuple, Type, Union from diffusers import AutoencoderKL from transformers import CLIPTextModel @@ -27,6 +28,42 @@ logger = logging.getLogger(__name__) NUM_DOUBLE_BLOCKS = 19 NUM_SINGLE_BLOCKS = 38 +@contextmanager +def temp_random_seed(seed, device=None): + """ + Context manager that temporarily sets a specific random seed and then + restores the original RNG state afterward. + + Args: + seed (int): The random seed to set temporarily + device (torch.device, optional): The device to set the seed for. + If None, will detect from the current context. + """ + # Save original RNG states + original_cpu_rng_state = torch.get_rng_state() + original_cuda_rng_states = None + if torch.cuda.is_available(): + original_cuda_rng_states = torch.cuda.get_rng_state_all() + + # Determine if we need to set CUDA seed + set_cuda = False + if device is not None: + set_cuda = device.type == 'cuda' + elif torch.cuda.is_available(): + set_cuda = True + + try: + # Set the temporary seed + torch.manual_seed(seed) + if set_cuda: + torch.cuda.manual_seed_all(seed) + yield + finally: + # Restore original RNG states + torch.set_rng_state(original_cpu_rng_state) + if torch.cuda.is_available() and original_cuda_rng_states is not None: + torch.cuda.set_rng_state_all(original_cuda_rng_states) + class LoRAModule(torch.nn.Module): """ @@ -44,6 +81,8 @@ class LoRAModule(torch.nn.Module): rank_dropout=None, module_dropout=None, split_dims: Optional[List[int]] = None, + ggpo_beta: Optional[float] = None, + ggpo_sigma: Optional[float] = None, ): """ if alpha == 0 or None, alpha is rank (no scaling). @@ -103,9 +142,16 @@ class LoRAModule(torch.nn.Module): self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.ggpo_sigma = ggpo_sigma + self.ggpo_beta = ggpo_beta + + self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0]) + self._org_module_weight = self.org_module.weight.detach() + def apply_to(self): self.org_forward = self.org_module.forward self.org_module.forward = self.forward + del self.org_module def forward(self, x): @@ -140,7 +186,15 @@ class LoRAModule(torch.nn.Module): lx = self.lora_up(lx) - return org_forwarded + lx * self.multiplier * scale + # LoRA Gradient-Guided Perturbation Optimization + if self.training and hasattr(self, 'perturbation_seed') and self.ggpo_sigma is not None and self.ggpo_beta is not None: + with torch.no_grad(), torch.autocast(self.device.type), temp_random_seed(self.perturbation_seed): + perturbation = torch.randn_like(self._org_module_weight, dtype=self.dtype, device=self.device) + perturbation.mul_(self.perturbation_scale_factor) + perturbation_output = x @ perturbation.T # Result: (batch × n) + return org_forwarded + (self.multiplier * scale * lx) + perturbation_output + else: + return org_forwarded + lx * self.multiplier * scale else: lxs = [lora_down(x) for lora_down in self.lora_down] @@ -167,6 +221,58 @@ class LoRAModule(torch.nn.Module): return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale + @torch.no_grad() + def update_norms(self): + # Not running GGPO so not currently running update norms + if self.ggpo_beta is None or self.ggpo_sigma is None: + return + + # only update norms when we are training + if self.lora_down.weight.requires_grad is not True: + print(f"skipping update_norms for {self.lora_name}") + return + + lora_down_grad = None + lora_up_grad = None + + for name, param in self.named_parameters(): + if name == "lora_down.weight": + lora_down_grad = param.grad + elif name == "lora_up.weight": + lora_up_grad = param.grad + + with torch.autocast(self.device.type): + module_weights = self.scale * (self.lora_up.weight @ self.lora_down.weight) + org_device = self._org_module_weight.device + org_dtype = self._org_module_weight.dtype + org_weight = self._org_module_weight.to(device=self.device, dtype=self.dtype) + combined_weight = org_weight + module_weights + + self.combined_weight_norms = torch.norm(combined_weight, dim=1, keepdim=True) + + self._org_module_weight.to(device=org_device, dtype=org_dtype) + + + # Calculate gradient norms if we have both gradients + if lora_down_grad is not None and lora_up_grad is not None: + with torch.autocast(self.device.type): + approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight)) + self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True) + + self.perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2)) + self.perturbation_scale_factor = (self.perturbation_scale * self.perturbation_norm_factor).to(self.device) + + # LoRA Gradient-Guided Perturbation Optimization + self.perturbation_seed = torch.randint(0, 2**32 - 1, (1,)).detach().item() + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + class LoRAInfModule(LoRAModule): def __init__( @@ -420,6 +526,16 @@ def create_network( if split_qkv is not None: split_qkv = True if split_qkv == "True" else False + ggpo_beta = kwargs.get("ggpo_beta", None) + ggpo_sigma = kwargs.get("ggpo_sigma", None) + + if ggpo_beta is not None: + ggpo_beta = float(ggpo_beta) + + if ggpo_sigma is not None: + ggpo_sigma = float(ggpo_sigma) + + # train T5XXL train_t5xxl = kwargs.get("train_t5xxl", False) if train_t5xxl is not None: @@ -449,6 +565,8 @@ def create_network( in_dims=in_dims, train_double_block_indices=train_double_block_indices, train_single_block_indices=train_single_block_indices, + ggpo_beta=ggpo_beta, + ggpo_sigma=ggpo_sigma, verbose=verbose, ) @@ -561,6 +679,8 @@ class LoRANetwork(torch.nn.Module): in_dims: Optional[List[int]] = None, train_double_block_indices: Optional[List[bool]] = None, train_single_block_indices: Optional[List[bool]] = None, + ggpo_beta: Optional[float] = None, + ggpo_sigma: Optional[float] = None, verbose: Optional[bool] = False, ) -> None: super().__init__() @@ -599,10 +719,16 @@ class LoRANetwork(torch.nn.Module): # logger.info( # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" # ) + + if ggpo_beta is not None and ggpo_sigma is not None: + logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}") + if self.split_qkv: logger.info(f"split qkv for LoRA") if self.train_blocks is not None: logger.info(f"train {self.train_blocks} blocks only") + + if train_t5xxl: logger.info(f"train T5XXL as well") @@ -722,6 +848,8 @@ class LoRANetwork(torch.nn.Module): rank_dropout=rank_dropout, module_dropout=module_dropout, split_dims=split_dims, + ggpo_beta=ggpo_beta, + ggpo_sigma=ggpo_sigma, ) loras.append(lora) @@ -790,6 +918,10 @@ class LoRANetwork(torch.nn.Module): for lora in self.text_encoder_loras + self.unet_loras: lora.enabled = is_enabled + def update_norms(self): + for lora in self.text_encoder_loras + self.unet_loras: + lora.update_norms() + def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file diff --git a/train_network.py b/train_network.py index 2d279b3b..9db335b0 100644 --- a/train_network.py +++ b/train_network.py @@ -1400,6 +1400,10 @@ class NetworkTrainer: params_to_clip = accelerator.unwrap_model(network).get_trainable_params() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + if global_step % 5 == 0: + if hasattr(network, "update_norms"): + network.update_norms() + optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) From e5b5c7e1db5a5c8d7e0628cd565e9619f9564adb Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Sat, 15 Mar 2025 13:29:32 +0800 Subject: [PATCH 13/71] Update requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index de39f588..52c3b8c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -43,4 +43,5 @@ rich==13.7.0 # for T5XXL tokenizer (SD3/FLUX) sentencepiece==0.2.0 # for kohya_ss library +pytorch-optimizer -e . From 7e90cdd47a6019739d59beb161741891b79eeaef Mon Sep 17 00:00:00 2001 From: Disty0 Date: Mon, 17 Mar 2025 17:26:08 +0300 Subject: [PATCH 14/71] use bytearray and add typing hints --- library/jpeg_xl_util.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/library/jpeg_xl_util.py b/library/jpeg_xl_util.py index d127d44d..3f5e7f72 100644 --- a/library/jpeg_xl_util.py +++ b/library/jpeg_xl_util.py @@ -1,15 +1,17 @@ # Modified from https://github.com/Fraetor/jxl_decode # Added partial read support for up to 200x speedup + import os +from typing import List, Tuple class JXLBitstream: """ A stream of bits with methods for easy handling. """ - def __init__(self, file, offset=0, offsets=[]) -> None: + def __init__(self, file, offset: int = 0, offsets: List[List[int]] = None): self.shift = 0 - self.bitstream = [] + self.bitstream = bytearray() self.file = file self.offset = offset self.offsets = offsets @@ -32,7 +34,7 @@ class JXLBitstream: self.shift += length return bits - def partial_read(self, current_length, length): + def partial_read(self, current_length: int, length: int) -> None: self.previous_data_len += self.offsets[self.index][2] to_read_length = self.previous_data_len - (self.shift + current_length) self.bitstream += self.file.read(to_read_length) @@ -44,7 +46,7 @@ class JXLBitstream: self.partial_read(current_length, length) -def decode_codestream(file, offset=0, offsets=[]): +def decode_codestream(file, offset: int = 0, offsets: List[List[int]] = None) -> Tuple[int,int]: """ Decodes the actual codestream. JXL codestream specification: http://www-internal/2022/18181-1 @@ -104,13 +106,13 @@ def decode_codestream(file, offset=0, offsets=[]): return width, height -def decode_container(file): +def decode_container(file) -> Tuple[int,int]: """ Parses the ISOBMFF container, extracts the codestream, and decodes it. JXL container specification: http://www-internal/2022/18181-2 """ - def parse_box(file, file_start) -> dict: + def parse_box(file, file_start: int) -> dict: file.seek(file_start) LBox = int.from_bytes(file.read(4), "big") XLBox = None @@ -177,7 +179,7 @@ def decode_container(file): return decode_codestream(file, offset=offset, offsets=offsets) -def get_jxl_size(path): +def get_jxl_size(path: str) -> Tuple[int,int]: with open(path, "rb") as file: if file.read(2) == bytes.fromhex("FF0A"): return decode_codestream(file) From 564ec5fb7f6027d89b565cff1e8ed81a9e89ae07 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Mon, 17 Mar 2025 17:41:03 +0300 Subject: [PATCH 15/71] use extend instead of += --- library/jpeg_xl_util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/library/jpeg_xl_util.py b/library/jpeg_xl_util.py index 3f5e7f72..ade24a05 100644 --- a/library/jpeg_xl_util.py +++ b/library/jpeg_xl_util.py @@ -26,9 +26,9 @@ class JXLBitstream: self.partial_to_read_length = length if self.shift < self.previous_data_len + self.offsets[self.index][2]: self.partial_read(0, length) - self.bitstream += self.file.read(self.partial_to_read_length) + self.bitstream.extend(self.file.read(self.partial_to_read_length)) else: - self.bitstream += self.file.read(length) + self.bitstream.extend(self.file.read(length)) bitmask = 2**length - 1 bits = (int.from_bytes(self.bitstream, "little") >> self.shift) & bitmask self.shift += length @@ -37,7 +37,7 @@ class JXLBitstream: def partial_read(self, current_length: int, length: int) -> None: self.previous_data_len += self.offsets[self.index][2] to_read_length = self.previous_data_len - (self.shift + current_length) - self.bitstream += self.file.read(to_read_length) + self.bitstream.extend(self.file.read(to_read_length)) current_length += to_read_length self.partial_to_read_length -= to_read_length self.index += 1 From 620a06f517032fff9842b81950795bb14c0ad361 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Mon, 17 Mar 2025 17:44:29 +0300 Subject: [PATCH 16/71] Check for uppercase file extension too --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 916b8834..3b6f7663 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1158,7 +1158,7 @@ class BaseDataset(torch.utils.data.Dataset): ) def get_image_size(self, image_path): - if image_path.endswith(".jxl"): + if image_path.endswith(".jxl") or image_path.endswith(".JXL"): return get_jxl_size(image_path) return imagesize.get(image_path) From 3647d065b50d74ade3642edd0ec99a2ce1041edf Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 18 Mar 2025 14:25:09 -0400 Subject: [PATCH 17/71] Cache weight norms estimate on initialization. Move to update norms every step --- networks/lora_flux.py | 142 ++++++++++++++++++++++++++++++++++-------- train_network.py | 36 ++++++++--- 2 files changed, 145 insertions(+), 33 deletions(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 98cf8c55..9f5f1916 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -15,6 +15,7 @@ from diffusers import AutoencoderKL from transformers import CLIPTextModel import numpy as np import torch +from torch import Tensor import re from library.utils import setup_logging from library.sdxl_original_unet import SdxlUNet2DConditionModel @@ -145,8 +146,13 @@ class LoRAModule(torch.nn.Module): self.ggpo_sigma = ggpo_sigma self.ggpo_beta = ggpo_beta - self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0]) - self._org_module_weight = self.org_module.weight.detach() + if self.ggpo_beta is not None and self.ggpo_sigma is not None: + self.combined_weight_norms = None + self.grad_norms = None + self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0]) + self.perturbation_seed = torch.randint(0, 2**32 - 1, (1,)).detach().item() + self.initialize_norm_cache(org_module.weight) + self.org_module_shape: tuple[int] = org_module.weight.shape def apply_to(self): self.org_forward = self.org_module.forward @@ -187,10 +193,12 @@ class LoRAModule(torch.nn.Module): lx = self.lora_up(lx) # LoRA Gradient-Guided Perturbation Optimization - if self.training and hasattr(self, 'perturbation_seed') and self.ggpo_sigma is not None and self.ggpo_beta is not None: - with torch.no_grad(), torch.autocast(self.device.type), temp_random_seed(self.perturbation_seed): - perturbation = torch.randn_like(self._org_module_weight, dtype=self.dtype, device=self.device) - perturbation.mul_(self.perturbation_scale_factor) + if self.training and hasattr(self, 'perturbation_seed') and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None: + with torch.no_grad(), temp_random_seed(self.perturbation_seed): + perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2)) + perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device) + perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device) + perturbation.mul_(perturbation_scale_factor) perturbation_output = x @ perturbation.T # Result: (batch × n) return org_forwarded + (self.multiplier * scale * lx) + perturbation_output else: @@ -221,6 +229,69 @@ class LoRAModule(torch.nn.Module): return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale + @torch.no_grad() + def initialize_norm_cache(self, org_module_weight: Tensor): + # Choose a reasonable sample size + n_rows = org_module_weight.shape[0] + sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller + + # Sample random indices across all rows + indices = torch.randperm(n_rows)[:sample_size] + + # Convert to a supported data type first, then index + # Use float32 for indexing operations + weights_float32 = org_module_weight.to(dtype=torch.float32) + sampled_weights = weights_float32[indices].to(device=self.device) + + # Calculate sampled norms + sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True) + + # Store the mean norm as our estimate + self.org_weight_norm_estimate = sampled_norms.mean() + + # Optional: store standard deviation for confidence intervals + self.org_weight_norm_std = sampled_norms.std() + + # Free memory + del sampled_weights, weights_float32 + + @torch.no_grad() + def validate_norm_approximation(self, org_module_weight: Tensor, verbose=True): + # Calculate the true norm (this will be slow but it's just for validation) + true_norms = [] + chunk_size = 1024 # Process in chunks to avoid OOM + + for i in range(0, org_module_weight.shape[0], chunk_size): + end_idx = min(i + chunk_size, org_module_weight.shape[0]) + chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype) + chunk_norms = torch.norm(chunk, dim=1, keepdim=True) + true_norms.append(chunk_norms.cpu()) + del chunk + + true_norms = torch.cat(true_norms, dim=0) + true_mean_norm = true_norms.mean().item() + + # Compare with our estimate + estimated_norm = self.org_weight_norm_estimate.item() + + # Calculate error metrics + absolute_error = abs(true_mean_norm - estimated_norm) + relative_error = absolute_error / true_mean_norm * 100 # as percentage + + if verbose: + logger.info(f"True mean norm: {true_mean_norm:.6f}") + logger.info(f"Estimated norm: {estimated_norm:.6f}") + logger.info(f"Absolute error: {absolute_error:.6f}") + logger.info(f"Relative error: {relative_error:.2f}%") + + return { + 'true_mean_norm': true_mean_norm, + 'estimated_norm': estimated_norm, + 'absolute_error': absolute_error, + 'relative_error': relative_error + } + + @torch.no_grad() def update_norms(self): # Not running GGPO so not currently running update norms @@ -228,8 +299,20 @@ class LoRAModule(torch.nn.Module): return # only update norms when we are training - if self.lora_down.weight.requires_grad is not True: - print(f"skipping update_norms for {self.lora_name}") + if self.training is False: + return + + module_weights = self.lora_up.weight @ self.lora_down.weight + module_weights.mul(self.scale) + + self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True) + self.combined_weight_norms = torch.sqrt((self.org_weight_norm_estimate**2) + + torch.sum(module_weights**2, dim=1, keepdim=True)) + + @torch.no_grad() + def update_grad_norms(self): + if self.training is False: + print(f"skipping update_grad_norms for {self.lora_name}") return lora_down_grad = None @@ -241,29 +324,12 @@ class LoRAModule(torch.nn.Module): elif name == "lora_up.weight": lora_up_grad = param.grad - with torch.autocast(self.device.type): - module_weights = self.scale * (self.lora_up.weight @ self.lora_down.weight) - org_device = self._org_module_weight.device - org_dtype = self._org_module_weight.dtype - org_weight = self._org_module_weight.to(device=self.device, dtype=self.dtype) - combined_weight = org_weight + module_weights - - self.combined_weight_norms = torch.norm(combined_weight, dim=1, keepdim=True) - - self._org_module_weight.to(device=org_device, dtype=org_dtype) - - # Calculate gradient norms if we have both gradients if lora_down_grad is not None and lora_up_grad is not None: with torch.autocast(self.device.type): approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight)) self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True) - self.perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2)) - self.perturbation_scale_factor = (self.perturbation_scale * self.perturbation_norm_factor).to(self.device) - - # LoRA Gradient-Guided Perturbation Optimization - self.perturbation_seed = torch.randint(0, 2**32 - 1, (1,)).detach().item() @property def device(self): @@ -922,6 +988,32 @@ class LoRANetwork(torch.nn.Module): for lora in self.text_encoder_loras + self.unet_loras: lora.update_norms() + def update_grad_norms(self): + for lora in self.text_encoder_loras + self.unet_loras: + lora.update_grad_norms() + + def grad_norms(self) -> Tensor: + grad_norms = [] + for lora in self.text_encoder_loras + self.unet_loras: + if hasattr(lora, "grad_norms") and lora.grad_norms is not None: + grad_norms.append(lora.grad_norms.mean(dim=0)) + return torch.stack(grad_norms) if len(grad_norms) > 0 else torch.tensor([]) + + def weight_norms(self) -> Tensor: + weight_norms = [] + for lora in self.text_encoder_loras + self.unet_loras: + if hasattr(lora, "weight_norms") and lora.weight_norms is not None: + weight_norms.append(lora.weight_norms.mean(dim=0)) + return torch.stack(weight_norms) if len(weight_norms) > 0 else torch.tensor([]) + + def combined_weight_norms(self) -> Tensor: + combined_weight_norms = [] + for lora in self.text_encoder_loras + self.unet_loras: + if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None: + combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0)) + return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([]) + + def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file diff --git a/train_network.py b/train_network.py index 9db335b0..4898e798 100644 --- a/train_network.py +++ b/train_network.py @@ -69,13 +69,20 @@ class NetworkTrainer: keys_scaled=None, mean_norm=None, maximum_norm=None, + mean_grad_norm=None, + mean_combined_norm=None ): logs = {"loss/current": current_loss, "loss/average": avr_loss} if keys_scaled is not None: logs["max_norm/keys_scaled"] = keys_scaled - logs["max_norm/average_key_norm"] = mean_norm logs["max_norm/max_key_norm"] = maximum_norm + if mean_norm is not None: + logs["norm/avg_key_norm"] = mean_norm + if mean_grad_norm is not None: + logs["norm/avg_grad_norm"] = mean_grad_norm + if mean_combined_norm is not None: + logs["norm/avg_combined_norm"] = mean_combined_norm lrs = lr_scheduler.get_last_lr() for i, lr in enumerate(lrs): @@ -1400,10 +1407,12 @@ class NetworkTrainer: params_to_clip = accelerator.unwrap_model(network).get_trainable_params() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - if global_step % 5 == 0: + if hasattr(network, "update_grad_norms"): + network.update_grad_norms() if hasattr(network, "update_norms"): network.update_norms() + optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) @@ -1412,9 +1421,23 @@ class NetworkTrainer: keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization( args.scale_weight_norms, accelerator.device ) + mean_grad_norm = None + mean_combined_norm = None max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} else: - keys_scaled, mean_norm, maximum_norm = None, None, None + if hasattr(network, "weight_norms"): + mean_norm = network.weight_norms().mean().item() + mean_grad_norm = network.grad_norms().mean().item() + mean_combined_norm = network.combined_weight_norms().mean().item() + weight_norms = network.weight_norms() + maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None + keys_scaled = None + max_mean_logs = {"avg weight norm": mean_norm, "avg grad norm": mean_grad_norm, "avg comb norm": mean_combined_norm} + else: + keys_scaled, mean_norm, maximum_norm = None, None, None + mean_grad_norm = None + mean_combined_norm = None + max_mean_logs = {} # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -1446,14 +1469,11 @@ class NetworkTrainer: loss_recorder.add(epoch=epoch, step=step, loss=current_loss) avr_loss: float = loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if args.scale_weight_norms: - progress_bar.set_postfix(**{**max_mean_logs, **logs}) + progress_bar.set_postfix(**{**max_mean_logs, **logs}) if is_tracking: logs = self.generate_step_logs( - args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm + args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm, mean_grad_norm, mean_combined_norm ) self.step_logging(accelerator, logs, global_step, epoch + 1) From 0b25a05e3c0b983d7a4fa74f40798705a00992e3 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 18 Mar 2025 15:40:40 -0400 Subject: [PATCH 18/71] Add IP noise gamma for Flux --- library/flux_train_utils.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f7f06c5c..f866fd4a 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -415,6 +415,16 @@ def get_noisy_model_input_and_timesteps( bsz, _, h, w = latents.shape sigmas = None + ip_noise_gamma = 0.0 + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.ip_noise_gamma: + if args.ip_noise_gamma_random_strength: + ip_noise_gamma = torch.rand(1, device=latents.device) * args.ip_noise_gamma + else: + ip_noise_gamma = args.ip_noise_gamma + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling if args.timestep_sampling == "sigmoid": @@ -425,7 +435,7 @@ def get_noisy_model_input_and_timesteps( timesteps = t * 1000.0 t = t.view(-1, 1, 1, 1) - noisy_model_input = (1 - t) * latents + t * noise + noisy_model_input = (1 - t) * latents + t * noise + ip_noise_gamma elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift logits_norm = torch.randn(bsz, device=device) @@ -435,7 +445,7 @@ def get_noisy_model_input_and_timesteps( t = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * noise + noisy_model_input = (1 - t) * latents + t * noise + ip_noise_gamma elif args.timestep_sampling == "flux_shift": logits_norm = torch.randn(bsz, device=device) logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling @@ -445,7 +455,7 @@ def get_noisy_model_input_and_timesteps( t = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * noise + noisy_model_input = (1 - t) * latents + t * noise + ip_noise_gamma else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -461,7 +471,8 @@ def get_noisy_model_input_and_timesteps( # Add noise according to flow matching. sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + ip_noise_gamma + return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas From c8be141ae0576119ecd8ae329f00700098ee83a2 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 18 Mar 2025 15:42:18 -0400 Subject: [PATCH 19/71] Apply IP gamma to noise fix --- library/flux_train_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f866fd4a..557f61e7 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -471,7 +471,7 @@ def get_noisy_model_input_and_timesteps( # Add noise according to flow matching. sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + ip_noise_gamma + noisy_model_input = sigmas * noise + ip_noise_gamma + (1.0 - sigmas) * latents return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas From b425466e7be64e12238b267862468dc9f0b0bb6e Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 18 Mar 2025 18:42:35 -0400 Subject: [PATCH 20/71] Fix IP noise gamma to use random values --- library/flux_train_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 557f61e7..f0744747 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -415,15 +415,15 @@ def get_noisy_model_input_and_timesteps( bsz, _, h, w = latents.shape sigmas = None - ip_noise_gamma = 0.0 - # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) if args.ip_noise_gamma: if args.ip_noise_gamma_random_strength: - ip_noise_gamma = torch.rand(1, device=latents.device) * args.ip_noise_gamma + ip_noise = (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents) else: - ip_noise_gamma = args.ip_noise_gamma + ip_noise = args.ip_noise_gamma * torch.randn_like(latents) + else: + ip_noise = torch.zeros_like(latents) if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling @@ -435,7 +435,7 @@ def get_noisy_model_input_and_timesteps( timesteps = t * 1000.0 t = t.view(-1, 1, 1, 1) - noisy_model_input = (1 - t) * latents + t * noise + ip_noise_gamma + noisy_model_input = (1 - t) * latents + t * (noise + ip_noise) elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift logits_norm = torch.randn(bsz, device=device) @@ -445,7 +445,7 @@ def get_noisy_model_input_and_timesteps( t = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * noise + ip_noise_gamma + noisy_model_input = (1 - t) * latents + t * (noise + ip_noise) elif args.timestep_sampling == "flux_shift": logits_norm = torch.randn(bsz, device=device) logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling @@ -455,7 +455,7 @@ def get_noisy_model_input_and_timesteps( t = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * noise + ip_noise_gamma + noisy_model_input = (1 - t) * latents + t * (noise + ip_noise) else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -471,7 +471,7 @@ def get_noisy_model_input_and_timesteps( # Add noise according to flow matching. sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) - noisy_model_input = sigmas * noise + ip_noise_gamma + (1.0 - sigmas) * latents + noisy_model_input = sigmas * (noise + ip_noise) + (1.0 - sigmas) * latents return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas From a4f3a9fc1a4f4f964a6971bc4b0ae15c94f0d672 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 18 Mar 2025 18:44:21 -0400 Subject: [PATCH 21/71] Use ones_like --- library/flux_train_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f0744747..8cf95858 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -423,7 +423,7 @@ def get_noisy_model_input_and_timesteps( else: ip_noise = args.ip_noise_gamma * torch.randn_like(latents) else: - ip_noise = torch.zeros_like(latents) + ip_noise = torch.ones_like(latents) if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling From 6f4d3657756a9d679dfa76f7c6c7bd1c957130ca Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 18 Mar 2025 18:53:34 -0400 Subject: [PATCH 22/71] zeros_like because we are adding --- library/flux_train_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 8cf95858..f0744747 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -423,7 +423,7 @@ def get_noisy_model_input_and_timesteps( else: ip_noise = args.ip_noise_gamma * torch.randn_like(latents) else: - ip_noise = torch.ones_like(latents) + ip_noise = torch.zeros_like(latents) if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling From b81bcd0b01aa81bf616b6125ca1da4d6d3c9dd82 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 18 Mar 2025 21:36:55 -0400 Subject: [PATCH 23/71] Move IP noise gamma to noise creation to remove complexity and align noise for target loss --- flux_train_network.py | 9 +++++++++ library/flux_train_utils.py | 19 ++++--------------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index def44155..d85584f5 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -350,6 +350,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.ip_noise_gamma: + if args.ip_noise_gamma_random_strength: + noise = noise + (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents) + else: + noise = noise + args.ip_noise_gamma * torch.randn_like(latents) + bsz = latents.shape[0] # get noisy model input and timesteps diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f0744747..f7f06c5c 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -415,16 +415,6 @@ def get_noisy_model_input_and_timesteps( bsz, _, h, w = latents.shape sigmas = None - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - if args.ip_noise_gamma: - if args.ip_noise_gamma_random_strength: - ip_noise = (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents) - else: - ip_noise = args.ip_noise_gamma * torch.randn_like(latents) - else: - ip_noise = torch.zeros_like(latents) - if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling if args.timestep_sampling == "sigmoid": @@ -435,7 +425,7 @@ def get_noisy_model_input_and_timesteps( timesteps = t * 1000.0 t = t.view(-1, 1, 1, 1) - noisy_model_input = (1 - t) * latents + t * (noise + ip_noise) + noisy_model_input = (1 - t) * latents + t * noise elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift logits_norm = torch.randn(bsz, device=device) @@ -445,7 +435,7 @@ def get_noisy_model_input_and_timesteps( t = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * (noise + ip_noise) + noisy_model_input = (1 - t) * latents + t * noise elif args.timestep_sampling == "flux_shift": logits_norm = torch.randn(bsz, device=device) logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling @@ -455,7 +445,7 @@ def get_noisy_model_input_and_timesteps( t = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * (noise + ip_noise) + noisy_model_input = (1 - t) * latents + t * noise else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -471,8 +461,7 @@ def get_noisy_model_input_and_timesteps( # Add noise according to flow matching. sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) - noisy_model_input = sigmas * (noise + ip_noise) + (1.0 - sigmas) * latents - + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas From 5b210ad7178c0b88c214686389b0afb03ba3813c Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Wed, 19 Mar 2025 10:49:06 +0800 Subject: [PATCH 24/71] update prodigyopt and prodigy-plus-schedule-free --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 52c3b8c7..7348647f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,6 @@ opencv-python==4.8.1.78 einops==0.7.0 pytorch-lightning==1.9.0 bitsandbytes==0.44.0 -prodigyopt==1.0 lion-pytorch==0.0.6 schedulefree==1.4 tensorboard @@ -44,4 +43,6 @@ rich==13.7.0 sentencepiece==0.2.0 # for kohya_ss library pytorch-optimizer +prodigy-plus-schedule-free==1.9.0 +prodigyopt==1.1.2 -e . From 7197266703d8ac9219dda8b5a58bbd60d029d597 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 19 Mar 2025 00:25:51 -0400 Subject: [PATCH 25/71] Perturbed noise should be separate of input noise --- flux_train_network.py | 9 --------- library/flux_train_utils.py | 13 ++++++++++++- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index d85584f5..def44155 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -350,15 +350,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - if args.ip_noise_gamma: - if args.ip_noise_gamma_random_strength: - noise = noise + (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents) - else: - noise = noise + args.ip_noise_gamma * torch.randn_like(latents) - bsz = latents.shape[0] # get noisy model input and timesteps diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f7f06c5c..775e0c33 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -410,11 +410,22 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, device, dtype + args, noise_scheduler, latents: torch.Tensor, input_noise: torch.Tensor, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, _, h, w = latents.shape sigmas = None + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.ip_noise_gamma: + if args.ip_noise_gamma_random_strength: + noise = input_noise.detach().clone() + (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents) + else: + noise = input_noise.detach().clone() + args.ip_noise_gamma * torch.randn_like(latents) + else: + noise = input_noise + + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling if args.timestep_sampling == "sigmoid": From d93ad90a717beb2fd322d2fae73992e9ea5213ea Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 19 Mar 2025 00:37:27 -0400 Subject: [PATCH 26/71] Add perturbation on noisy_model_input if needed --- library/flux_train_utils.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 775e0c33..0fe81da7 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -410,20 +410,11 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents: torch.Tensor, input_noise: torch.Tensor, device, dtype + args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, _, h, w = latents.shape sigmas = None - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - if args.ip_noise_gamma: - if args.ip_noise_gamma_random_strength: - noise = input_noise.detach().clone() + (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents) - else: - noise = input_noise.detach().clone() + args.ip_noise_gamma * torch.randn_like(latents) - else: - noise = input_noise if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": @@ -474,6 +465,15 @@ def get_noisy_model_input_and_timesteps( sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.ip_noise_gamma: + if args.ip_noise_gamma_random_strength: + xi = noise.detach().clone() + (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents) + else: + xi = noise.detach().clone() + args.ip_noise_gamma * torch.randn_like(latents) + noisy_model_input += xi + return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas From 8e6817b0c2d6e312b8da0d84baa2ecc72c83767f Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 19 Mar 2025 00:45:13 -0400 Subject: [PATCH 27/71] Remove double noise --- library/flux_train_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 0fe81da7..9808ad0a 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -415,8 +415,6 @@ def get_noisy_model_input_and_timesteps( bsz, _, h, w = latents.shape sigmas = None - - if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling if args.timestep_sampling == "sigmoid": @@ -469,10 +467,10 @@ def get_noisy_model_input_and_timesteps( # (this is the forward diffusion process) if args.ip_noise_gamma: if args.ip_noise_gamma_random_strength: - xi = noise.detach().clone() + (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents) + noise_perturbation = (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(noise) else: - xi = noise.detach().clone() + args.ip_noise_gamma * torch.randn_like(latents) - noisy_model_input += xi + noise_perturbation = args.ip_noise_gamma * torch.randn_like(noise) + noisy_model_input += noise_perturbation return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas From 1eddac26b010d23ce5f0eb6a8ac12fbca66ee50b Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 19 Mar 2025 00:49:42 -0400 Subject: [PATCH 28/71] Separate random to a variable, and make sure on device --- library/flux_train_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 9808ad0a..107f351f 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -466,11 +466,12 @@ def get_noisy_model_input_and_timesteps( # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) if args.ip_noise_gamma: + xi = torch.randn_like(latents, device=latents.device, dtype=dtype) if args.ip_noise_gamma_random_strength: - noise_perturbation = (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(noise) + ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma) else: - noise_perturbation = args.ip_noise_gamma * torch.randn_like(noise) - noisy_model_input += noise_perturbation + ip_noise_gamma = args.ip_noise_gamma + noisy_model_input += ip_noise_gamma * xi return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas From 5d5a7d2acf884077b6a24db269c8f4facb5b7487 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 19 Mar 2025 13:50:04 -0400 Subject: [PATCH 29/71] Fix IP noise calculation --- library/flux_train_utils.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 107f351f..0cb07e3d 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -423,29 +423,24 @@ def get_noisy_model_input_and_timesteps( else: t = torch.rand((bsz,), device=device) + sigmas = t.view(-1, 1, 1, 1) timesteps = t * 1000.0 - t = t.view(-1, 1, 1, 1) - noisy_model_input = (1 - t) * latents + t * noise elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift logits_norm = torch.randn(bsz, device=device) logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling timesteps = logits_norm.sigmoid() timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) - - t = timesteps.view(-1, 1, 1, 1) + sigmas = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * noise elif args.timestep_sampling == "flux_shift": logits_norm = torch.randn(bsz, device=device) logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling timesteps = logits_norm.sigmoid() mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) timesteps = time_shift(mu, 1.0, timesteps) - - t = timesteps.view(-1, 1, 1, 1) + sigmas = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * noise else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -458,10 +453,7 @@ def get_noisy_model_input_and_timesteps( ) indices = (u * noise_scheduler.config.num_train_timesteps).long() timesteps = noise_scheduler.timesteps[indices].to(device=device) - - # Add noise according to flow matching. sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -471,7 +463,9 @@ def get_noisy_model_input_and_timesteps( ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma) else: ip_noise_gamma = args.ip_noise_gamma - noisy_model_input += ip_noise_gamma * xi + noisy_model_input = sigmas * (noise + ip_noise_gamma * xi) + (1.0 - sigmas) * latents + else: + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas From f974c6b2577348acbe948bcc668dd7b061feb73e Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 19 Mar 2025 14:27:43 -0400 Subject: [PATCH 30/71] change order to match upstream --- library/flux_train_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 0cb07e3d..7bf2faf0 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -413,8 +413,6 @@ def get_noisy_model_input_and_timesteps( args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, _, h, w = latents.shape - sigmas = None - if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling if args.timestep_sampling == "sigmoid": @@ -463,9 +461,9 @@ def get_noisy_model_input_and_timesteps( ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma) else: ip_noise_gamma = args.ip_noise_gamma - noisy_model_input = sigmas * (noise + ip_noise_gamma * xi) + (1.0 - sigmas) * latents + noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi) else: - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas From d151833526f5f79414a995cbb416de8a31e000cb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 20 Mar 2025 22:05:29 +0900 Subject: [PATCH 31/71] docs: update README with recent changes and specify version for pytorch-optimizer --- README.md | 4 ++++ requirements.txt | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 426eaed8..59b0e676 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Mar 20, 2025: +- `pytorch-optimizer` is added to requirements.txt. Thank you to gesen2egee for PR [#1985](https://github.com/kohya-ss/sd-scripts/pull/1985). + - For example, you can use CAME optimizer with `--optimizer_type "pytorch_optimizer.CAME" --optimizer_args "weight_decay=0.01"`. + Mar 6, 2025: - Added a utility script to merge the weights of SD3's DiT, VAE (optional), CLIP-L, CLIP-G, and T5XXL into a single .safetensors file. Run `tools/merge_sd3_safetensors.py`. See `--help` for usage. PR [#1960](https://github.com/kohya-ss/sd-scripts/pull/1960) diff --git a/requirements.txt b/requirements.txt index 7348647f..767d9e8e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,9 @@ pytorch-lightning==1.9.0 bitsandbytes==0.44.0 lion-pytorch==0.0.6 schedulefree==1.4 +pytorch-optimizer==3.5.0 +prodigy-plus-schedule-free==1.9.0 +prodigyopt==1.1.2 tensorboard safetensors==0.4.4 # gradio==3.16.2 @@ -42,7 +45,4 @@ rich==13.7.0 # for T5XXL tokenizer (SD3/FLUX) sentencepiece==0.2.0 # for kohya_ss library -pytorch-optimizer -prodigy-plus-schedule-free==1.9.0 -prodigyopt==1.1.2 -e . From 16cef81aeaec1ebc07de30c7a1448982a61167e1 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 20 Mar 2025 14:32:56 -0400 Subject: [PATCH 32/71] Refactor sigmas and timesteps --- library/flux_train_utils.py | 41 ++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 7bf2faf0..9110da89 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -366,8 +366,6 @@ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < n_dim: - sigma = sigma.unsqueeze(-1) return sigma @@ -413,32 +411,30 @@ def get_noisy_model_input_and_timesteps( args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, _, h, w = latents.shape + num_timesteps = noise_scheduler.config.num_train_timesteps if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": - # Simple random t-based noise sampling + # Simple random sigma-based noise sampling if args.timestep_sampling == "sigmoid": # https://github.com/XLabs-AI/x-flux/tree/main - t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) + sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) else: - t = torch.rand((bsz,), device=device) + sigmas = torch.rand((bsz,), device=device) - sigmas = t.view(-1, 1, 1, 1) - timesteps = t * 1000.0 + timesteps = sigmas * num_timesteps elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift - logits_norm = torch.randn(bsz, device=device) - logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling - timesteps = logits_norm.sigmoid() - timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) - sigmas = timesteps.view(-1, 1, 1, 1) - timesteps = timesteps * 1000.0 + sigmas = torch.randn(bsz, device=device) + sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling + sigmas = sigmas.sigmoid() + sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas) + timesteps = sigmas * num_timesteps elif args.timestep_sampling == "flux_shift": - logits_norm = torch.randn(bsz, device=device) - logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling - timesteps = logits_norm.sigmoid() - mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) - timesteps = time_shift(mu, 1.0, timesteps) - sigmas = timesteps.view(-1, 1, 1, 1) - timesteps = timesteps * 1000.0 + sigmas = torch.randn(bsz, device=device) + sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling + sigmas = sigmas.sigmoid() + mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size + sigmas = time_shift(mu, 1.0, sigmas) + timesteps = sigmas * num_timesteps else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -449,10 +445,13 @@ def get_noisy_model_input_and_timesteps( logit_std=args.logit_std, mode_scale=args.mode_scale, ) - indices = (u * noise_scheduler.config.num_train_timesteps).long() + indices = (u * num_timesteps).long() timesteps = noise_scheduler.timesteps[indices].to(device=device) sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + # Broadcast sigmas to latent shape + sigmas = sigmas.view(-1, 1, 1, 1) + # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) if args.ip_noise_gamma: From e8b32548580ebf0001cd457d7b6f796e2eb169ff Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 20 Mar 2025 15:01:15 -0400 Subject: [PATCH 33/71] Add flux_train_utils tests for get get_noisy_model_input_and_timesteps --- library/flux_train_utils.py | 1 + tests/library/test_flux_train_utils.py | 220 +++++++++++++++++++++++++ 2 files changed, 221 insertions(+) create mode 100644 tests/library/test_flux_train_utils.py diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 9110da89..0e73a01d 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -411,6 +411,7 @@ def get_noisy_model_input_and_timesteps( args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, _, h, w = latents.shape + assert bsz > 0, "Batch size not large enough" num_timesteps = noise_scheduler.config.num_train_timesteps if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random sigma-based noise sampling diff --git a/tests/library/test_flux_train_utils.py b/tests/library/test_flux_train_utils.py new file mode 100644 index 00000000..a4c7ba3b --- /dev/null +++ b/tests/library/test_flux_train_utils.py @@ -0,0 +1,220 @@ +import pytest +import torch +from unittest.mock import MagicMock, patch +from library.flux_train_utils import ( + get_noisy_model_input_and_timesteps, +) + +# Mock classes and functions +class MockNoiseScheduler: + def __init__(self, num_train_timesteps=1000): + self.config = MagicMock() + self.config.num_train_timesteps = num_train_timesteps + self.timesteps = torch.arange(num_train_timesteps, dtype=torch.long) + + +# Create fixtures for commonly used objects +@pytest.fixture +def args(): + args = MagicMock() + args.timestep_sampling = "uniform" + args.weighting_scheme = "uniform" + args.logit_mean = 0.0 + args.logit_std = 1.0 + args.mode_scale = 1.0 + args.sigmoid_scale = 1.0 + args.discrete_flow_shift = 3.1582 + args.ip_noise_gamma = None + args.ip_noise_gamma_random_strength = False + return args + + +@pytest.fixture +def noise_scheduler(): + return MockNoiseScheduler(num_train_timesteps=1000) + + +@pytest.fixture +def latents(): + return torch.randn(2, 4, 8, 8) + + +@pytest.fixture +def noise(): + return torch.randn(2, 4, 8, 8) + + +@pytest.fixture +def device(): + # return "cuda" if torch.cuda.is_available() else "cpu" + return "cpu" + + +# Mock the required functions +@pytest.fixture(autouse=True) +def mock_functions(): + with ( + patch("torch.sigmoid", side_effect=torch.sigmoid), + patch("torch.rand", side_effect=torch.rand), + patch("torch.randn", side_effect=torch.randn), + ): + yield + + +# Test different timestep sampling methods +def test_uniform_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "uniform" + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + assert noisy_input.dtype == dtype + assert timesteps.dtype == dtype + + +def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "sigmoid" + args.sigmoid_scale = 10.0 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_shift_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "shift" + args.sigmoid_scale = 1.0 + args.discrete_flow_shift = 3.1582 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "flux_shift" + args.sigmoid_scale = 10.0 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_weighting_scheme(args, noise_scheduler, latents, noise, device): + # Mock the necessary functions for this specific test + with patch("library.flux_train_utils.compute_density_for_timestep_sampling", + return_value=torch.tensor([0.3, 0.7], device=device)), \ + patch("library.flux_train_utils.get_sigmas", + return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)): + + args.timestep_sampling = "other" # Will trigger the weighting scheme path + args.weighting_scheme = "uniform" + args.logit_mean = 0.0 + args.logit_std = 1.0 + args.mode_scale = 1.0 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, device, dtype + ) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +# Test IP noise options +def test_with_ip_noise(args, noise_scheduler, latents, noise, device): + args.ip_noise_gamma = 0.5 + args.ip_noise_gamma_random_strength = False + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device): + args.ip_noise_gamma = 0.1 + args.ip_noise_gamma_random_strength = True + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +# Test different data types +def test_float16_dtype(args, noise_scheduler, latents, noise, device): + dtype = torch.float16 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.dtype == dtype + assert timesteps.dtype == dtype + + +# Test different batch sizes +def test_different_batch_size(args, noise_scheduler, device): + latents = torch.randn(5, 4, 8, 8) # batch size of 5 + noise = torch.randn(5, 4, 8, 8) + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (5,) + assert sigmas.shape == (5, 1, 1, 1) + + +# Test different image sizes +def test_different_image_size(args, noise_scheduler, device): + latents = torch.randn(2, 4, 16, 16) # larger image size + noise = torch.randn(2, 4, 16, 16) + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (2,) + assert sigmas.shape == (2, 1, 1, 1) + + +# Test edge cases +def test_zero_batch_size(args, noise_scheduler, device): + with pytest.raises(AssertionError): # expecting an error with zero batch size + latents = torch.randn(0, 4, 8, 8) + noise = torch.randn(0, 4, 8, 8) + dtype = torch.float32 + + get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + +def test_different_timestep_count(args, device): + noise_scheduler = MockNoiseScheduler(num_train_timesteps=500) # different timestep count + latents = torch.randn(2, 4, 8, 8) + noise = torch.randn(2, 4, 8, 8) + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (2,) + # Check that timesteps are within the proper range + assert torch.all(timesteps < 500) From 8aa126582efbdf0472b0b8db800d50860870f3cd Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 20 Mar 2025 15:09:11 -0400 Subject: [PATCH 34/71] Scale sigmoid to default 1.0 --- pytest.ini | 1 + requirements.txt | 2 +- tests/library/test_flux_train_utils.py | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pytest.ini b/pytest.ini index 484d3aef..34b7e9c1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -6,3 +6,4 @@ filterwarnings = ignore::DeprecationWarning ignore::UserWarning ignore::FutureWarning +pythonpath = . diff --git a/requirements.txt b/requirements.txt index de39f588..8fe8c762 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ opencv-python==4.8.1.78 einops==0.7.0 pytorch-lightning==1.9.0 bitsandbytes==0.44.0 -prodigyopt==1.0 +prodigyopt>=1.0 lion-pytorch==0.0.6 schedulefree==1.4 tensorboard diff --git a/tests/library/test_flux_train_utils.py b/tests/library/test_flux_train_utils.py index a4c7ba3b..2ad7ce4e 100644 --- a/tests/library/test_flux_train_utils.py +++ b/tests/library/test_flux_train_utils.py @@ -77,7 +77,7 @@ def test_uniform_sampling(args, noise_scheduler, latents, noise, device): def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device): args.timestep_sampling = "sigmoid" - args.sigmoid_scale = 10.0 + args.sigmoid_scale = 1.0 dtype = torch.float32 noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) @@ -102,7 +102,7 @@ def test_shift_sampling(args, noise_scheduler, latents, noise, device): def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device): args.timestep_sampling = "flux_shift" - args.sigmoid_scale = 10.0 + args.sigmoid_scale = 1.0 dtype = torch.float32 noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) From d40f5b1e4ef5e7e6b51df26914be3a661b006d34 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 20 Mar 2025 15:09:50 -0400 Subject: [PATCH 35/71] Revert "Scale sigmoid to default 1.0" This reverts commit 8aa126582efbdf0472b0b8db800d50860870f3cd. --- pytest.ini | 1 - requirements.txt | 2 +- tests/library/test_flux_train_utils.py | 4 ++-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pytest.ini b/pytest.ini index 34b7e9c1..484d3aef 100644 --- a/pytest.ini +++ b/pytest.ini @@ -6,4 +6,3 @@ filterwarnings = ignore::DeprecationWarning ignore::UserWarning ignore::FutureWarning -pythonpath = . diff --git a/requirements.txt b/requirements.txt index 8fe8c762..de39f588 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ opencv-python==4.8.1.78 einops==0.7.0 pytorch-lightning==1.9.0 bitsandbytes==0.44.0 -prodigyopt>=1.0 +prodigyopt==1.0 lion-pytorch==0.0.6 schedulefree==1.4 tensorboard diff --git a/tests/library/test_flux_train_utils.py b/tests/library/test_flux_train_utils.py index 2ad7ce4e..a4c7ba3b 100644 --- a/tests/library/test_flux_train_utils.py +++ b/tests/library/test_flux_train_utils.py @@ -77,7 +77,7 @@ def test_uniform_sampling(args, noise_scheduler, latents, noise, device): def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device): args.timestep_sampling = "sigmoid" - args.sigmoid_scale = 1.0 + args.sigmoid_scale = 10.0 dtype = torch.float32 noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) @@ -102,7 +102,7 @@ def test_shift_sampling(args, noise_scheduler, latents, noise, device): def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device): args.timestep_sampling = "flux_shift" - args.sigmoid_scale = 1.0 + args.sigmoid_scale = 10.0 dtype = torch.float32 noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) From 89f0d27a5930ae0a355caacfedc546fb04a7345d Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 20 Mar 2025 15:10:33 -0400 Subject: [PATCH 36/71] Set sigmoid_scale to default 1.0 --- tests/library/test_flux_train_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/library/test_flux_train_utils.py b/tests/library/test_flux_train_utils.py index a4c7ba3b..2ad7ce4e 100644 --- a/tests/library/test_flux_train_utils.py +++ b/tests/library/test_flux_train_utils.py @@ -77,7 +77,7 @@ def test_uniform_sampling(args, noise_scheduler, latents, noise, device): def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device): args.timestep_sampling = "sigmoid" - args.sigmoid_scale = 10.0 + args.sigmoid_scale = 1.0 dtype = torch.float32 noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) @@ -102,7 +102,7 @@ def test_shift_sampling(args, noise_scheduler, latents, noise, device): def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device): args.timestep_sampling = "flux_shift" - args.sigmoid_scale = 10.0 + args.sigmoid_scale = 1.0 dtype = torch.float32 noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) From 8f4ee8fc343b047965cd8976fca65c3a35b7593a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 21 Mar 2025 22:05:48 +0900 Subject: [PATCH 37/71] doc: update README for latest --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 6beee5e3..7ed3a2f5 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,8 @@ This repository contains training, generation and utility scripts for Stable Dif [__Change History__](#change-history) is moved to the bottom of the page. 更新履歴は[ページ末尾](#change-history)に移しました。 +Latest update: 2025-03-21 (Version 0.9.1) + [日本語版READMEはこちら](./README-ja.md) The development version is in the `dev` branch. Please check the dev branch for the latest changes. @@ -146,6 +148,11 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### Mar 21, 2025 / 2025-03-21 Version 0.9.1 + +- Fixed a bug where some of LoRA modules for CLIP Text Encoder were not trained. Thank you Nekotekina for PR [#1964](https://github.com/kohya-ss/sd-scripts/pull/1964) + - The LoRA modules for CLIP Text Encoder are now 264 modules, which is the same as before. Only 88 modules were trained in the previous version. + ### Jan 17, 2025 / 2025-01-17 Version 0.9.0 - __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries. From 182544dcce383a433527e446bfc7fa8374e375a8 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 26 Mar 2025 14:23:04 -0400 Subject: [PATCH 38/71] Remove pertubation seed --- networks/lora_flux.py | 41 ++--------------------------------------- 1 file changed, 2 insertions(+), 39 deletions(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 9f5f1916..92b3979a 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -29,42 +29,6 @@ logger = logging.getLogger(__name__) NUM_DOUBLE_BLOCKS = 19 NUM_SINGLE_BLOCKS = 38 -@contextmanager -def temp_random_seed(seed, device=None): - """ - Context manager that temporarily sets a specific random seed and then - restores the original RNG state afterward. - - Args: - seed (int): The random seed to set temporarily - device (torch.device, optional): The device to set the seed for. - If None, will detect from the current context. - """ - # Save original RNG states - original_cpu_rng_state = torch.get_rng_state() - original_cuda_rng_states = None - if torch.cuda.is_available(): - original_cuda_rng_states = torch.cuda.get_rng_state_all() - - # Determine if we need to set CUDA seed - set_cuda = False - if device is not None: - set_cuda = device.type == 'cuda' - elif torch.cuda.is_available(): - set_cuda = True - - try: - # Set the temporary seed - torch.manual_seed(seed) - if set_cuda: - torch.cuda.manual_seed_all(seed) - yield - finally: - # Restore original RNG states - torch.set_rng_state(original_cpu_rng_state) - if torch.cuda.is_available() and original_cuda_rng_states is not None: - torch.cuda.set_rng_state_all(original_cuda_rng_states) - class LoRAModule(torch.nn.Module): """ @@ -150,7 +114,6 @@ class LoRAModule(torch.nn.Module): self.combined_weight_norms = None self.grad_norms = None self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0]) - self.perturbation_seed = torch.randint(0, 2**32 - 1, (1,)).detach().item() self.initialize_norm_cache(org_module.weight) self.org_module_shape: tuple[int] = org_module.weight.shape @@ -193,8 +156,8 @@ class LoRAModule(torch.nn.Module): lx = self.lora_up(lx) # LoRA Gradient-Guided Perturbation Optimization - if self.training and hasattr(self, 'perturbation_seed') and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None: - with torch.no_grad(), temp_random_seed(self.perturbation_seed): + if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None: + with torch.no_grad(): perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2)) perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device) perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device) From 0181b7a0425fd58012f7e3ece10345c86d9b6fc8 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 27 Mar 2025 03:28:33 -0400 Subject: [PATCH 39/71] Remove progress bar avg norms --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 4898e798..5b2f377a 100644 --- a/train_network.py +++ b/train_network.py @@ -1432,7 +1432,7 @@ class NetworkTrainer: weight_norms = network.weight_norms() maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None keys_scaled = None - max_mean_logs = {"avg weight norm": mean_norm, "avg grad norm": mean_grad_norm, "avg comb norm": mean_combined_norm} + max_mean_logs = {} else: keys_scaled, mean_norm, maximum_norm = None, None, None mean_grad_norm = None From 381303d64fa87c61da145b85b737c410d42555fa Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 29 Mar 2025 02:26:18 +0800 Subject: [PATCH 40/71] Update train_network.py --- train_network.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/train_network.py b/train_network.py index 7bf125dc..6953bb17 100644 --- a/train_network.py +++ b/train_network.py @@ -912,14 +912,22 @@ class NetworkTrainer: if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: - with torch.no_grad(): - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) - + if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size: + with torch.no_grad(): + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) + else: + chunks = [batch["images"][i:i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)] + list_latents = [] + for chunk in chunks: + with torch.no_grad(): + # latentに変換 + list_latents.append(vae.encode(chunk.to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)) + latents = torch.cat(list_latents, dim=0) # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.nan_to_num(latents, 0, out=latents) + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * self.vae_scale_factor # get multiplier for each sample From 1f432e2c0e5b583c09100c68ce59f30c9d39ecf6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 30 Mar 2025 20:40:29 +0900 Subject: [PATCH 41/71] use PIL for lanczos and box --- docs/config_README-en.md | 3 +++ docs/config_README-ja.md | 4 ++++ library/train_util.py | 2 +- library/utils.py | 21 ++++++++++++++------- 4 files changed, 22 insertions(+), 8 deletions(-) diff --git a/docs/config_README-en.md b/docs/config_README-en.md index 66a50dc0..8c55903d 100644 --- a/docs/config_README-en.md +++ b/docs/config_README-en.md @@ -152,6 +152,7 @@ These options are related to subset configuration. | `keep_tokens_separator` | `“|||”` | o | o | o | | `secondary_separator` | `“;;;”` | o | o | o | | `enable_wildcard` | `true` | o | o | o | +| `resize_interpolation` | (not specified) | o | o | o | * `num_repeats` * Specifies the number of repeats for images in a subset. This is equivalent to `--dataset_repeats` in fine-tuning but can be specified for any training method. @@ -165,6 +166,8 @@ These options are related to subset configuration. * Specifies an additional separator. The part separated by this separator is treated as one tag and is shuffled and dropped. It is then replaced by `caption_separator`. For example, if you specify `aaa;;;bbb;;;ccc`, it will be replaced by `aaa,bbb,ccc` or dropped together. * `enable_wildcard` * Enables wildcard notation. This will be explained later. +* `resize_interpolation` + * Specifies the interpolation method used when resizing images. Normally, there is no need to specify this. The following options can be specified: `lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box`. By default (when not specified), `area` is used for downscaling, and `lanczos` is used for upscaling. If this option is specified, the same interpolation method will be used for both upscaling and downscaling. When `lanczos` or `box` is specified, PIL is used; for other options, OpenCV is used. ### DreamBooth-specific options diff --git a/docs/config_README-ja.md b/docs/config_README-ja.md index 0ed95e0e..aec0eca5 100644 --- a/docs/config_README-ja.md +++ b/docs/config_README-ja.md @@ -144,6 +144,7 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学 | `keep_tokens_separator` | `“|||”` | o | o | o | | `secondary_separator` | `“;;;”` | o | o | o | | `enable_wildcard` | `true` | o | o | o | +| `resize_interpolation` |(通常は設定しません) | o | o | o | * `num_repeats` * サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。 @@ -162,6 +163,9 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学 * `enable_wildcard` * ワイルドカード記法および複数行キャプションを有効にします。ワイルドカード記法、複数行キャプションについては後述します。 +* `resize_interpolation` + * 画像のリサイズ時に使用する補間方法を指定します。通常は指定しなくて構いません。`lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box` が指定可能です。デフォルト(未指定時)は、縮小時は `area`、拡大時は `lanczos` になります。このオプションを指定すると、拡大時・縮小時とも同じ補間方法が使用されます。`lanczos`、`box`を指定するとPILが、それ以外を指定するとOpenCVが使用されます。 + ### DreamBooth 方式専用のオプション DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。 diff --git a/library/train_util.py b/library/train_util.py index e9c50688..1ed1d3c2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -74,7 +74,7 @@ import library.model_util as model_util import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec import library.deepspeed_utils as deepspeed_utils -from library.utils import setup_logging, resize_image +from library.utils import setup_logging, resize_image, validate_interpolation_fn setup_logging() import logging diff --git a/library/utils.py b/library/utils.py index 4fbc2627..0f535a87 100644 --- a/library/utils.py +++ b/library/utils.py @@ -400,7 +400,7 @@ def pil_resize(image, size, interpolation): def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None): """ - Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS + Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS. Args: image: numpy.ndarray @@ -413,14 +413,21 @@ def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, Returns: image """ - interpolation = get_cv2_interpolation(resize_interpolation) + if resize_interpolation is None: + resize_interpolation = "lanczos" if width > resized_width and height > resized_height else "area" + + # we use PIL for lanczos (for backward compatibility) and box, cv2 for others + use_pil = resize_interpolation in ["lanczos", "lanczos4", "box"] + resized_size = (resized_width, resized_height) - if width > resized_width and height > resized_width: - image = cv2.resize(image, resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ - logger.debug(f"resize image using {resize_interpolation}") + if use_pil: + interpolation = get_pil_interpolation(resize_interpolation) + image = pil_resize(image, resized_size, interpolation=interpolation) + logger.debug(f"resize image using {resize_interpolation} (PIL)") else: - image = cv2.resize(image, resized_size, interpolation=interpolation if interpolation is not None else cv2.INTER_LANCZOS4) # INTER_AREAでやりたいのでcv2でリサイズ - logger.debug(f"resize image using {resize_interpolation}") + interpolation = get_cv2_interpolation(resize_interpolation) + image = cv2.resize(image, resized_size, interpolation=interpolation) + logger.debug(f"resize image using {resize_interpolation} (cv2)") return image From 96a133c99850fe19544b62fbfde55a7d149802dd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 30 Mar 2025 20:45:06 +0900 Subject: [PATCH 42/71] README.md: update recent updates section to include new interpolation method for resizing images --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 7620e407..6e28b212 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Mar 30, 2025: +- The interpolation method for resizing the original image to the training size can now be specified. Thank you to rockerBOO for PR [#1939](https://github.com/kohya-ss/sd-scripts/pull/1939). + Mar 20, 2025: - `pytorch-optimizer` is added to requirements.txt. Thank you to gesen2egee for PR [#1985](https://github.com/kohya-ss/sd-scripts/pull/1985). - For example, you can use CAME optimizer with `--optimizer_type "pytorch_optimizer.CAME" --optimizer_args "weight_decay=0.01"`. From d0b5c0e5cfabf65a64d9d60712dc67bd8057336b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 30 Mar 2025 21:15:37 +0900 Subject: [PATCH 43/71] chore: formatting, add TODO comment --- train_network.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/train_network.py b/train_network.py index 3bab0cad..f66cdeb4 100644 --- a/train_network.py +++ b/train_network.py @@ -70,7 +70,7 @@ class NetworkTrainer: mean_norm=None, maximum_norm=None, mean_grad_norm=None, - mean_combined_norm=None + mean_combined_norm=None, ): logs = {"loss/current": current_loss, "loss/average": avr_loss} @@ -658,6 +658,10 @@ class NetworkTrainer: return network_has_multiplier = hasattr(network, "set_multiplier") + # TODO remove `hasattr`s by setting up methods if not defined in the network like (hacky but works): + # if not hasattr(network, "prepare_network"): + # network.prepare_network = lambda args: None + if hasattr(network, "prepare_network"): network.prepare_network(args) if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"): @@ -1019,12 +1023,12 @@ class NetworkTrainer: "ss_huber_c": args.huber_c, "ss_fp8_base": bool(args.fp8_base), "ss_fp8_base_unet": bool(args.fp8_base_unet), - "ss_validation_seed": args.validation_seed, - "ss_validation_split": args.validation_split, - "ss_max_validation_steps": args.max_validation_steps, - "ss_validate_every_n_epochs": args.validate_every_n_epochs, - "ss_validate_every_n_steps": args.validate_every_n_steps, - "ss_resize_interpolation": args.resize_interpolation + "ss_validation_seed": args.validation_seed, + "ss_validation_split": args.validation_split, + "ss_max_validation_steps": args.max_validation_steps, + "ss_validate_every_n_epochs": args.validate_every_n_epochs, + "ss_validate_every_n_steps": args.validate_every_n_steps, + "ss_resize_interpolation": args.resize_interpolation, } self.update_metadata(metadata, args) # architecture specific metadata @@ -1415,7 +1419,6 @@ class NetworkTrainer: if hasattr(network, "update_norms"): network.update_norms() - optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) @@ -1476,7 +1479,17 @@ class NetworkTrainer: if is_tracking: logs = self.generate_step_logs( - args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm, mean_grad_norm, mean_combined_norm + args, + current_loss, + avr_loss, + lr_scheduler, + lr_descriptions, + optimizer, + keys_scaled, + mean_norm, + maximum_norm, + mean_grad_norm, + mean_combined_norm, ) self.step_logging(accelerator, logs, global_step, epoch + 1) From aaa26bb882ff23e1e35c84fed6e4f7a12ec420d4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 30 Mar 2025 21:18:05 +0900 Subject: [PATCH 44/71] docs: update README to include LoRA-GGPO details for FLUX.1 training --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 6e28b212..4bc0c2b5 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,8 @@ The command to install PyTorch is as follows: ### Recent Updates Mar 30, 2025: +- LoRA-GGPO is added for FLUX.1 LoRA training. Thank you to rockerBOO for PR [#1974](https://github.com/kohya-ss/sd-scripts/pull/1974). + - Specify `--network_args ggpo_sigma=0.03 ggpo_beta=0.01` in the command line or `network_args = ["ggpo_sigma=0.03", "ggpo_beta=0.01"]` in .toml file. See PR for details. - The interpolation method for resizing the original image to the training size can now be specified. Thank you to rockerBOO for PR [#1939](https://github.com/kohya-ss/sd-scripts/pull/1939). Mar 20, 2025: From 583ab27b3ccf9e360be07d424e8c15f90d1041ad Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 31 Mar 2025 22:02:25 +0900 Subject: [PATCH 45/71] doc: update license information in jpeg_xl_util.py --- library/jpeg_xl_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/jpeg_xl_util.py b/library/jpeg_xl_util.py index ade24a05..c2e3a393 100644 --- a/library/jpeg_xl_util.py +++ b/library/jpeg_xl_util.py @@ -1,4 +1,4 @@ -# Modified from https://github.com/Fraetor/jxl_decode +# Modified from https://github.com/Fraetor/jxl_decode Original license: MIT # Added partial read support for up to 200x speedup import os From ede34702609c59ba2256cf7e330bad67ce9c77d3 Mon Sep 17 00:00:00 2001 From: Lex Song Date: Wed, 2 Apr 2025 03:28:58 +0800 Subject: [PATCH 46/71] Ensure all size parameters are integers to prevent type errors --- library/utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/library/utils.py b/library/utils.py index 0f535a87..767de472 100644 --- a/library/utils.py +++ b/library/utils.py @@ -413,6 +413,13 @@ def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, Returns: image """ + + # Ensure all size parameters are actual integers + width = int(width) + height = int(height) + resized_width = int(resized_width) + resized_height = int(resized_height) + if resize_interpolation is None: resize_interpolation = "lanczos" if width > resized_width and height > resized_height else "area" From b822b7e60b84a4fb32a8e1ffa966054f8fe96209 Mon Sep 17 00:00:00 2001 From: Lex Song Date: Wed, 2 Apr 2025 03:32:36 +0800 Subject: [PATCH 47/71] Fix the interpolation logic error in resize_image() The original code had a mistake. It used 'lanczos' when the image got smaller (width > resized_width and height > resized_height) and 'area' when it stayed the same or got bigger. This was the wrong way. 'area' is better for big shrinking. --- library/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/library/utils.py b/library/utils.py index 767de472..d0586b84 100644 --- a/library/utils.py +++ b/library/utils.py @@ -421,8 +421,11 @@ def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height = int(resized_height) if resize_interpolation is None: - resize_interpolation = "lanczos" if width > resized_width and height > resized_height else "area" - + if width >= resized_width and height >= resized_height: + resize_interpolation = "area" + else: + resize_interpolation = "lanczos" + # we use PIL for lanczos (for backward compatibility) and box, cv2 for others use_pil = resize_interpolation in ["lanczos", "lanczos4", "box"] From f1423a72298a12110192f59cebe26b39206268e5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 3 Apr 2025 21:48:51 +0900 Subject: [PATCH 48/71] fix: add resize_interpolation parameter to FineTuningDataset constructor --- library/train_util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 53799081..6c39f8d9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2154,8 +2154,9 @@ class FineTuningDataset(BaseDataset): debug_dataset: bool, validation_seed: int, validation_split: float, + resize_interpolation: Optional[str], ) -> None: - super().__init__(resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) self.batch_size = batch_size From fd36fd1aa91a5c24bea820fae245b1cea7ac2b44 Mon Sep 17 00:00:00 2001 From: Dave Lage Date: Thu, 3 Apr 2025 16:09:45 -0400 Subject: [PATCH 49/71] Fix resize PR link --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 4bc0c2b5..ae417d05 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ The command to install PyTorch is as follows: Mar 30, 2025: - LoRA-GGPO is added for FLUX.1 LoRA training. Thank you to rockerBOO for PR [#1974](https://github.com/kohya-ss/sd-scripts/pull/1974). - Specify `--network_args ggpo_sigma=0.03 ggpo_beta=0.01` in the command line or `network_args = ["ggpo_sigma=0.03", "ggpo_beta=0.01"]` in .toml file. See PR for details. -- The interpolation method for resizing the original image to the training size can now be specified. Thank you to rockerBOO for PR [#1939](https://github.com/kohya-ss/sd-scripts/pull/1939). +- The interpolation method for resizing the original image to the training size can now be specified. Thank you to rockerBOO for PR [#1936](https://github.com/kohya-ss/sd-scripts/pull/1936). Mar 20, 2025: - `pytorch-optimizer` is added to requirements.txt. Thank you to gesen2egee for PR [#1985](https://github.com/kohya-ss/sd-scripts/pull/1985). From 4589262f8f35914b592992f4144bde5e746a6e36 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 6 Apr 2025 21:34:27 +0900 Subject: [PATCH 50/71] README.md: Update recent updates section to include IP noise gamma feature for FLUX.1 --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index ae417d05..2e80a697 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Apr 6, 2025: +- IP noise gamma has been enabled in FLUX.1. Thanks to rockerBOO for PR [#1992](https://github.com/kohya-ss/sd-scripts/pull/1992). See the PR for details. + - `--ip_noise_gamma` and `--ip_noise_gamma_random_strength` are available. + Mar 30, 2025: - LoRA-GGPO is added for FLUX.1 LoRA training. Thank you to rockerBOO for PR [#1974](https://github.com/kohya-ss/sd-scripts/pull/1974). - Specify `--network_args ggpo_sigma=0.03 ggpo_beta=0.01` in the command line or `network_args = ["ggpo_sigma=0.03", "ggpo_beta=0.01"]` in .toml file. See PR for details. From 629073cd9dd21296ca8aa97a5267d4dc7f6e5fdb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 16 Apr 2025 21:50:36 +0900 Subject: [PATCH 51/71] Add guidance scale for prompt param and flux sampling --- library/flux_train_utils.py | 10 +++++++--- library/train_util.py | 5 +++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index ce381829..d2ff347d 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -154,6 +154,7 @@ def sample_image_inference( sample_steps = prompt_dict.get("sample_steps", 20) width = prompt_dict.get("width", 512) height = prompt_dict.get("height", 512) + guidance_scale = prompt_dict.get("guidance_scale", args.guidance_scale) scale = prompt_dict.get("scale", 1.0) # 1.0 means no guidance seed = prompt_dict.get("seed") controlnet_image = prompt_dict.get("controlnet_image") @@ -180,9 +181,12 @@ def sample_image_inference( logger.info(f"prompt: {prompt}") if scale != 1.0: logger.info(f"negative_prompt: {negative_prompt}") + elif negative_prompt != "": + logger.info(f"negative prompt is ignored because scale is 1.0") logger.info(f"height: {height}") logger.info(f"width: {width}") logger.info(f"sample_steps: {sample_steps}") + logger.info(f"guidance_scale: {guidance_scale}") if scale != 1.0: logger.info(f"scale: {scale}") # logger.info(f"sample_sampler: {sampler_name}") @@ -256,7 +260,7 @@ def sample_image_inference( txt_ids, l_pooled, timesteps=timesteps, - guidance=scale, + guidance=guidance_scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image, @@ -489,7 +493,7 @@ def get_noisy_model_input_and_timesteps( sigmas = torch.randn(bsz, device=device) sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling sigmas = sigmas.sigmoid() - mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size + mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size sigmas = time_shift(mu, 1.0, sigmas) timesteps = sigmas * num_timesteps else: @@ -514,7 +518,7 @@ def get_noisy_model_input_and_timesteps( if args.ip_noise_gamma: xi = torch.randn_like(latents, device=latents.device, dtype=dtype) if args.ip_noise_gamma_random_strength: - ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma) + ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma else: ip_noise_gamma = args.ip_noise_gamma noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi) diff --git a/library/train_util.py b/library/train_util.py index 6c39f8d9..e152f30f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -6178,6 +6178,11 @@ def line_to_prompt_dict(line: str) -> dict: prompt_dict["scale"] = float(m.group(1)) continue + m = re.match(r"g ([\d\.]+)", parg, re.IGNORECASE) + if m: # guidance scale + prompt_dict["guidance_scale"] = float(m.group(1)) + continue + m = re.match(r"n (.+)", parg, re.IGNORECASE) if m: # negative prompt prompt_dict["negative_prompt"] = m.group(1) From 26db64be17835de5bff22bfc6d671ae1a2ffb4a4 Mon Sep 17 00:00:00 2001 From: Glen Date: Sat, 19 Apr 2025 11:54:12 -0600 Subject: [PATCH 52/71] fix: update hf_hub_download parameters to fix wd14 tagger regression --- finetune/tag_images_by_wd14_tagger.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index cbc3d2d6..f8f6ddd9 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -100,15 +100,19 @@ def main(args): else: for file in SUB_DIR_FILES: hf_hub_download( - args.repo_id, - file, + repo_id=args.repo_id, + filename=file, subfolder=SUB_DIR, - cache_dir=os.path.join(model_location, SUB_DIR), + local_dir=os.path.join(model_location, SUB_DIR), force_download=True, - force_filename=file, ) for file in files: - hf_hub_download(args.repo_id, file, cache_dir=model_location, force_download=True, force_filename=file) + hf_hub_download( + repo_id=args.repo_id, + filename=file, + local_dir=model_location, + force_download=True, + ) else: logger.info("using existing wd14 tagger model") From 7c61c0dfe0e879fd6b66ccb70273e4b99deaf1c5 Mon Sep 17 00:00:00 2001 From: saibit Date: Tue, 22 Apr 2025 16:06:55 +0800 Subject: [PATCH 53/71] Add autocast warpper for forward functions in deepspeed_utils.py to try aligning precision when using mixed precision in training process --- library/deepspeed_utils.py | 32 ++++++++++++++++++++++++++++++++ library/flux_models.py | 2 +- library/train_util.py | 5 +++++ requirements.txt | 3 ++- 4 files changed, 40 insertions(+), 2 deletions(-) diff --git a/library/deepspeed_utils.py b/library/deepspeed_utils.py index 99a7b2b3..3018def7 100644 --- a/library/deepspeed_utils.py +++ b/library/deepspeed_utils.py @@ -94,6 +94,7 @@ def prepare_deepspeed_plugin(args: argparse.Namespace): deepspeed_plugin.deepspeed_config["train_batch_size"] = ( args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"]) ) + deepspeed_plugin.set_mixed_precision(args.mixed_precision) if args.mixed_precision.lower() == "fp16": deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow. @@ -122,18 +123,49 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models): class DeepSpeedWrapper(torch.nn.Module): def __init__(self, **kw_models) -> None: super().__init__() + self.models = torch.nn.ModuleDict() + + warp_model_forward_with_torch_autocast = args.mixed_precision is not "no" for key, model in kw_models.items(): if isinstance(model, list): model = torch.nn.ModuleList(model) + assert isinstance( model, torch.nn.Module ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}" + + if warp_model_forward_with_torch_autocast: + model = self.__warp_with_torch_autocast(model) + self.models.update(torch.nn.ModuleDict({key: model})) + def __warp_with_torch_autocast(self, model): + if isinstance(model, torch.nn.ModuleList): + for i in range(len(model)): + model[i] = self.__warp_model_forward_with_torch_autocast(model[i]) + else: + model = self.__warp_model_forward_with_torch_autocast(model) + return model + + def __warp_model_forward_with_torch_autocast(self, model): + + assert hasattr(model, "forward"), f"model must have a forward method." + + forward_fn = model.forward + + def forward(*args, **kwargs): + device_type= "cuda" if torch.cuda.is_available() else "cpu" + with torch.autocast(device_type=device_type): + return forward_fn(*args, **kwargs) + model.forward = forward + + return model + def get_models(self): return self.models + ds_model = DeepSpeedWrapper(**models) return ds_model diff --git a/library/flux_models.py b/library/flux_models.py index 328ad481..12151ee8 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1005,7 +1005,7 @@ class Flux(nn.Module): return self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) - + def forward( self, img: Tensor, diff --git a/library/train_util.py b/library/train_util.py index 6c39f8d9..dbbfda3e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5495,6 +5495,11 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio def patch_accelerator_for_fp16_training(accelerator): + + from accelerate import DistributedType + if accelerator.distributed_type == DistributedType.DEEPSPEED: + return + org_unscale_grads = accelerator.scaler._unscale_grads_ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): diff --git a/requirements.txt b/requirements.txt index 767d9e8e..bead3f90 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ accelerate==0.33.0 transformers==4.44.0 -diffusers[torch]==0.25.0 +diffusers==0.25.0 +deepspeed==0.16.7 ftfy==6.1.1 # albumentations==1.3.0 opencv-python==4.8.1.78 From d33d5eccd16970e489359ee02b89a6259559e4b9 Mon Sep 17 00:00:00 2001 From: saibit Date: Tue, 22 Apr 2025 16:12:06 +0800 Subject: [PATCH 54/71] # --- library/flux_models.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/library/flux_models.py b/library/flux_models.py index 12151ee8..d7840d51 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1004,8 +1004,7 @@ class Flux(nn.Module): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) - self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) - + self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) def forward( self, img: Tensor, From 7f984f47758f9e17f4a82b92cb9dbc97b3ba982f Mon Sep 17 00:00:00 2001 From: saibit Date: Tue, 22 Apr 2025 16:15:12 +0800 Subject: [PATCH 55/71] # --- library/flux_models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/library/flux_models.py b/library/flux_models.py index d7840d51..328ad481 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1004,7 +1004,8 @@ class Flux(nn.Module): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) - self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) + self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) + def forward( self, img: Tensor, From c8af252a44a7dbc54a0c1622946faedef4e7c52b Mon Sep 17 00:00:00 2001 From: Robert Date: Tue, 22 Apr 2025 16:19:14 +0800 Subject: [PATCH 56/71] refactor --- library/deepspeed_utils.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/library/deepspeed_utils.py b/library/deepspeed_utils.py index 3018def7..f6eac367 100644 --- a/library/deepspeed_utils.py +++ b/library/deepspeed_utils.py @@ -126,7 +126,7 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models): self.models = torch.nn.ModuleDict() - warp_model_forward_with_torch_autocast = args.mixed_precision is not "no" + wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no" for key, model in kw_models.items(): if isinstance(model, list): @@ -136,31 +136,30 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models): model, torch.nn.Module ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}" - if warp_model_forward_with_torch_autocast: - model = self.__warp_with_torch_autocast(model) + if wrap_model_forward_with_torch_autocast: + model = self.__wrap_model_with_torch_autocast(model) self.models.update(torch.nn.ModuleDict({key: model})) - def __warp_with_torch_autocast(self, model): + def __wrap_model_with_torch_autocast(self, model): if isinstance(model, torch.nn.ModuleList): - for i in range(len(model)): - model[i] = self.__warp_model_forward_with_torch_autocast(model[i]) + model = [self.__wrap_model_forward_with_torch_autocast(m) for m in model] else: - model = self.__warp_model_forward_with_torch_autocast(model) + model = self.__wrap_model_forward_with_torch_autocast(model) return model - def __warp_model_forward_with_torch_autocast(self, model): + def __wrap_model_forward_with_torch_autocast(self, model): assert hasattr(model, "forward"), f"model must have a forward method." forward_fn = model.forward def forward(*args, **kwargs): - device_type= "cuda" if torch.cuda.is_available() else "cpu" + device_type = "cuda" if torch.cuda.is_available() else "cpu" with torch.autocast(device_type=device_type): return forward_fn(*args, **kwargs) + model.forward = forward - return model def get_models(self): From adb775c6165d93a856e33d0d9058efd629cf2a2d Mon Sep 17 00:00:00 2001 From: saibit Date: Wed, 23 Apr 2025 17:05:20 +0800 Subject: [PATCH 57/71] Update: requirement diffusers[torch]==0.25.0 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index bead3f90..9e97eed3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ accelerate==0.33.0 transformers==4.44.0 -diffusers==0.25.0 +diffusers[torch]==0.25.0 deepspeed==0.16.7 ftfy==6.1.1 # albumentations==1.3.0 From abf2c44bc5650afef8bebbb1ef278c66f44c4dda Mon Sep 17 00:00:00 2001 From: sharlynxy Date: Wed, 23 Apr 2025 18:57:19 +0800 Subject: [PATCH 58/71] Dynamically set device in deepspeed wrapper (#2) * get device type from model * add logger warning * format * format * format --- library/deepspeed_utils.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/library/deepspeed_utils.py b/library/deepspeed_utils.py index f6eac367..09c6f7b9 100644 --- a/library/deepspeed_utils.py +++ b/library/deepspeed_utils.py @@ -5,6 +5,8 @@ from accelerate import DeepSpeedPlugin, Accelerator from .utils import setup_logging +from .device_utils import get_preferred_device + setup_logging() import logging @@ -153,13 +155,21 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models): assert hasattr(model, "forward"), f"model must have a forward method." forward_fn = model.forward - + def forward(*args, **kwargs): - device_type = "cuda" if torch.cuda.is_available() else "cpu" - with torch.autocast(device_type=device_type): + try: + device_type = model.device.type + except AttributeError: + logger.warning( + "[DeepSpeed] model.device is not available. Using get_preferred_device() " + "to determine the device_type for torch.autocast()." + ) + device_type = get_preferred_device().type + + with torch.autocast(device_type = device_type): return forward_fn(*args, **kwargs) - model.forward = forward + model.forward = forward return model def get_models(self): From 46ad3be0593df1df9d485c3ac2efb5aebd87730c Mon Sep 17 00:00:00 2001 From: saibit Date: Thu, 24 Apr 2025 11:26:36 +0800 Subject: [PATCH 59/71] update deepspeed wrapper --- library/deepspeed_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/library/deepspeed_utils.py b/library/deepspeed_utils.py index 09c6f7b9..a8a05c3a 100644 --- a/library/deepspeed_utils.py +++ b/library/deepspeed_utils.py @@ -134,18 +134,18 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models): if isinstance(model, list): model = torch.nn.ModuleList(model) + if wrap_model_forward_with_torch_autocast: + model = self.__wrap_model_with_torch_autocast(model) + assert isinstance( model, torch.nn.Module ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}" - if wrap_model_forward_with_torch_autocast: - model = self.__wrap_model_with_torch_autocast(model) - self.models.update(torch.nn.ModuleDict({key: model})) def __wrap_model_with_torch_autocast(self, model): if isinstance(model, torch.nn.ModuleList): - model = [self.__wrap_model_forward_with_torch_autocast(m) for m in model] + model = torch.nn.ModuleList([self.__wrap_model_forward_with_torch_autocast(m) for m in model]) else: model = self.__wrap_model_forward_with_torch_autocast(model) return model From 8387e0b95c1067e919f91a2abec11ddcd5ed15cb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 27 Apr 2025 18:25:59 +0900 Subject: [PATCH 60/71] docs: update README to include CFG scale support in FLUX.1 training --- README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 2e80a697..f9831aee 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Apr 27, 2025: +- FLUX.1 training now supports CFG scale in the sample generation during training. Please use `--g` option, to specify the CFG scale (note that `--l` is used as the embedded guidance scale.) PR [#2064](https://github.com/kohya-ss/sd-scripts/pull/2064). + - See [here](#sample-image-generation-during-training) for details. + Apr 6, 2025: - IP noise gamma has been enabled in FLUX.1. Thanks to rockerBOO for PR [#1992](https://github.com/kohya-ss/sd-scripts/pull/1992). See the PR for details. - `--ip_noise_gamma` and `--ip_noise_gamma_random_strength` are available. @@ -1344,11 +1348,13 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used. - * `--n` Negative prompt up to the next option. + * `--n` Negative prompt up to the next option. Ignored when CFG scale is `1.0`. * `--w` Specifies the width of the generated image. * `--h` Specifies the height of the generated image. * `--d` Specifies the seed of the generated image. * `--l` Specifies the CFG scale of the generated image. + * In guidance distillation models like FLUX.1, this value is used as the embedded guidance scale for backward compatibility. + * `--g` Specifies the CFG scale for the models with embedded guidance scale. The default is `1.0`, `1.0` means no CFG. In general, should not be changed unless you train the un-distilled FLUX.1 models. * `--s` Specifies the number of steps in the generation. The prompt weighting such as `( )` and `[ ]` are working. From f0b07c52abaf4ab33d619b427afabe17b69b7d05 Mon Sep 17 00:00:00 2001 From: "Kohya S." <52813779+kohya-ss@users.noreply.github.com> Date: Sun, 27 Apr 2025 21:28:38 +0900 Subject: [PATCH 61/71] Create FUNDING.yml --- .github/FUNDING.yml | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 .github/FUNDING.yml diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 00000000..3b8943c3 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,3 @@ +# These are supported funding model platforms + +github: kohya-ss From fd3a445769910ddc0c8c02d13e535cac37b85d2e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 27 Apr 2025 22:50:27 +0900 Subject: [PATCH 62/71] fix: revert default emb guidance scale and CFG scale for FLUX.1 sampling --- library/flux_train_utils.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index d2ff347d..5f6867a8 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -154,8 +154,9 @@ def sample_image_inference( sample_steps = prompt_dict.get("sample_steps", 20) width = prompt_dict.get("width", 512) height = prompt_dict.get("height", 512) - guidance_scale = prompt_dict.get("guidance_scale", args.guidance_scale) - scale = prompt_dict.get("scale", 1.0) # 1.0 means no guidance + # TODO refactor variable names + cfg_scale = prompt_dict.get("guidance_scale", 1.0) + emb_guidance_scale = prompt_dict.get("scale", 3.5) seed = prompt_dict.get("seed") controlnet_image = prompt_dict.get("controlnet_image") prompt: str = prompt_dict.get("prompt", "") @@ -179,16 +180,16 @@ def sample_image_inference( height = max(64, height - height % 16) # round to divisible by 16 width = max(64, width - width % 16) # round to divisible by 16 logger.info(f"prompt: {prompt}") - if scale != 1.0: + if cfg_scale != 1.0: logger.info(f"negative_prompt: {negative_prompt}") elif negative_prompt != "": logger.info(f"negative prompt is ignored because scale is 1.0") logger.info(f"height: {height}") logger.info(f"width: {width}") logger.info(f"sample_steps: {sample_steps}") - logger.info(f"guidance_scale: {guidance_scale}") - if scale != 1.0: - logger.info(f"scale: {scale}") + logger.info(f"embedded guidance scale: {emb_guidance_scale}") + if cfg_scale != 1.0: + logger.info(f"CFG scale: {cfg_scale}") # logger.info(f"sample_sampler: {sampler_name}") if seed is not None: logger.info(f"seed: {seed}") @@ -220,12 +221,12 @@ def sample_image_inference( l_pooled, t5_out, txt_ids, t5_attn_mask = encode_prompt(prompt) # encode negative prompts - if scale != 1.0: + if cfg_scale != 1.0: neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode_prompt(negative_prompt) neg_t5_attn_mask = ( neg_t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask and neg_t5_attn_mask is not None else None ) - neg_cond = (scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask) + neg_cond = (cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask) else: neg_cond = None @@ -260,7 +261,7 @@ def sample_image_inference( txt_ids, l_pooled, timesteps=timesteps, - guidance=guidance_scale, + guidance=emb_guidance_scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image, From 29523c9b68bd56cdb1cce3f4985f2e45cefb1f2b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 27 Apr 2025 23:34:37 +0900 Subject: [PATCH 63/71] docs: add note for user feedback on CFG scale in FLUX.1 training --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index f9831aee..18e8e659 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ The command to install PyTorch is as follows: Apr 27, 2025: - FLUX.1 training now supports CFG scale in the sample generation during training. Please use `--g` option, to specify the CFG scale (note that `--l` is used as the embedded guidance scale.) PR [#2064](https://github.com/kohya-ss/sd-scripts/pull/2064). - See [here](#sample-image-generation-during-training) for details. + - If you have any issues with this, please let us know. Apr 6, 2025: - IP noise gamma has been enabled in FLUX.1. Thanks to rockerBOO for PR [#1992](https://github.com/kohya-ss/sd-scripts/pull/1992). See the PR for details. From 4625b34f4ebcd30ffb24f1b03ece8a8362c7bfb2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 29 Apr 2025 21:27:04 +0900 Subject: [PATCH 64/71] Fix mean image aspect ratio error calculation to avoid NaN values --- library/train_util.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 100ef475..b9d08f25 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -957,8 +957,11 @@ class BaseDataset(torch.utils.data.Dataset): self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)} logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}") - img_ar_errors = np.array(img_ar_errors) - mean_img_ar_error = np.mean(np.abs(img_ar_errors)) + if len(img_ar_errors) == 0: + mean_img_ar_error = 0 # avoid NaN + else: + img_ar_errors = np.array(img_ar_errors) + mean_img_ar_error = np.mean(np.abs(img_ar_errors)) self.bucket_info["mean_img_ar_error"] = mean_img_ar_error logger.info(f"mean ar error (without repeats): {mean_img_ar_error}") From 1684ababcd7fc4259c77f1471ef41d10e612a721 Mon Sep 17 00:00:00 2001 From: sharlynxy Date: Wed, 30 Apr 2025 19:51:09 +0800 Subject: [PATCH 65/71] remove deepspeed from requirements.txt --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9e97eed3..767d9e8e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ accelerate==0.33.0 transformers==4.44.0 diffusers[torch]==0.25.0 -deepspeed==0.16.7 ftfy==6.1.1 # albumentations==1.3.0 opencv-python==4.8.1.78 From a4fae93dce5b78a0c92ee328d6b2dd96be944a7d Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 1 May 2025 00:55:10 -0400 Subject: [PATCH 66/71] Add pythonpath to pytest.ini --- pytest.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/pytest.ini b/pytest.ini index 484d3aef..34b7e9c1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -6,3 +6,4 @@ filterwarnings = ignore::DeprecationWarning ignore::UserWarning ignore::FutureWarning +pythonpath = . From f62c68df3c96639e83dcb7f5062d48c3067055e1 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 1 May 2025 01:37:57 -0400 Subject: [PATCH 67/71] Make grad_norm and combined_grad_norm None is not recording --- networks/lora_flux.py | 12 ++++++------ train_network.py | 6 ++++-- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 92b3979a..0b30f1b8 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -955,26 +955,26 @@ class LoRANetwork(torch.nn.Module): for lora in self.text_encoder_loras + self.unet_loras: lora.update_grad_norms() - def grad_norms(self) -> Tensor: + def grad_norms(self) -> Tensor | None: grad_norms = [] for lora in self.text_encoder_loras + self.unet_loras: if hasattr(lora, "grad_norms") and lora.grad_norms is not None: grad_norms.append(lora.grad_norms.mean(dim=0)) - return torch.stack(grad_norms) if len(grad_norms) > 0 else torch.tensor([]) + return torch.stack(grad_norms) if len(grad_norms) > 0 else None - def weight_norms(self) -> Tensor: + def weight_norms(self) -> Tensor | None: weight_norms = [] for lora in self.text_encoder_loras + self.unet_loras: if hasattr(lora, "weight_norms") and lora.weight_norms is not None: weight_norms.append(lora.weight_norms.mean(dim=0)) - return torch.stack(weight_norms) if len(weight_norms) > 0 else torch.tensor([]) + return torch.stack(weight_norms) if len(weight_norms) > 0 else None - def combined_weight_norms(self) -> Tensor: + def combined_weight_norms(self) -> Tensor | None: combined_weight_norms = [] for lora in self.text_encoder_loras + self.unet_loras: if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None: combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0)) - return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([]) + return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None def load_weights(self, file): diff --git a/train_network.py b/train_network.py index d6bc66ed..2b4e6d3f 100644 --- a/train_network.py +++ b/train_network.py @@ -1444,8 +1444,10 @@ class NetworkTrainer: else: if hasattr(network, "weight_norms"): mean_norm = network.weight_norms().mean().item() - mean_grad_norm = network.grad_norms().mean().item() - mean_combined_norm = network.combined_weight_norms().mean().item() + grad_norms = network.grad_norms() + mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None + combined_weight_norms = network.combined_weight_norms() + mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None weight_norms = network.weight_norms() maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None keys_scaled = None From b4a89c3cdf7319b6840f1e4a28a5a1001643bc22 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 1 May 2025 02:03:22 -0400 Subject: [PATCH 68/71] Fix None --- train_network.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index 2b4e6d3f..1336a0b1 100644 --- a/train_network.py +++ b/train_network.py @@ -1443,13 +1443,13 @@ class NetworkTrainer: max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} else: if hasattr(network, "weight_norms"): - mean_norm = network.weight_norms().mean().item() + weight_norms = network.weight_norms() + mean_norm = weight_norms.mean().item() if weight_norms is not None else None grad_norms = network.grad_norms() mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None combined_weight_norms = network.combined_weight_norms() mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None - weight_norms = network.weight_norms() - maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None + maximum_norm = weight_norms.max().item() if weight_norms is not None else None keys_scaled = None max_mean_logs = {} else: From 865c8d55e2b8cd9f0b6008a6d4ee4a07949d9acc Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 1 May 2025 23:29:19 +0900 Subject: [PATCH 69/71] README.md: Update recent updates and add DeepSpeed installation instructions --- README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/README.md b/README.md index 18e8e659..13c2320c 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,10 @@ The command to install PyTorch is as follows: ### Recent Updates +May 1, 2025: +- The error when training FLUX.1 with mixed precision in flux_train.py with DeepSpeed enabled has been resolved. Thanks to sharlynxy for PR [#2060](https://github.com/kohya-ss/sd-scripts/pull/2060). Please refer to the PR for details. + - If you enable DeepSpeed, please install deepseed with `pip install deepspeed==0.16.7`. + Apr 27, 2025: - FLUX.1 training now supports CFG scale in the sample generation during training. Please use `--g` option, to specify the CFG scale (note that `--l` is used as the embedded guidance scale.) PR [#2064](https://github.com/kohya-ss/sd-scripts/pull/2064). - See [here](#sample-image-generation-during-training) for details. @@ -875,6 +879,14 @@ Note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o (Single GPU with id `0` will be used.) +## DeepSpeed installation (experimental, Linux or WSL2 only) + +To install DeepSpeed, run the following command in your activated virtual environment: + +```bash +pip install deepspeed==0.16.7 +``` + ## Upgrade When a new release comes out you can upgrade your repo with the following command: From a27ace74d96d9519629283f4ff3d207c1ad8d98e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 1 May 2025 23:31:23 +0900 Subject: [PATCH 70/71] doc: add DeepSpeed installation in header section --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 13c2320c..497969ab 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,8 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`. + - [FLUX.1 training](#flux1-training) - [SD3 training](#sd3-training) @@ -16,7 +18,7 @@ The command to install PyTorch is as follows: May 1, 2025: - The error when training FLUX.1 with mixed precision in flux_train.py with DeepSpeed enabled has been resolved. Thanks to sharlynxy for PR [#2060](https://github.com/kohya-ss/sd-scripts/pull/2060). Please refer to the PR for details. - - If you enable DeepSpeed, please install deepseed with `pip install deepspeed==0.16.7`. + - If you enable DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`. Apr 27, 2025: - FLUX.1 training now supports CFG scale in the sample generation during training. Please use `--g` option, to specify the CFG scale (note that `--l` is used as the embedded guidance scale.) PR [#2064](https://github.com/kohya-ss/sd-scripts/pull/2064). From 2bfda1271bcdbe13e823579fc406f3eaa229573b Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 19 May 2025 20:25:42 -0400 Subject: [PATCH 71/71] Update workflows to read-all instead of write-all --- .github/workflows/tests.yml | 5 ++++- .github/workflows/typos.yml | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2eddedc7..9e037e53 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -12,6 +12,9 @@ on: - dev - sd3 +# CKV2_GHA_1: "Ensure top-level permissions are not set to write-all" +permissions: read-all + jobs: build: runs-on: ${{ matrix.os }} @@ -40,7 +43,7 @@ jobs: - name: Install dependencies run: | # Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch) - pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision==0.19.0 pytest==8.3.4 + pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4 pip install -r requirements.txt - name: Test with pytest diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index f53cda21..b9d6acc9 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -12,6 +12,9 @@ on: - synchronize - reopened +# CKV2_GHA_1: "Ensure top-level permissions are not set to write-all" +permissions: read-all + jobs: build: runs-on: ubuntu-latest