From 532f5c58a6e83a3400f82103f5854ff3f63d77d7 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 20:50:42 +0900 Subject: [PATCH 01/33] formatting --- train_network.py | 229 ++++++++++++++++++++++------------------------- 1 file changed, 108 insertions(+), 121 deletions(-) diff --git a/train_network.py b/train_network.py index 2c3bb2aa..cc54be7c 100644 --- a/train_network.py +++ b/train_network.py @@ -100,9 +100,7 @@ class NetworkTrainer: if ( args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None ): # tracking d*lr value of unet. - logs["lr/d*lr"] = ( - optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] - ) + logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] else: idx = 0 if not args.network_train_unet_only: @@ -115,16 +113,17 @@ class NetworkTrainer: logs[f"lr/d*lr/group{i}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) - if ( - args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None - ): - logs[f"lr/d*lr/group{i}"] = ( - optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] - ) + if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None: + logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] return logs - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): train_dataset_group.verify_bucket_reso_steps(64) if val_dataset_group is not None: val_dataset_group.verify_bucket_reso_steps(64) @@ -219,7 +218,7 @@ class NetworkTrainer: network, weight_dtype, train_unet, - is_train=True + is_train=True, ): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified @@ -315,22 +314,22 @@ class NetworkTrainer: # endregion def process_batch( - self, - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy: strategy_base.TextEncodingStrategy, - tokenize_strategy: strategy_base.TokenizeStrategy, - is_train=True, - train_text_encoder=True, - train_unet=True + self, + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy: strategy_base.TextEncodingStrategy, + tokenize_strategy: strategy_base.TokenizeStrategy, + is_train=True, + train_text_encoder=True, + train_unet=True, ) -> torch.Tensor: """ Process a batch for the network @@ -397,7 +396,7 @@ class NetworkTrainer: network, weight_dtype, train_unet, - is_train=is_train + is_train=is_train, ) huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) @@ -484,7 +483,7 @@ class NetworkTrainer: else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args) - val_dataset_group = None # placeholder until validation dataset supported for arbitrary + val_dataset_group = None # placeholder until validation dataset supported for arbitrary current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -701,7 +700,7 @@ class NetworkTrainer: num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) - + val_dataloader = torch.utils.data.DataLoader( val_dataset_group if val_dataset_group is not None else [], shuffle=False, @@ -900,7 +899,9 @@ class NetworkTrainer: accelerator.print("running training / 学習開始") accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}") + accelerator.print( + f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}" + ) accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") @@ -968,11 +969,11 @@ class NetworkTrainer: "ss_huber_c": args.huber_c, "ss_fp8_base": bool(args.fp8_base), "ss_fp8_base_unet": bool(args.fp8_base_unet), - "ss_validation_seed": args.validation_seed, - "ss_validation_split": args.validation_split, - "ss_max_validation_steps": args.max_validation_steps, - "ss_validate_every_n_epochs": args.validate_every_n_epochs, - "ss_validate_every_n_steps": args.validate_every_n_steps, + "ss_validation_seed": args.validation_seed, + "ss_validation_split": args.validation_split, + "ss_max_validation_steps": args.max_validation_steps, + "ss_validate_every_n_epochs": args.validate_every_n_epochs, + "ss_validate_every_n_steps": args.validate_every_n_steps, } self.update_metadata(metadata, args) # architecture specific metadata @@ -1248,9 +1249,7 @@ class NetworkTrainer: accelerator.log({}, step=0) validation_steps = ( - min(args.max_validation_steps, len(val_dataloader)) - if args.max_validation_steps is not None - else len(val_dataloader) + min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) ) # training loop @@ -1298,21 +1297,21 @@ class NetworkTrainer: self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, - is_train=True, - train_text_encoder=train_text_encoder, - train_unet=train_unet + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=True, + train_text_encoder=train_text_encoder, + train_unet=train_unet, ) accelerator.backward(loss) @@ -1369,32 +1368,21 @@ class NetworkTrainer: if args.scale_weight_norms: progress_bar.set_postfix(**{**max_mean_logs, **logs}) - if is_tracking: logs = self.generate_step_logs( - args, - current_loss, - avr_loss, - lr_scheduler, - lr_descriptions, - optimizer, - keys_scaled, - mean_norm, - maximum_norm + args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm ) accelerator.log(logs, step=global_step) # VALIDATION PER STEP should_validate_step = ( - args.validate_every_n_steps is not None - and global_step != 0 # Skip first step + args.validate_every_n_steps is not None + and global_step != 0 # Skip first step and global_step % args.validate_every_n_steps == 0 ) if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: val_progress_bar = tqdm( - range(validation_steps), smoothing=0, - disable=not accelerator.is_local_main_process, - desc="validation steps" + range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps" ) for val_step, batch in enumerate(val_dataloader): if val_step >= validation_steps: @@ -1404,27 +1392,27 @@ class NetworkTrainer: self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, is_train=False, - train_text_encoder=False, - train_unet=False + train_text_encoder=False, + train_unet=False, ) current_loss = loss.detach().item() val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_progress_bar.update(1) - val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average }) + val_progress_bar.set_postfix({"val_avg_loss": val_step_loss_recorder.moving_average}) if is_tracking: logs = { @@ -1436,26 +1424,25 @@ class NetworkTrainer: if is_tracking: loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average logs = { - "loss/validation/step_average": val_step_loss_recorder.moving_average, - "loss/validation/step_divergence": loss_validation_divergence, + "loss/validation/step_average": val_step_loss_recorder.moving_average, + "loss/validation/step_divergence": loss_validation_divergence, } accelerator.log(logs, step=global_step) - + if global_step >= args.max_train_steps: break # EPOCH VALIDATION should_validate_epoch = ( - (epoch + 1) % args.validate_every_n_epochs == 0 - if args.validate_every_n_epochs is not None - else True + (epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None else True ) if should_validate_epoch and len(val_dataloader) > 0: val_progress_bar = tqdm( - range(validation_steps), smoothing=0, - disable=not accelerator.is_local_main_process, - desc="epoch validation steps" + range(validation_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="epoch validation steps", ) for val_step, batch in enumerate(val_dataloader): @@ -1466,43 +1453,43 @@ class NetworkTrainer: self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, is_train=False, - train_text_encoder=False, - train_unet=False + train_text_encoder=False, + train_unet=False, ) current_loss = loss.detach().item() val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_progress_bar.update(1) - val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average }) + val_progress_bar.set_postfix({"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average}) if is_tracking: logs = { - "loss/validation/epoch_current": current_loss, - "epoch": epoch + 1, - "val_step": (epoch * validation_steps) + val_step + "loss/validation/epoch_current": current_loss, + "epoch": epoch + 1, + "val_step": (epoch * validation_steps) + val_step, } accelerator.log(logs, step=global_step) if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss + loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss logs = { - "loss/validation/epoch_average": avr_loss, - "loss/validation/epoch_divergence": loss_validation_divergence, - "epoch": epoch + 1 + "loss/validation/epoch_average": avr_loss, + "loss/validation/epoch_divergence": loss_validation_divergence, + "epoch": epoch + 1, } accelerator.log(logs, step=global_step) @@ -1510,7 +1497,7 @@ class NetworkTrainer: if is_tracking: logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} accelerator.log(logs, step=global_step) - + accelerator.wait_for_everyone() # 指定エポックごとにモデルを保存 @@ -1696,31 +1683,31 @@ def setup_parser() -> argparse.ArgumentParser: "--validation_seed", type=int, default=None, - help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する" + help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する", ) parser.add_argument( "--validation_split", type=float, default=0.0, - help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合" + help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合", ) parser.add_argument( "--validate_every_n_steps", type=int, default=None, - help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます" + help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます", ) parser.add_argument( "--validate_every_n_epochs", type=int, default=None, - help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます" + help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます", ) parser.add_argument( "--max_validation_steps", type=int, default=None, - help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します" + help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します", ) return parser From 86a2f3fd262e52b3249d9f5508efe4774f1fa3ed Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 21:10:52 +0900 Subject: [PATCH 02/33] Fix gradient handling when Text Encoders are trained --- flux_train_network.py | 43 ++----------------------------------------- sd3_train_network.py | 2 +- train_network.py | 10 +++++----- 3 files changed, 8 insertions(+), 47 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 5cd1b9d5..475bd751 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -376,9 +376,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): t5_attn_mask = None def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): - # if not args.split_mode: - # normal forward - with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): + # grad is enabled even if unet is not in train mode, because Text Encoder is in train mode + with torch.set_grad_enabled(is_train), accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = unet( img=img, @@ -390,44 +389,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): guidance=guidance_vec, txt_attention_mask=t5_attn_mask, ) - """ - else: - # split forward to reduce memory usage - assert network.train_blocks == "single", "train_blocks must be single for split mode" - with accelerator.autocast(): - # move flux lower to cpu, and then move flux upper to gpu - unet.to("cpu") - clean_memory_on_device(accelerator.device) - self.flux_upper.to(accelerator.device) - - # upper model does not require grad - with torch.no_grad(): - intermediate_img, intermediate_txt, vec, pe = self.flux_upper( - img=packed_noisy_model_input, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, - ) - - # move flux upper back to cpu, and then move flux lower to gpu - self.flux_upper.to("cpu") - clean_memory_on_device(accelerator.device) - unet.to(accelerator.device) - - # lower model requires grad - intermediate_img.requires_grad_(True) - intermediate_txt.requires_grad_(True) - vec.requires_grad_(True) - pe.requires_grad_(True) - - with torch.set_grad_enabled(is_train and train_unet): - model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) - """ - return model_pred model_pred = call_dit( diff --git a/sd3_train_network.py b/sd3_train_network.py index dcf497f5..2f457949 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -345,7 +345,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): t5_attn_mask = None # call model - with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): + with torch.set_grad_enabled(is_train), accelerator.autocast(): # TODO support attention mask model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled) diff --git a/train_network.py b/train_network.py index cc54be7c..6f1652fd 100644 --- a/train_network.py +++ b/train_network.py @@ -232,7 +232,7 @@ class NetworkTrainer: t.requires_grad_(True) # Predict the noise residual - with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): + with torch.set_grad_enabled(is_train), accelerator.autocast(): noise_pred = self.call_unet( args, accelerator, @@ -1405,8 +1405,8 @@ class NetworkTrainer: text_encoding_strategy, tokenize_strategy, is_train=False, - train_text_encoder=False, - train_unet=False, + train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True + train_unet=train_unet, ) current_loss = loss.detach().item() @@ -1466,8 +1466,8 @@ class NetworkTrainer: text_encoding_strategy, tokenize_strategy, is_train=False, - train_text_encoder=False, - train_unet=False, + train_text_encoder=train_text_encoder, + train_unet=train_unet, ) current_loss = loss.detach().item() From b6a309321675b5d0a59b776ffb4d0ecdd3d28ec2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 21:22:11 +0900 Subject: [PATCH 03/33] call optimizer eval/train fn before/after validation --- train_network.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/train_network.py b/train_network.py index 6f1652fd..e735c582 100644 --- a/train_network.py +++ b/train_network.py @@ -1381,6 +1381,8 @@ class NetworkTrainer: and global_step % args.validate_every_n_steps == 0 ) if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: + optimizer_eval_fn() + val_progress_bar = tqdm( range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps" ) @@ -1429,6 +1431,8 @@ class NetworkTrainer: } accelerator.log(logs, step=global_step) + optimizer_train_fn() + if global_step >= args.max_train_steps: break @@ -1438,6 +1442,8 @@ class NetworkTrainer: ) if should_validate_epoch and len(val_dataloader) > 0: + optimizer_eval_fn() + val_progress_bar = tqdm( range(validation_steps), smoothing=0, @@ -1493,6 +1499,8 @@ class NetworkTrainer: } accelerator.log(logs, step=global_step) + optimizer_train_fn() + # END OF EPOCH if is_tracking: logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} From 29f31d005f12a08650389164fa9c60504928d451 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 21:35:43 +0900 Subject: [PATCH 04/33] add network.train()/eval() for validation --- train_network.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index e735c582..9b8036f8 100644 --- a/train_network.py +++ b/train_network.py @@ -1276,7 +1276,7 @@ class NetworkTrainer: metadata["ss_epoch"] = str(epoch + 1) - accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) + accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) # network.train() is called here # TRAINING skipped_dataloader = None @@ -1382,6 +1382,7 @@ class NetworkTrainer: ) if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() + accelerator.unwrap_model(network).eval() val_progress_bar = tqdm( range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps" @@ -1432,6 +1433,7 @@ class NetworkTrainer: accelerator.log(logs, step=global_step) optimizer_train_fn() + accelerator.unwrap_model(network).train() if global_step >= args.max_train_steps: break @@ -1443,6 +1445,7 @@ class NetworkTrainer: if should_validate_epoch and len(val_dataloader) > 0: optimizer_eval_fn() + accelerator.unwrap_model(network).eval() val_progress_bar = tqdm( range(validation_steps), @@ -1500,6 +1503,7 @@ class NetworkTrainer: accelerator.log(logs, step=global_step) optimizer_train_fn() + accelerator.unwrap_model(network).train() # END OF EPOCH if is_tracking: From 0750859133eec7858052cd3f79106113fa786e94 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 21:56:59 +0900 Subject: [PATCH 05/33] validation: Implement timestep-based validation processing --- sd3_train_network.py | 1 + train_network.py | 185 +++++++++++++++++++++++++------------------ 2 files changed, 109 insertions(+), 77 deletions(-) diff --git a/sd3_train_network.py b/sd3_train_network.py index 2f457949..d4f13125 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -446,6 +446,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): prepare_fp8(text_encoder, weight_dtype) def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + # TODO consider validation # drop cached text encoder outputs text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: diff --git a/train_network.py b/train_network.py index 9b8036f8..a63e9d1e 100644 --- a/train_network.py +++ b/train_network.py @@ -9,6 +9,7 @@ import random import time import json from multiprocessing import Value +import numpy as np import toml from tqdm import tqdm @@ -1248,10 +1249,6 @@ class NetworkTrainer: # log empty object to commit the sample images to wandb accelerator.log({}, step=0) - validation_steps = ( - min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - ) - # training loop if initial_step > 0: # only if skip_until_initial_step is specified for skip_epoch in range(epoch_to_start): # skip epochs @@ -1270,6 +1267,17 @@ class NetworkTrainer: clean_memory_on_device(accelerator.device) + validation_steps = ( + min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + ) + NUM_VALIDATION_TIMESTEPS = 4 # 200, 400, 600, 800 TODO make this configurable + min_timestep = 0 if args.min_timestep is None else args.min_timestep + max_timestep = noise_scheduler.num_train_timesteps if args.max_timestep is None else args.max_timestep + validation_timesteps = np.linspace(min_timestep, max_timestep, (NUM_VALIDATION_TIMESTEPS + 2), dtype=int)[1:-1] + validation_total_steps = validation_steps * len(validation_timesteps) + original_args_min_timestep = args.min_timestep + original_args_max_timestep = args.max_timestep + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") current_epoch.value = epoch + 1 @@ -1385,12 +1393,96 @@ class NetworkTrainer: accelerator.unwrap_model(network).eval() val_progress_bar = tqdm( - range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps" + range(validation_total_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="validation steps", ) + val_ts_step = 0 for val_step, batch in enumerate(val_dataloader): if val_step >= validation_steps: break + for timestep in validation_timesteps: + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + + args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep + + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False, + train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True + train_unet=train_unet, + ) + + current_loss = loss.detach().item() + val_step_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss) + val_progress_bar.update(1) + val_progress_bar.set_postfix( + {"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep} + ) + + if is_tracking: + logs = { + "loss/validation/step_current": current_loss, + "val_step": (epoch * validation_total_steps) + val_ts_step, + } + accelerator.log(logs, step=global_step) + + val_ts_step += 1 + + if is_tracking: + loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average + logs = { + "loss/validation/step_average": val_step_loss_recorder.moving_average, + "loss/validation/step_divergence": loss_validation_divergence, + } + accelerator.log(logs, step=global_step) + + args.min_timestep = original_args_min_timestep + args.max_timestep = original_args_max_timestep + optimizer_train_fn() + accelerator.unwrap_model(network).train() + + if global_step >= args.max_train_steps: + break + + # EPOCH VALIDATION + should_validate_epoch = ( + (epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None else True + ) + + if should_validate_epoch and len(val_dataloader) > 0: + optimizer_eval_fn() + accelerator.unwrap_model(network).eval() + + val_progress_bar = tqdm( + range(validation_total_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="epoch validation steps", + ) + + val_ts_step = 0 + for val_step, batch in enumerate(val_dataloader): + if val_step >= validation_steps: + break + + for timestep in validation_timesteps: + args.min_timestep = args.max_timestep = timestep + # temporary, for batch processing self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) @@ -1408,89 +1500,26 @@ class NetworkTrainer: text_encoding_strategy, tokenize_strategy, is_train=False, - train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True + train_text_encoder=train_text_encoder, train_unet=train_unet, ) current_loss = loss.detach().item() - val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) + val_epoch_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss) val_progress_bar.update(1) - val_progress_bar.set_postfix({"val_avg_loss": val_step_loss_recorder.moving_average}) + val_progress_bar.set_postfix( + {"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep} + ) if is_tracking: logs = { - "loss/validation/step_current": current_loss, - "val_step": (epoch * validation_steps) + val_step, + "loss/validation/epoch_current": current_loss, + "epoch": epoch + 1, + "val_step": (epoch * validation_total_steps) + val_ts_step, } accelerator.log(logs, step=global_step) - if is_tracking: - loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average - logs = { - "loss/validation/step_average": val_step_loss_recorder.moving_average, - "loss/validation/step_divergence": loss_validation_divergence, - } - accelerator.log(logs, step=global_step) - - optimizer_train_fn() - accelerator.unwrap_model(network).train() - - if global_step >= args.max_train_steps: - break - - # EPOCH VALIDATION - should_validate_epoch = ( - (epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None else True - ) - - if should_validate_epoch and len(val_dataloader) > 0: - optimizer_eval_fn() - accelerator.unwrap_model(network).eval() - - val_progress_bar = tqdm( - range(validation_steps), - smoothing=0, - disable=not accelerator.is_local_main_process, - desc="epoch validation steps", - ) - - for val_step, batch in enumerate(val_dataloader): - if val_step >= validation_steps: - break - - # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - - loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, - is_train=False, - train_text_encoder=train_text_encoder, - train_unet=train_unet, - ) - - current_loss = loss.detach().item() - val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) - val_progress_bar.update(1) - val_progress_bar.set_postfix({"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average}) - - if is_tracking: - logs = { - "loss/validation/epoch_current": current_loss, - "epoch": epoch + 1, - "val_step": (epoch * validation_steps) + val_step, - } - accelerator.log(logs, step=global_step) + val_ts_step += 1 if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average @@ -1502,6 +1531,8 @@ class NetworkTrainer: } accelerator.log(logs, step=global_step) + args.min_timestep = original_args_min_timestep + args.max_timestep = original_args_max_timestep optimizer_train_fn() accelerator.unwrap_model(network).train() From 45ec02b2a8b5eb5af8f5b4877381dc4dcc596cb9 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 22:10:38 +0900 Subject: [PATCH 06/33] use same noise for every validation --- flux_train_network.py | 1 - train_network.py | 6 ++++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/flux_train_network.py b/flux_train_network.py index aab02573..475bd751 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -377,7 +377,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): # grad is enabled even if unet is not in train mode, because Text Encoder is in train mode - with torch.set_grad_enabled(is_train), accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = unet( diff --git a/train_network.py b/train_network.py index a63e9d1e..f0deb67a 100644 --- a/train_network.py +++ b/train_network.py @@ -1391,6 +1391,8 @@ class NetworkTrainer: if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() accelerator.unwrap_model(network).eval() + rng_state = torch.get_rng_state() + torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( range(validation_total_steps), @@ -1451,6 +1453,7 @@ class NetworkTrainer: } accelerator.log(logs, step=global_step) + torch.set_rng_state(rng_state) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() @@ -1467,6 +1470,8 @@ class NetworkTrainer: if should_validate_epoch and len(val_dataloader) > 0: optimizer_eval_fn() accelerator.unwrap_model(network).eval() + rng_state = torch.get_rng_state() + torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( range(validation_total_steps), @@ -1531,6 +1536,7 @@ class NetworkTrainer: } accelerator.log(logs, step=global_step) + torch.set_rng_state(rng_state) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() From de830b89416f0671d7a1364a9262fa850c0669df Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 29 Jan 2025 00:02:45 -0500 Subject: [PATCH 07/33] Move progress bar to account for sampling image first --- train_network.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/train_network.py b/train_network.py index c3879531..2deb736d 100644 --- a/train_network.py +++ b/train_network.py @@ -1163,10 +1163,6 @@ class NetworkTrainer: args.max_train_steps > initial_step ), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}" - progress_bar = tqdm( - range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps" - ) - epoch_to_start = 0 if initial_step > 0: if args.skip_until_initial_step: @@ -1271,6 +1267,10 @@ class NetworkTrainer: clean_memory_on_device(accelerator.device) + progress_bar = tqdm( + range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps" + ) + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") current_epoch.value = epoch + 1 From 4a71687d20787d78a30b7a0df327067f5c402999 Mon Sep 17 00:00:00 2001 From: tsukimiya <71832+tsukimiya@users.noreply.github.com> Date: Tue, 4 Feb 2025 00:42:27 +0900 Subject: [PATCH 08/33] =?UTF-8?q?=E4=B8=8D=E8=A6=81=E3=81=AA=E8=AD=A6?= =?UTF-8?q?=E5=91=8A=E3=81=AE=E5=89=8A=E9=99=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (おそらく https://github.com/kohya-ss/sd-scripts/commit/be14c062674973d0e4fee1eb4527e04707bb72b8 の修正漏れ ) --- library/sdxl_train_util.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index b74bea91..f78d9424 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -345,8 +345,6 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser): def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" - if args.v_parameterization: - logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") if args.clip_skip is not None: logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") From c5b803ce94bd70812e6979ac7b986a769659b14e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 4 Feb 2025 21:59:09 +0900 Subject: [PATCH 09/33] rng state management: Implement functions to get and set RNG states for consistent validation --- train_network.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/train_network.py b/train_network.py index f0deb67a..b3c7ff52 100644 --- a/train_network.py +++ b/train_network.py @@ -1278,6 +1278,31 @@ class NetworkTrainer: original_args_min_timestep = args.min_timestep original_args_max_timestep = args.max_timestep + def get_rng_state() -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: + cpu_rng_state = torch.get_rng_state() + if accelerator.device.type == "cuda": + gpu_rng_state = torch.cuda.get_rng_state() + elif accelerator.device.type == "xpu": + gpu_rng_state = torch.xpu.get_rng_state() + elif accelerator.device.type == "mps": + gpu_rng_state = torch.cuda.get_rng_state() + else: + gpu_rng_state = None + python_rng_state = random.getstate() + return (cpu_rng_state, gpu_rng_state, python_rng_state) + + def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]): + cpu_rng_state, gpu_rng_state, python_rng_state = rng_states + torch.set_rng_state(cpu_rng_state) + if gpu_rng_state is not None: + if accelerator.device.type == "cuda": + torch.cuda.set_rng_state(gpu_rng_state) + elif accelerator.device.type == "xpu": + torch.xpu.set_rng_state(gpu_rng_state) + elif accelerator.device.type == "mps": + torch.cuda.set_rng_state(gpu_rng_state) + random.setstate(python_rng_state) + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") current_epoch.value = epoch + 1 @@ -1391,7 +1416,7 @@ class NetworkTrainer: if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() accelerator.unwrap_model(network).eval() - rng_state = torch.get_rng_state() + rng_states = get_rng_state() torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( @@ -1453,7 +1478,7 @@ class NetworkTrainer: } accelerator.log(logs, step=global_step) - torch.set_rng_state(rng_state) + set_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() @@ -1470,7 +1495,7 @@ class NetworkTrainer: if should_validate_epoch and len(val_dataloader) > 0: optimizer_eval_fn() accelerator.unwrap_model(network).eval() - rng_state = torch.get_rng_state() + rng_states = get_rng_state() torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( @@ -1536,7 +1561,7 @@ class NetworkTrainer: } accelerator.log(logs, step=global_step) - torch.set_rng_state(rng_state) + set_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() From a24db1d532a95cc9dd91aba25a06b8eb58db5cff Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 4 Feb 2025 22:02:42 +0900 Subject: [PATCH 10/33] fix: validation timestep generation fails on SD/SDXL training --- library/train_util.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 37ed0a99..01fa6467 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5935,7 +5935,10 @@ def save_sd_model_on_train_end_common( def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor: - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") + if min_timestep < max_timestep: + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") + else: + timesteps = torch.full((b_size,), max_timestep, device="cpu") timesteps = timesteps.long().to(device) return timesteps From 0911683717e439676bba758a5f7a29356984966c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Feb 2025 20:53:49 +0900 Subject: [PATCH 11/33] set python random state --- train_network.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/train_network.py b/train_network.py index b3c7ff52..083e5993 100644 --- a/train_network.py +++ b/train_network.py @@ -1278,7 +1278,7 @@ class NetworkTrainer: original_args_min_timestep = args.min_timestep original_args_max_timestep = args.max_timestep - def get_rng_state() -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: + def switch_rng_state(seed:int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: cpu_rng_state = torch.get_rng_state() if accelerator.device.type == "cuda": gpu_rng_state = torch.cuda.get_rng_state() @@ -1289,9 +1289,13 @@ class NetworkTrainer: else: gpu_rng_state = None python_rng_state = random.getstate() + + torch.manual_seed(seed) + random.seed(seed) + return (cpu_rng_state, gpu_rng_state, python_rng_state) - def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]): + def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]): cpu_rng_state, gpu_rng_state, python_rng_state = rng_states torch.set_rng_state(cpu_rng_state) if gpu_rng_state is not None: @@ -1416,8 +1420,7 @@ class NetworkTrainer: if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() accelerator.unwrap_model(network).eval() - rng_states = get_rng_state() - torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) + rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( range(validation_total_steps), @@ -1478,7 +1481,7 @@ class NetworkTrainer: } accelerator.log(logs, step=global_step) - set_rng_state(rng_states) + restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() @@ -1495,8 +1498,7 @@ class NetworkTrainer: if should_validate_epoch and len(val_dataloader) > 0: optimizer_eval_fn() accelerator.unwrap_model(network).eval() - rng_states = get_rng_state() - torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) + rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( range(validation_total_steps), @@ -1561,7 +1563,7 @@ class NetworkTrainer: } accelerator.log(logs, step=global_step) - set_rng_state(rng_states) + restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() From 344845b42941b48956dce94d614fbf32e900c70e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Feb 2025 21:25:40 +0900 Subject: [PATCH 12/33] fix: validation with block swap --- flux_train_network.py | 14 ++++++++++++-- sd3_train_network.py | 19 ++++++++++++++----- train_network.py | 18 +++++++++++------- 3 files changed, 37 insertions(+), 14 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 475bd751..e97dfc5b 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -36,7 +36,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): self.is_schnell: Optional[bool] = None self.is_swapping_blocks: bool = False - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): super().assert_extra_args(args, train_dataset_group, val_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) @@ -341,7 +346,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): network, weight_dtype, train_unet, - is_train=True + is_train=True, ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -507,6 +512,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): text_encoder.to(te_weight_dtype) # fp8 prepare_fp8(text_encoder, weight_dtype) + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + if self.is_swapping_blocks: + # prepare for next forward: because backward pass is not called, we need to prepare it here + accelerator.unwrap_model(unet).prepare_block_swap_before_forward() + def prepare_unet_with_accelerator( self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module ) -> torch.nn.Module: diff --git a/sd3_train_network.py b/sd3_train_network.py index d4f13125..216d93c5 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -26,7 +26,12 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): super().__init__() self.sample_prompts_te_outputs = None - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): # super().assert_extra_args(args, train_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) @@ -317,7 +322,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): network, weight_dtype, train_unet, - is_train=True + is_train=True, ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -445,15 +450,19 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): text_encoder.to(te_weight_dtype) # fp8 prepare_fp8(text_encoder, weight_dtype) - def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): - # TODO consider validation - # drop cached text encoder outputs + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True): + # drop cached text encoder outputs: in validation, we drop cached outputs deterministically by fixed seed text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) batch["text_encoder_outputs_list"] = text_encoder_outputs_list + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + if self.is_swapping_blocks: + # prepare for next forward: because backward pass is not called, we need to prepare it here + accelerator.unwrap_model(unet).prepare_block_swap_before_forward() + def prepare_unet_with_accelerator( self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module ) -> torch.nn.Module: diff --git a/train_network.py b/train_network.py index 083e5993..49013c70 100644 --- a/train_network.py +++ b/train_network.py @@ -309,7 +309,10 @@ class NetworkTrainer: ) -> torch.nn.Module: return accelerator.prepare(unet) - def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train: bool = True): + pass + + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): pass # endregion @@ -1278,7 +1281,7 @@ class NetworkTrainer: original_args_min_timestep = args.min_timestep original_args_max_timestep = args.max_timestep - def switch_rng_state(seed:int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: + def switch_rng_state(seed: int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: cpu_rng_state = torch.get_rng_state() if accelerator.device.type == "cuda": gpu_rng_state = torch.cuda.get_rng_state() @@ -1330,8 +1333,8 @@ class NetworkTrainer: with accelerator.accumulate(training_model): on_step_start_for_network(text_encoder, unet) - # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + # preprocess batch for each model + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True) loss = self.process_batch( batch, @@ -1434,8 +1437,7 @@ class NetworkTrainer: break for timestep in validation_timesteps: - # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False) args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep @@ -1471,6 +1473,7 @@ class NetworkTrainer: } accelerator.log(logs, step=global_step) + self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) val_ts_step += 1 if is_tracking: @@ -1516,7 +1519,7 @@ class NetworkTrainer: args.min_timestep = args.max_timestep = timestep # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False) loss = self.process_batch( batch, @@ -1551,6 +1554,7 @@ class NetworkTrainer: } accelerator.log(logs, step=global_step) + self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) val_ts_step += 1 if is_tracking: From 177203818a024329efa74640a588674323363373 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 11 Feb 2025 21:42:46 +0900 Subject: [PATCH 13/33] fix: unpause training progress bar after vaidation --- train_network.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/train_network.py b/train_network.py index 49013c70..8bfb1925 100644 --- a/train_network.py +++ b/train_network.py @@ -1489,6 +1489,7 @@ class NetworkTrainer: args.max_timestep = original_args_max_timestep optimizer_train_fn() accelerator.unwrap_model(network).train() + progress_bar.unpause() if global_step >= args.max_train_steps: break @@ -1572,6 +1573,7 @@ class NetworkTrainer: args.max_timestep = original_args_max_timestep optimizer_train_fn() accelerator.unwrap_model(network).train() + progress_bar.unpause() # END OF EPOCH if is_tracking: From cd80752175c663ede2cb7995da652ed5f5f7f749 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 11 Feb 2025 21:42:58 +0900 Subject: [PATCH 14/33] fix: remove unused parameter 'accelerator' from encode_images_to_latents method --- flux_train_network.py | 2 +- sd3_train_network.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index e97dfc5b..def44155 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -328,7 +328,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) return noise_scheduler - def encode_images_to_latents(self, args, accelerator, vae, images): + def encode_images_to_latents(self, args, vae, images): return vae.encode(images) def shift_scale_latents(self, args, latents): diff --git a/sd3_train_network.py b/sd3_train_network.py index 216d93c5..cdb7aa4e 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -304,7 +304,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift) return noise_scheduler - def encode_images_to_latents(self, args, accelerator, vae, images): + def encode_images_to_latents(self, args, vae, images): return vae.encode(images) def shift_scale_latents(self, args, latents): From 76b761943b5166f496aa1cb8ffbcc2d04469346a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 11 Feb 2025 21:53:57 +0900 Subject: [PATCH 15/33] fix: simplify validation step condition in NetworkTrainer --- train_network.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/train_network.py b/train_network.py index 8bfb1925..99c58f49 100644 --- a/train_network.py +++ b/train_network.py @@ -1414,12 +1414,9 @@ class NetworkTrainer: ) accelerator.log(logs, step=global_step) - # VALIDATION PER STEP - should_validate_step = ( - args.validate_every_n_steps is not None - and global_step != 0 # Skip first step - and global_step % args.validate_every_n_steps == 0 - ) + # VALIDATION PER STEP: global_step is already incremented + # for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ... + should_validate_step = args.validate_every_n_steps is not None and global_step % args.validate_every_n_steps == 0 if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() accelerator.unwrap_model(network).eval() From 9436b410617f22716eac64f7c604c8f53fa8c1a8 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 17 Feb 2025 14:28:41 -0500 Subject: [PATCH 16/33] Fix validation split and add test --- library/train_util.py | 8 ++++++-- tests/test_validation.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) create mode 100644 tests/test_validation.py diff --git a/library/train_util.py b/library/train_util.py index 39b4af85..b2329066 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -161,15 +161,19 @@ def split_train_val( [0:80] = 80 training images [80:] = 20 validation images """ + dataset = list(zip(paths, sizes)) if validation_seed is not None: logging.info(f"Using validation seed: {validation_seed}") prevstate = random.getstate() random.seed(validation_seed) - random.shuffle(paths) + random.shuffle(dataset) random.setstate(prevstate) else: - random.shuffle(paths) + random.shuffle(dataset) + paths, sizes = zip(*dataset) + paths = list(paths) + sizes = list(sizes) # Split the dataset between training and validation if is_training_dataset: # Training dataset we split to the first part diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 00000000..f80686d8 --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,17 @@ +from library.train_util import split_train_val + + +def test_split_train_val(): + paths = ["path1", "path2", "path3", "path4", "path5", "path6", "path7"] + sizes = [(1, 1), (2, 2), None, (4, 4), (5, 5), (6, 6), None] + result_paths, result_sizes = split_train_val(paths, sizes, True, 0.2, 1234) + assert result_paths == ["path2", "path3", "path6", "path5", "path1", "path4"], result_paths + assert result_sizes == [(2, 2), None, (6, 6), (5, 5), (1, 1), (4, 4)], result_sizes + + result_paths, result_sizes = split_train_val(paths, sizes, False, 0.2, 1234) + assert result_paths == ["path7"], result_paths + assert result_sizes == [None], result_sizes + + +if __name__ == "__main__": + test_split_train_val() From 4a369961346ca153a370728247449978d8a33415 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 18 Feb 2025 22:05:08 +0900 Subject: [PATCH 17/33] modify log step calculation --- train_network.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/train_network.py b/train_network.py index 47c4bb56..93558da4 100644 --- a/train_network.py +++ b/train_network.py @@ -1464,11 +1464,10 @@ class NetworkTrainer: ) if is_tracking: - logs = { - "loss/validation/step_current": current_loss, - "val_step": (epoch * validation_total_steps) + val_ts_step, - } - accelerator.log(logs, step=global_step) + logs = {"loss/validation/step_current": current_loss} + accelerator.log( + logs, step=global_step + val_ts_step + ) # a bit weird to log with global_step + val_ts_step self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) val_ts_step += 1 @@ -1545,25 +1544,20 @@ class NetworkTrainer: ) if is_tracking: - logs = { - "loss/validation/epoch_current": current_loss, - "epoch": epoch + 1, - "val_step": (epoch * validation_total_steps) + val_ts_step, - } - accelerator.log(logs, step=global_step) + logs = {"loss/validation/epoch_current": current_loss} + accelerator.log(logs, step=global_step + val_ts_step) self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) val_ts_step += 1 if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average + loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average logs = { "loss/validation/epoch_average": avr_loss, "loss/validation/epoch_divergence": loss_validation_divergence, - "epoch": epoch + 1, } - accelerator.log(logs, step=global_step) + accelerator.log(logs, step=epoch + 1) restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep @@ -1574,8 +1568,8 @@ class NetworkTrainer: # END OF EPOCH if is_tracking: - logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} - accelerator.log(logs, step=global_step) + logs = {"loss/epoch_average": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() From 13df47516dda6e350b6aa79373b5a0e7287648b5 Mon Sep 17 00:00:00 2001 From: Yidi Date: Thu, 20 Feb 2025 04:49:51 -0500 Subject: [PATCH 18/33] Remove position_ids for V2 The postions_ids cause errors for the newer version of transformer. This has already been fixed in convert_ldm_clip_checkpoint_v1() but not in v2. The new code applies the same fix to convert_ldm_clip_checkpoint_v2(). --- library/model_util.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/library/model_util.py b/library/model_util.py index be410a02..9918c7b2 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -643,16 +643,15 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): new_sd[key_pfx + "k_proj" + key_suffix] = values[1] new_sd[key_pfx + "v_proj" + key_suffix] = values[2] - # rename or add position_ids + # remove position_ids for newer transformer, which causes error :( ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids" if ANOTHER_POSITION_IDS_KEY in new_sd: # waifu diffusion v1.4 - position_ids = new_sd[ANOTHER_POSITION_IDS_KEY] del new_sd[ANOTHER_POSITION_IDS_KEY] - else: - position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) - new_sd["text_model.embeddings.position_ids"] = position_ids + if "text_model.embeddings.position_ids" in new_sd: + del new_sd["text_model.embeddings.position_ids"] + return new_sd From efb2a128cd0d2c6340a21bf544e77853a20b3453 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 21 Feb 2025 22:07:35 +0900 Subject: [PATCH 19/33] fix wandb val logging --- library/train_util.py | 57 +++++++++++++++------------------ train_network.py | 73 ++++++++++++++++++++++++++++++++----------- 2 files changed, 80 insertions(+), 50 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 25870198..1f591c42 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -13,17 +13,7 @@ import re import shutil import time import typing -from typing import ( - Any, - Callable, - Dict, - List, - NamedTuple, - Optional, - Sequence, - Tuple, - Union -) +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState import glob import math @@ -146,12 +136,13 @@ IMAGE_TRANSFORMS = transforms.Compose( TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" + def split_train_val( - paths: List[str], + paths: List[str], sizes: List[Optional[Tuple[int, int]]], - is_training_dataset: bool, - validation_split: float, - validation_seed: int | None + is_training_dataset: bool, + validation_split: float, + validation_seed: int | None, ) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]: """ Split the dataset into train and validation @@ -1842,7 +1833,7 @@ class BaseDataset(torch.utils.data.Dataset): class DreamBoothDataset(BaseDataset): IMAGE_INFO_CACHE_FILE = "metadata_cache.json" - # The is_training_dataset defines the type of dataset, training or validation + # The is_training_dataset defines the type of dataset, training or validation # if is_training_dataset is True -> training dataset # if is_training_dataset is False -> validation dataset def __init__( @@ -1981,29 +1972,25 @@ class DreamBoothDataset(BaseDataset): logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") # We want to create a training and validation split. This should be improved in the future - # to allow a clearer distinction between training and validation. This can be seen as a + # to allow a clearer distinction between training and validation. This can be seen as a # short-term solution to limit what is necessary to implement validation datasets - # + # # We split the dataset for the subset based on if we are doing a validation split - # The self.is_training_dataset defines the type of dataset, training or validation + # The self.is_training_dataset defines the type of dataset, training or validation # if self.is_training_dataset is True -> training dataset # if self.is_training_dataset is False -> validation dataset if self.validation_split > 0.0: - # For regularization images we do not want to split this dataset. + # For regularization images we do not want to split this dataset. if subset.is_reg is True: # Skip any validation dataset for regularization images if self.is_training_dataset is False: img_paths = [] sizes = [] - # Otherwise the img_paths remain as original img_paths and no split + # Otherwise the img_paths remain as original img_paths and no split # required for training images dataset of regularization images else: img_paths, sizes = split_train_val( - img_paths, - sizes, - self.is_training_dataset, - self.validation_split, - self.validation_seed + img_paths, sizes, self.is_training_dataset, self.validation_split, self.validation_seed ) logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") @@ -2373,7 +2360,7 @@ class ControlNetDataset(BaseDataset): bucket_no_upscale: bool, debug_dataset: bool, validation_split: float, - validation_seed: Optional[int], + validation_seed: Optional[int], ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -2431,9 +2418,9 @@ class ControlNetDataset(BaseDataset): self.image_data = self.dreambooth_dataset_delegate.image_data self.batch_size = batch_size self.num_train_images = self.dreambooth_dataset_delegate.num_train_images - self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images + self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images self.validation_split = validation_split - self.validation_seed = validation_seed + self.validation_seed = validation_seed # assert all conditioning data exists missing_imgs = [] @@ -5952,7 +5939,9 @@ def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: tor return timesteps -def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]: +def get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents: torch.FloatTensor +) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]: # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: @@ -6444,7 +6433,7 @@ def sample_image_inference( wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption -def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str): +def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str): """ Initialize experiment trackers with tracker specific behaviors """ @@ -6461,13 +6450,17 @@ def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tr ) if "wandb" in [tracker.name for tracker in accelerator.trackers]: - import wandb + import wandb + wandb_tracker = accelerator.get_tracker("wandb", unwrap=True) # Define specific metrics to handle validation and epochs "steps" wandb_tracker.define_metric("epoch", hidden=True) wandb_tracker.define_metric("val_step", hidden=True) + wandb_tracker.define_metric("global_step", hidden=True) + + # endregion diff --git a/train_network.py b/train_network.py index 93558da4..ab5483de 100644 --- a/train_network.py +++ b/train_network.py @@ -119,6 +119,45 @@ class NetworkTrainer: return logs + def step_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int): + self.accelerator_logging(accelerator, logs, global_step, global_step, epoch) + + def epoch_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int): + self.accelerator_logging(accelerator, logs, epoch, global_step, epoch) + + def val_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int, val_step: int): + self.accelerator_logging(accelerator, logs, global_step + val_step, global_step, epoch, val_step) + + def accelerator_logging( + self, accelerator: Accelerator, logs: dict, step_value: int, global_step: int, epoch: int, val_step: Optional[int] = None + ): + """ + step_value is for tensorboard, other values are for wandb + """ + tensorboard_tracker = None + wandb_tracker = None + other_trackers = [] + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + tensorboard_tracker = accelerator.get_tracker("tensorboard") + elif tracker.name == "wandb": + wandb_tracker = accelerator.get_tracker("wandb") + else: + other_trackers.append(accelerator.get_tracker(tracker.name)) + + if tensorboard_tracker is not None: + tensorboard_tracker.log(logs, step=step_value) + + if wandb_tracker is not None: + logs["global_step"] = global_step + logs["epoch"] = epoch + if val_step is not None: + logs["val_step"] = val_step + wandb_tracker.log(logs) + + for tracker in other_trackers: + tracker.log(logs, step=step_value) + def assert_extra_args( self, args, @@ -1412,7 +1451,7 @@ class NetworkTrainer: logs = self.generate_step_logs( args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm ) - accelerator.log(logs, step=global_step) + self.step_logging(accelerator, logs, global_step, epoch + 1) # VALIDATION PER STEP: global_step is already incremented # for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ... @@ -1428,7 +1467,7 @@ class NetworkTrainer: disable=not accelerator.is_local_main_process, desc="validation steps", ) - val_ts_step = 0 + val_timesteps_step = 0 for val_step, batch in enumerate(val_dataloader): if val_step >= validation_steps: break @@ -1457,20 +1496,18 @@ class NetworkTrainer: ) current_loss = loss.detach().item() - val_step_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss) + val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) val_progress_bar.update(1) val_progress_bar.set_postfix( {"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep} ) - if is_tracking: - logs = {"loss/validation/step_current": current_loss} - accelerator.log( - logs, step=global_step + val_ts_step - ) # a bit weird to log with global_step + val_ts_step + # if is_tracking: + # logs = {f"loss/validation/step_current_{timestep}": current_loss} + # self.val_logging(accelerator, logs, global_step, epoch + 1, val_step) self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - val_ts_step += 1 + val_timesteps_step += 1 if is_tracking: loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average @@ -1478,7 +1515,7 @@ class NetworkTrainer: "loss/validation/step_average": val_step_loss_recorder.moving_average, "loss/validation/step_divergence": loss_validation_divergence, } - accelerator.log(logs, step=global_step) + self.step_logging(accelerator, logs, global_step, epoch=epoch + 1) restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep @@ -1507,7 +1544,7 @@ class NetworkTrainer: desc="epoch validation steps", ) - val_ts_step = 0 + val_timesteps_step = 0 for val_step, batch in enumerate(val_dataloader): if val_step >= validation_steps: break @@ -1537,18 +1574,18 @@ class NetworkTrainer: ) current_loss = loss.detach().item() - val_epoch_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss) + val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) val_progress_bar.update(1) val_progress_bar.set_postfix( {"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep} ) - if is_tracking: - logs = {"loss/validation/epoch_current": current_loss} - accelerator.log(logs, step=global_step + val_ts_step) + # if is_tracking: + # logs = {f"loss/validation/epoch_current_{timestep}": current_loss} + # self.val_logging(accelerator, logs, global_step, epoch + 1, val_step) self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - val_ts_step += 1 + val_timesteps_step += 1 if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average @@ -1557,7 +1594,7 @@ class NetworkTrainer: "loss/validation/epoch_average": avr_loss, "loss/validation/epoch_divergence": loss_validation_divergence, } - accelerator.log(logs, step=epoch + 1) + self.epoch_logging(accelerator, logs, global_step, epoch + 1) restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep @@ -1569,7 +1606,7 @@ class NetworkTrainer: # END OF EPOCH if is_tracking: logs = {"loss/epoch_average": loss_recorder.moving_average} - accelerator.log(logs, step=epoch + 1) + self.epoch_logging(accelerator, logs, global_step, epoch + 1) accelerator.wait_for_everyone() From f68702f71c16719d0f85820a2a4585f19b96552f Mon Sep 17 00:00:00 2001 From: Disty0 Date: Tue, 25 Feb 2025 21:27:41 +0300 Subject: [PATCH 20/33] Update IPEX libs --- library/device_utils.py | 11 +- library/ipex/__init__.py | 170 +++++++++++------- library/ipex/attention.py | 220 +++++++++-------------- library/ipex/diffusers.py | 349 +++++-------------------------------- library/ipex/gradscaler.py | 2 +- library/ipex/hijacks.py | 134 +++++++++----- 6 files changed, 330 insertions(+), 556 deletions(-) diff --git a/library/device_utils.py b/library/device_utils.py index 8823c5d9..d2e19745 100644 --- a/library/device_utils.py +++ b/library/device_utils.py @@ -2,6 +2,13 @@ import functools import gc import torch +try: + # intel gpu support for pytorch older than 2.5 + # ipex is not needed after pytorch 2.5 + import intel_extension_for_pytorch as ipex # noqa +except Exception: + pass + try: HAS_CUDA = torch.cuda.is_available() @@ -14,8 +21,6 @@ except Exception: HAS_MPS = False try: - import intel_extension_for_pytorch as ipex # noqa - HAS_XPU = torch.xpu.is_available() except Exception: HAS_XPU = False @@ -69,7 +74,7 @@ def init_ipex(): This function should run right after importing torch and before doing anything else. - If IPEX is not available, this function does nothing. + If xpu is not available, this function does nothing. """ try: if HAS_XPU: diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index e5aba693..a36664bb 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -2,7 +2,11 @@ import os import sys import contextlib import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +try: + import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import + legacy = True +except Exception: + legacy = False from .hijacks import ipex_hijacks # pylint: disable=protected-access, missing-function-docstring, line-too-long @@ -12,6 +16,13 @@ def ipex_init(): # pylint: disable=too-many-statements if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked: return True, "Skipping IPEX hijack" else: + try: # force xpu device on torch compile and triton + torch._inductor.utils.GPU_TYPES = ["xpu"] + torch._inductor.utils.get_gpu_type = lambda *args, **kwargs: "xpu" + from triton import backends as triton_backends # pylint: disable=import-error + triton_backends.backends["nvidia"].driver.is_active = lambda *args, **kwargs: False + except Exception: + pass # Replace cuda with xpu: torch.cuda.current_device = torch.xpu.current_device torch.cuda.current_stream = torch.xpu.current_stream @@ -26,84 +37,99 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.is_current_stream_capturing = lambda: False torch.cuda.set_device = torch.xpu.set_device torch.cuda.stream = torch.xpu.stream - torch.cuda.synchronize = torch.xpu.synchronize torch.cuda.Event = torch.xpu.Event torch.cuda.Stream = torch.xpu.Stream - torch.cuda.FloatTensor = torch.xpu.FloatTensor torch.Tensor.cuda = torch.Tensor.xpu torch.Tensor.is_cuda = torch.Tensor.is_xpu torch.nn.Module.cuda = torch.nn.Module.xpu - torch.UntypedStorage.cuda = torch.UntypedStorage.xpu - torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock - torch.cuda._initialized = torch.xpu.lazy_init._initialized - torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker - torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls - torch.cuda._tls = torch.xpu.lazy_init._tls - torch.cuda.threading = torch.xpu.lazy_init.threading - torch.cuda.traceback = torch.xpu.lazy_init.traceback torch.cuda.Optional = torch.xpu.Optional torch.cuda.__cached__ = torch.xpu.__cached__ torch.cuda.__loader__ = torch.xpu.__loader__ - torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage torch.cuda.Tuple = torch.xpu.Tuple torch.cuda.streams = torch.xpu.streams - torch.cuda._lazy_new = torch.xpu._lazy_new - torch.cuda.FloatStorage = torch.xpu.FloatStorage torch.cuda.Any = torch.xpu.Any torch.cuda.__doc__ = torch.xpu.__doc__ torch.cuda.default_generators = torch.xpu.default_generators - torch.cuda.HalfTensor = torch.xpu.HalfTensor torch.cuda._get_device_index = torch.xpu._get_device_index torch.cuda.__path__ = torch.xpu.__path__ - torch.cuda.Device = torch.xpu.Device - torch.cuda.IntTensor = torch.xpu.IntTensor - torch.cuda.ByteStorage = torch.xpu.ByteStorage torch.cuda.set_stream = torch.xpu.set_stream - torch.cuda.BoolStorage = torch.xpu.BoolStorage - torch.cuda.os = torch.xpu.os torch.cuda.torch = torch.xpu.torch - torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage torch.cuda.Union = torch.xpu.Union - torch.cuda.DoubleTensor = torch.xpu.DoubleTensor - torch.cuda.ShortTensor = torch.xpu.ShortTensor - torch.cuda.LongTensor = torch.xpu.LongTensor - torch.cuda.IntStorage = torch.xpu.IntStorage - torch.cuda.LongStorage = torch.xpu.LongStorage torch.cuda.__annotations__ = torch.xpu.__annotations__ torch.cuda.__package__ = torch.xpu.__package__ torch.cuda.__builtins__ = torch.xpu.__builtins__ - torch.cuda.CharTensor = torch.xpu.CharTensor torch.cuda.List = torch.xpu.List torch.cuda._lazy_init = torch.xpu._lazy_init - torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor - torch.cuda.DoubleStorage = torch.xpu.DoubleStorage - torch.cuda.ByteTensor = torch.xpu.ByteTensor torch.cuda.StreamContext = torch.xpu.StreamContext - torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage - torch.cuda.ShortStorage = torch.xpu.ShortStorage torch.cuda._lazy_call = torch.xpu._lazy_call - torch.cuda.HalfStorage = torch.xpu.HalfStorage torch.cuda.random = torch.xpu.random torch.cuda._device = torch.xpu._device - torch.cuda.classproperty = torch.xpu.classproperty torch.cuda.__name__ = torch.xpu.__name__ torch.cuda._device_t = torch.xpu._device_t - torch.cuda.warnings = torch.xpu.warnings torch.cuda.__spec__ = torch.xpu.__spec__ - torch.cuda.BoolTensor = torch.xpu.BoolTensor - torch.cuda.CharStorage = torch.xpu.CharStorage torch.cuda.__file__ = torch.xpu.__file__ - torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing + if legacy: + torch.cuda.os = torch.xpu.os + torch.cuda.Device = torch.xpu.Device + torch.cuda.warnings = torch.xpu.warnings + torch.cuda.classproperty = torch.xpu.classproperty + torch.UntypedStorage.cuda = torch.UntypedStorage.xpu + if float(ipex.__version__[:3]) < 2.3: + torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock + torch.cuda._initialized = torch.xpu.lazy_init._initialized + torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork + torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker + torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls + torch.cuda._tls = torch.xpu.lazy_init._tls + torch.cuda.threading = torch.xpu.lazy_init.threading + torch.cuda.traceback = torch.xpu.lazy_init.traceback + torch.cuda._lazy_new = torch.xpu._lazy_new + + torch.cuda.FloatTensor = torch.xpu.FloatTensor + torch.cuda.FloatStorage = torch.xpu.FloatStorage + torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor + torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage + torch.cuda.HalfTensor = torch.xpu.HalfTensor + torch.cuda.HalfStorage = torch.xpu.HalfStorage + torch.cuda.ByteTensor = torch.xpu.ByteTensor + torch.cuda.ByteStorage = torch.xpu.ByteStorage + torch.cuda.DoubleTensor = torch.xpu.DoubleTensor + torch.cuda.DoubleStorage = torch.xpu.DoubleStorage + torch.cuda.ShortTensor = torch.xpu.ShortTensor + torch.cuda.ShortStorage = torch.xpu.ShortStorage + torch.cuda.LongTensor = torch.xpu.LongTensor + torch.cuda.LongStorage = torch.xpu.LongStorage + torch.cuda.IntTensor = torch.xpu.IntTensor + torch.cuda.IntStorage = torch.xpu.IntStorage + torch.cuda.CharTensor = torch.xpu.CharTensor + torch.cuda.CharStorage = torch.xpu.CharStorage + torch.cuda.BoolTensor = torch.xpu.BoolTensor + torch.cuda.BoolStorage = torch.xpu.BoolStorage + torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage + torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage + + if not legacy or float(ipex.__version__[:3]) >= 2.3: + torch.cuda._initialization_lock = torch.xpu._initialization_lock + torch.cuda._initialized = torch.xpu._initialized + torch.cuda._is_in_bad_fork = torch.xpu._is_in_bad_fork + torch.cuda._lazy_seed_tracker = torch.xpu._lazy_seed_tracker + torch.cuda._queued_calls = torch.xpu._queued_calls + torch.cuda._tls = torch.xpu._tls + torch.cuda.threading = torch.xpu.threading + torch.cuda.traceback = torch.xpu.traceback + # Memory: - torch.cuda.memory = torch.xpu.memory if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): torch.xpu.empty_cache = lambda: None torch.cuda.empty_cache = torch.xpu.empty_cache + + if legacy: + torch.cuda.memory_summary = torch.xpu.memory_summary + torch.cuda.memory_snapshot = torch.xpu.memory_snapshot + torch.cuda.memory = torch.xpu.memory torch.cuda.memory_stats = torch.xpu.memory_stats - torch.cuda.memory_summary = torch.xpu.memory_summary - torch.cuda.memory_snapshot = torch.xpu.memory_snapshot torch.cuda.memory_allocated = torch.xpu.memory_allocated torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated torch.cuda.memory_reserved = torch.xpu.memory_reserved @@ -128,32 +154,44 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.initial_seed = torch.xpu.initial_seed # AMP: - torch.cuda.amp = torch.xpu.amp - torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled - torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype + if legacy: + torch.xpu.amp.custom_fwd = torch.cuda.amp.custom_fwd + torch.xpu.amp.custom_bwd = torch.cuda.amp.custom_bwd + torch.cuda.amp = torch.xpu.amp + if float(ipex.__version__[:3]) < 2.3: + torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled + torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype - if not hasattr(torch.cuda.amp, "common"): - torch.cuda.amp.common = contextlib.nullcontext() - torch.cuda.amp.common.amp_definitely_not_available = lambda: False + if not hasattr(torch.cuda.amp, "common"): + torch.cuda.amp.common = contextlib.nullcontext() + torch.cuda.amp.common.amp_definitely_not_available = lambda: False - try: - torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler - except Exception: # pylint: disable=broad-exception-caught try: - from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error - gradscaler_init() torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler except Exception: # pylint: disable=broad-exception-caught - torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler + try: + from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error + gradscaler_init() + torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler + except Exception: # pylint: disable=broad-exception-caught + torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler # C - torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream - ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count - ipex._C._DeviceProperties.major = 2024 - ipex._C._DeviceProperties.minor = 0 + if legacy and float(ipex.__version__[:3]) < 2.3: + torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream + ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count + ipex._C._DeviceProperties.major = 12 + ipex._C._DeviceProperties.minor = 1 + else: + torch._C._cuda_getCurrentRawStream = torch._C._xpu_getCurrentRawStream + torch._C._XpuDeviceProperties.multi_processor_count = torch._C._XpuDeviceProperties.gpu_subslice_count + torch._C._XpuDeviceProperties.major = 12 + torch._C._XpuDeviceProperties.minor = 1 # Fix functions with ipex: - torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] + # torch.xpu.mem_get_info always returns the total memory as free memory + torch.xpu.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] + torch.cuda.mem_get_info = torch.xpu.mem_get_info torch._utils._get_available_device_type = lambda: "xpu" torch.has_cuda = True torch.cuda.has_half = True @@ -161,19 +199,19 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.is_fp16_supported = lambda *args, **kwargs: True torch.backends.cuda.is_built = lambda *args, **kwargs: True torch.version.cuda = "12.1" - torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1] + torch.cuda.get_arch_list = lambda: ["ats-m150", "pvc"] + torch.cuda.get_device_capability = lambda *args, **kwargs: (12,1) torch.cuda.get_device_properties.major = 12 torch.cuda.get_device_properties.minor = 1 torch.cuda.ipc_collect = lambda *args, **kwargs: None torch.cuda.utilization = lambda *args, **kwargs: 0 - ipex_hijacks() - if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None: - try: - from .diffusers import ipex_diffusers - ipex_diffusers() - except Exception: # pylint: disable=broad-exception-caught - pass + device_supports_fp64, can_allocate_plus_4gb = ipex_hijacks(legacy=legacy) + try: + from .diffusers import ipex_diffusers + ipex_diffusers(device_supports_fp64=device_supports_fp64, can_allocate_plus_4gb=can_allocate_plus_4gb) + except Exception: # pylint: disable=broad-exception-caught + pass torch.cuda.is_xpu_hijacked = True except Exception as e: return False, e diff --git a/library/ipex/attention.py b/library/ipex/attention.py index 2bc62f65..400b59b6 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -1,177 +1,119 @@ import os import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -from functools import cache +from functools import cache, wraps # pylint: disable=protected-access, missing-function-docstring, line-too-long # ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers -sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4)) -attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) +sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 1)) +attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 0.5)) # Find something divisible with the input_tokens @cache -def find_slice_size(slice_size, slice_block_size): - while (slice_size * slice_block_size) > attention_slice_rate: - slice_size = slice_size // 2 - if slice_size <= 1: - slice_size = 1 - break - return slice_size +def find_split_size(original_size, slice_block_size, slice_rate=2): + split_size = original_size + while True: + if (split_size * slice_block_size) <= slice_rate and original_size % split_size == 0: + return split_size + split_size = split_size - 1 + if split_size <= 1: + return 1 + return split_size + # Find slice sizes for SDPA @cache -def find_sdpa_slice_sizes(query_shape, query_element_size): - if len(query_shape) == 3: - batch_size_attention, query_tokens, shape_three = query_shape - shape_four = 1 - else: - batch_size_attention, query_tokens, shape_three, shape_four = query_shape +def find_sdpa_slice_sizes(query_shape, key_shape, query_element_size, slice_rate=2, trigger_rate=3): + batch_size, attn_heads, query_len, _ = query_shape + _, _, key_len, _ = key_shape - slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size - block_size = batch_size_attention * slice_block_size + slice_batch_size = attn_heads * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024 - split_slice_size = batch_size_attention - split_2_slice_size = query_tokens - split_3_slice_size = shape_three + split_batch_size = batch_size + split_head_size = attn_heads + split_query_size = query_len - do_split = False - do_split_2 = False - do_split_3 = False + do_batch_split = False + do_head_split = False + do_query_split = False - if block_size > sdpa_slice_trigger_rate: - do_split = True - split_slice_size = find_slice_size(split_slice_size, slice_block_size) - if split_slice_size * slice_block_size > attention_slice_rate: - slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size - do_split_2 = True - split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) - if split_2_slice_size * slice_2_block_size > attention_slice_rate: - slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size - do_split_3 = True - split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) + if batch_size * slice_batch_size >= trigger_rate: + do_batch_split = True + split_batch_size = find_split_size(batch_size, slice_batch_size, slice_rate=slice_rate) - return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + if split_batch_size * slice_batch_size > slice_rate: + slice_head_size = split_batch_size * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024 + do_head_split = True + split_head_size = find_split_size(attn_heads, slice_head_size, slice_rate=slice_rate) -# Find slice sizes for BMM -@cache -def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape): - batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2] - slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size - block_size = batch_size_attention * slice_block_size + if split_head_size * slice_head_size > slice_rate: + slice_query_size = split_batch_size * split_head_size * (key_len) * query_element_size / 1024 / 1024 / 1024 + do_query_split = True + split_query_size = find_split_size(query_len, slice_query_size, slice_rate=slice_rate) - split_slice_size = batch_size_attention - split_2_slice_size = input_tokens - split_3_slice_size = mat2_atten_shape + return do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size - do_split = False - do_split_2 = False - do_split_3 = False - - if block_size > attention_slice_rate: - do_split = True - split_slice_size = find_slice_size(split_slice_size, slice_block_size) - if split_slice_size * slice_block_size > attention_slice_rate: - slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size - do_split_2 = True - split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) - if split_2_slice_size * slice_2_block_size > attention_slice_rate: - slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size - do_split_3 = True - split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) - - return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size - - -original_torch_bmm = torch.bmm -def torch_bmm_32_bit(input, mat2, *, out=None): - if input.device.type != "xpu": - return original_torch_bmm(input, mat2, out=out) - do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape) - - # Slice BMM - if do_split: - batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2] - hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype) - for i in range(batch_size_attention // split_slice_size): - start_idx = i * split_slice_size - end_idx = (i + 1) * split_slice_size - if do_split_2: - for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name - start_idx_2 = i2 * split_2_slice_size - end_idx_2 = (i2 + 1) * split_2_slice_size - if do_split_3: - for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name - start_idx_3 = i3 * split_3_slice_size - end_idx_3 = (i3 + 1) * split_3_slice_size - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm( - input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], - mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], - out=out - ) - else: - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm( - input[start_idx:end_idx, start_idx_2:end_idx_2], - mat2[start_idx:end_idx, start_idx_2:end_idx_2], - out=out - ) - else: - hidden_states[start_idx:end_idx] = original_torch_bmm( - input[start_idx:end_idx], - mat2[start_idx:end_idx], - out=out - ) - torch.xpu.synchronize(input.device) - else: - return original_torch_bmm(input, mat2, out=out) - return hidden_states original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention -def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs): +@wraps(torch.nn.functional.scaled_dot_product_attention) +def dynamic_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs): if query.device.type != "xpu": return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) - do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size()) + is_unsqueezed = False + if len(query.shape) == 3: + query = query.unsqueeze(0) + is_unsqueezed = True + if len(key.shape) == 3: + key = key.unsqueeze(0) + if len(value.shape) == 3: + value = value.unsqueeze(0) + do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size = find_sdpa_slice_sizes(query.shape, key.shape, query.element_size(), slice_rate=attention_slice_rate, trigger_rate=sdpa_slice_trigger_rate) # Slice SDPA - if do_split: - batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] - hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) - for i in range(batch_size_attention // split_slice_size): - start_idx = i * split_slice_size - end_idx = (i + 1) * split_slice_size - if do_split_2: - for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name - start_idx_2 = i2 * split_2_slice_size - end_idx_2 = (i2 + 1) * split_2_slice_size - if do_split_3: - for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name - start_idx_3 = i3 * split_3_slice_size - end_idx_3 = (i3 + 1) * split_3_slice_size - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention( - query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], - key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], - value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], - attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask, + if do_batch_split: + batch_size, attn_heads, query_len, _ = query.shape + _, _, _, head_dim = value.shape + hidden_states = torch.zeros((batch_size, attn_heads, query_len, head_dim), device=query.device, dtype=query.dtype) + if attn_mask is not None: + attn_mask = attn_mask.expand((query.shape[0], query.shape[1], query.shape[2], key.shape[-2])) + for ib in range(batch_size // split_batch_size): + start_idx = ib * split_batch_size + end_idx = (ib + 1) * split_batch_size + if do_head_split: + for ih in range(attn_heads // split_head_size): # pylint: disable=invalid-name + start_idx_h = ih * split_head_size + end_idx_h = (ih + 1) * split_head_size + if do_query_split: + for iq in range(query_len // split_query_size): # pylint: disable=invalid-name + start_idx_q = iq * split_query_size + end_idx_q = (iq + 1) * split_query_size + hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] = original_scaled_dot_product_attention( + query[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :], + key[start_idx:end_idx, start_idx_h:end_idx_h, :, :], + value[start_idx:end_idx, start_idx_h:end_idx_h, :, :], + attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] if attn_mask is not None else attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs ) else: - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( - query[start_idx:end_idx, start_idx_2:end_idx_2], - key[start_idx:end_idx, start_idx_2:end_idx_2], - value[start_idx:end_idx, start_idx_2:end_idx_2], - attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, + hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, :, :] = original_scaled_dot_product_attention( + query[start_idx:end_idx, start_idx_h:end_idx_h, :, :], + key[start_idx:end_idx, start_idx_h:end_idx_h, :, :], + value[start_idx:end_idx, start_idx_h:end_idx_h, :, :], + attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, :, :] if attn_mask is not None else attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs ) else: - hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention( - query[start_idx:end_idx], - key[start_idx:end_idx], - value[start_idx:end_idx], - attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask, + hidden_states[start_idx:end_idx, :, :, :] = original_scaled_dot_product_attention( + query[start_idx:end_idx, :, :, :], + key[start_idx:end_idx, :, :, :], + value[start_idx:end_idx, :, :, :], + attn_mask=attn_mask[start_idx:end_idx, :, :, :] if attn_mask is not None else attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs ) torch.xpu.synchronize(query.device) else: - return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) + hidden_states = original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) + if is_unsqueezed: + hidden_states.squeeze(0) return hidden_states diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py index 732a1856..75715d16 100644 --- a/library/ipex/diffusers.py +++ b/library/ipex/diffusers.py @@ -1,312 +1,47 @@ -import os +from functools import wraps import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -import diffusers #0.24.0 # pylint: disable=import-error -from diffusers.models.attention_processor import Attention -from diffusers.utils import USE_PEFT_BACKEND -from functools import cache +import diffusers # pylint: disable=import-error # pylint: disable=protected-access, missing-function-docstring, line-too-long -attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) -@cache -def find_slice_size(slice_size, slice_block_size): - while (slice_size * slice_block_size) > attention_slice_rate: - slice_size = slice_size // 2 - if slice_size <= 1: - slice_size = 1 - break - return slice_size - -@cache -def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None): - if len(query_shape) == 3: - batch_size_attention, query_tokens, shape_three = query_shape - shape_four = 1 - else: - batch_size_attention, query_tokens, shape_three, shape_four = query_shape - if slice_size is not None: - batch_size_attention = slice_size - - slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size - block_size = batch_size_attention * slice_block_size - - split_slice_size = batch_size_attention - split_2_slice_size = query_tokens - split_3_slice_size = shape_three - - do_split = False - do_split_2 = False - do_split_3 = False - - if query_device_type != "xpu": - return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size - - if block_size > attention_slice_rate: - do_split = True - split_slice_size = find_slice_size(split_slice_size, slice_block_size) - if split_slice_size * slice_block_size > attention_slice_rate: - slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size - do_split_2 = True - split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) - if split_2_slice_size * slice_2_block_size > attention_slice_rate: - slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size - do_split_3 = True - split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) - - return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size - -class SlicedAttnProcessor: # pylint: disable=too-few-public-methods - r""" - Processor for implementing sliced attention. - - Args: - slice_size (`int`, *optional*): - The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and - `attention_head_dim` must be a multiple of the `slice_size`. - """ - - def __init__(self, slice_size): - self.slice_size = slice_size - - def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, - encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches - - residual = hidden_states - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - dim = query.shape[-1] - query = attn.head_to_batch_dim(query) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - batch_size_attention, query_tokens, shape_three = query.shape - hidden_states = torch.zeros( - (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype - ) - - #################################################################### - # ARC GPUs can't allocate more than 4GB to a single block, Slice it: - _, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size) - - for i in range(batch_size_attention // split_slice_size): - start_idx = i * split_slice_size - end_idx = (i + 1) * split_slice_size - if do_split_2: - for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name - start_idx_2 = i2 * split_2_slice_size - end_idx_2 = (i2 + 1) * split_2_slice_size - if do_split_3: - for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name - start_idx_3 = i3 * split_3_slice_size - end_idx_3 = (i3 + 1) * split_3_slice_size - - query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] - key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] - attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]) - - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice - del attn_slice - else: - query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] - key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] - attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) - - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice - del attn_slice - torch.xpu.synchronize(query.device) - else: - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - del attn_slice - #################################################################### - - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -class AttnProcessor: - r""" - Default processor for performing attention-related computations. - """ - - def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, - encoder_hidden_states=None, attention_mask=None, - temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches - - residual = hidden_states - - args = () if USE_PEFT_BACKEND else (scale,) - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states, *args) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states, *args) - value = attn.to_v(encoder_hidden_states, *args) - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - #################################################################### - # ARC GPUs can't allocate more than 4GB to a single block, Slice it: - batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] - hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) - do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type) - - if do_split: - for i in range(batch_size_attention // split_slice_size): - start_idx = i * split_slice_size - end_idx = (i + 1) * split_slice_size - if do_split_2: - for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name - start_idx_2 = i2 * split_2_slice_size - end_idx_2 = (i2 + 1) * split_2_slice_size - if do_split_3: - for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name - start_idx_3 = i3 * split_3_slice_size - end_idx_3 = (i3 + 1) * split_3_slice_size - - query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] - key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] - attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]) - - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice - del attn_slice - else: - query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] - key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] - attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) - - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice - del attn_slice - else: - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - del attn_slice - torch.xpu.synchronize(query.device) - else: - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - #################################################################### - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - -def ipex_diffusers(): - #ARC GPUs can't allocate more than 4GB to a single block: - diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor - diffusers.models.attention_processor.AttnProcessor = AttnProcessor +# Diffusers FreeU +original_fourier_filter = diffusers.utils.torch_utils.fourier_filter +@wraps(diffusers.utils.torch_utils.fourier_filter) +def fourier_filter(x_in, threshold, scale): + return_dtype = x_in.dtype + return original_fourier_filter(x_in.to(dtype=torch.float32), threshold, scale).to(dtype=return_dtype) + + +# fp64 error +class FluxPosEmbed(torch.nn.Module): + def __init__(self, theta: int, axes_dim): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + for i in range(n_axes): + cos, sin = diffusers.models.embeddings.get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=torch.float32, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +def ipex_diffusers(device_supports_fp64=False, can_allocate_plus_4gb=False): + diffusers.utils.torch_utils.fourier_filter = fourier_filter + if not device_supports_fp64: + diffusers.models.embeddings.FluxPosEmbed = FluxPosEmbed diff --git a/library/ipex/gradscaler.py b/library/ipex/gradscaler.py index 6eb56bc2..0a861009 100644 --- a/library/ipex/gradscaler.py +++ b/library/ipex/gradscaler.py @@ -5,7 +5,7 @@ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, un # pylint: disable=protected-access, missing-function-docstring, line-too-long -device_supports_fp64 = torch.xpu.has_fp64_dtype() +device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64 OptState = ipex.cpu.autocast._grad_scaler.OptState _MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator _refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index d3cef827..91569746 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -2,10 +2,19 @@ import os from functools import wraps from contextlib import nullcontext import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import import numpy as np -device_supports_fp64 = torch.xpu.has_fp64_dtype() +device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64 +if os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '0' and (torch.xpu.get_device_properties("xpu").total_memory / 1024 / 1024 / 1024) > 4.1: + try: + x = torch.ones((33000,33000), dtype=torch.float32, device="xpu") + del x + torch.xpu.empty_cache() + can_allocate_plus_4gb = True + except Exception: + can_allocate_plus_4gb = False +else: + can_allocate_plus_4gb = bool(os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '-1') # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return @@ -26,7 +35,7 @@ def check_device(device): return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int)) def return_xpu(device): - return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu" + return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device(f"xpu:{device.index}" if device.index is not None else "xpu") if isinstance(device, torch.device) else "xpu" # Autocast @@ -42,7 +51,7 @@ def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=Non original_interpolate = torch.nn.functional.interpolate @wraps(torch.nn.functional.interpolate) def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments - if antialias or align_corners is not None or mode == 'bicubic': + if mode in {'bicubic', 'bilinear'}: return_device = tensor.device return_dtype = tensor.dtype return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode, @@ -73,35 +82,46 @@ def as_tensor(data, dtype=None, device=None): return original_as_tensor(data, dtype=dtype, device=device) -if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None: - original_torch_bmm = torch.bmm +if can_allocate_plus_4gb: original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention else: # 32 bit attention workarounds for Alchemist: try: - from .attention import torch_bmm_32_bit as original_torch_bmm - from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention + from .attention import dynamic_scaled_dot_product_attention as original_scaled_dot_product_attention except Exception: # pylint: disable=broad-exception-caught - original_torch_bmm = torch.bmm original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention - -# Data Type Errors: -@wraps(torch.bmm) -def torch_bmm(input, mat2, *, out=None): - if input.dtype != mat2.dtype: - mat2 = mat2.to(input.dtype) - return original_torch_bmm(input, mat2, out=out) - @wraps(torch.nn.functional.scaled_dot_product_attention) -def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): +def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs): if query.dtype != key.dtype: key = key.to(dtype=query.dtype) if query.dtype != value.dtype: value = value.to(dtype=query.dtype) if attn_mask is not None and query.dtype != attn_mask.dtype: attn_mask = attn_mask.to(dtype=query.dtype) - return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) + +# Data Type Errors: +original_torch_bmm = torch.bmm +@wraps(torch.bmm) +def torch_bmm(input, mat2, *, out=None): + if input.dtype != mat2.dtype: + mat2 = mat2.to(input.dtype) + return original_torch_bmm(input, mat2, out=out) + +# Diffusers FreeU +original_fft_fftn = torch.fft.fftn +@wraps(torch.fft.fftn) +def fft_fftn(input, s=None, dim=None, norm=None, *, out=None): + return_dtype = input.dtype + return original_fft_fftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype) + +# Diffusers FreeU +original_fft_ifftn = torch.fft.ifftn +@wraps(torch.fft.ifftn) +def fft_ifftn(input, s=None, dim=None, norm=None, *, out=None): + return_dtype = input.dtype + return original_fft_ifftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype) # A1111 FP16 original_functional_group_norm = torch.nn.functional.group_norm @@ -133,6 +153,15 @@ def functional_linear(input, weight, bias=None): bias.data = bias.data.to(dtype=weight.data.dtype) return original_functional_linear(input, weight, bias=bias) +original_functional_conv1d = torch.nn.functional.conv1d +@wraps(torch.nn.functional.conv1d) +def functional_conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_conv1d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + original_functional_conv2d = torch.nn.functional.conv2d @wraps(torch.nn.functional.conv2d) def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): @@ -142,14 +171,15 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, bias.data = bias.data.to(dtype=weight.data.dtype) return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) -# A1111 Embedding BF16 -original_torch_cat = torch.cat -@wraps(torch.cat) -def torch_cat(tensor, *args, **kwargs): - if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): - return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) - else: - return original_torch_cat(tensor, *args, **kwargs) +# LTX Video +original_functional_conv3d = torch.nn.functional.conv3d +@wraps(torch.nn.functional.conv3d) +def functional_conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_conv3d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) # SwinIR BF16: original_functional_pad = torch.nn.functional.pad @@ -164,6 +194,7 @@ def functional_pad(input, pad, mode='constant', value=None): original_torch_tensor = torch.tensor @wraps(torch.tensor) def torch_tensor(data, *args, dtype=None, device=None, **kwargs): + global device_supports_fp64 if check_device(device): device = return_xpu(device) if not device_supports_fp64: @@ -227,7 +258,7 @@ def torch_empty(*args, device=None, **kwargs): original_torch_randn = torch.randn @wraps(torch.randn) def torch_randn(*args, device=None, dtype=None, **kwargs): - if dtype == bytes: + if dtype is bytes: dtype = None if check_device(device): return original_torch_randn(*args, device=return_xpu(device), **kwargs) @@ -250,6 +281,14 @@ def torch_zeros(*args, device=None, **kwargs): else: return original_torch_zeros(*args, device=device, **kwargs) +original_torch_full = torch.full +@wraps(torch.full) +def torch_full(*args, device=None, **kwargs): + if check_device(device): + return original_torch_full(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_full(*args, device=device, **kwargs) + original_torch_linspace = torch.linspace @wraps(torch.linspace) def torch_linspace(*args, device=None, **kwargs): @@ -258,14 +297,6 @@ def torch_linspace(*args, device=None, **kwargs): else: return original_torch_linspace(*args, device=device, **kwargs) -original_torch_Generator = torch.Generator -@wraps(torch.Generator) -def torch_Generator(device=None): - if check_device(device): - return original_torch_Generator(return_xpu(device)) - else: - return original_torch_Generator(device) - original_torch_load = torch.load @wraps(torch.load) def torch_load(f, map_location=None, *args, **kwargs): @@ -276,9 +307,27 @@ def torch_load(f, map_location=None, *args, **kwargs): else: return original_torch_load(f, *args, map_location=map_location, **kwargs) +original_torch_Generator = torch.Generator +@wraps(torch.Generator) +def torch_Generator(device=None): + if check_device(device): + return original_torch_Generator(return_xpu(device)) + else: + return original_torch_Generator(device) + +@wraps(torch.cuda.synchronize) +def torch_cuda_synchronize(device=None): + if check_device(device): + return torch.xpu.synchronize(return_xpu(device)) + else: + return torch.xpu.synchronize(device) + # Hijack Functions: -def ipex_hijacks(): +def ipex_hijacks(legacy=True): + global device_supports_fp64, can_allocate_plus_4gb + if legacy and float(torch.__version__[:3]) < 2.5: + torch.nn.functional.interpolate = interpolate torch.tensor = torch_tensor torch.Tensor.to = Tensor_to torch.Tensor.cuda = Tensor_cuda @@ -289,9 +338,11 @@ def ipex_hijacks(): torch.randn = torch_randn torch.ones = torch_ones torch.zeros = torch_zeros + torch.full = torch_full torch.linspace = torch_linspace - torch.Generator = torch_Generator torch.load = torch_load + torch.Generator = torch_Generator + torch.cuda.synchronize = torch_cuda_synchronize torch.backends.cuda.sdp_kernel = return_null_context torch.nn.DataParallel = DummyDataParallel @@ -302,12 +353,15 @@ def ipex_hijacks(): torch.nn.functional.group_norm = functional_group_norm torch.nn.functional.layer_norm = functional_layer_norm torch.nn.functional.linear = functional_linear + torch.nn.functional.conv1d = functional_conv1d torch.nn.functional.conv2d = functional_conv2d - torch.nn.functional.interpolate = interpolate + torch.nn.functional.conv3d = functional_conv3d torch.nn.functional.pad = functional_pad torch.bmm = torch_bmm - torch.cat = torch_cat + torch.fft.fftn = fft_fftn + torch.fft.ifftn = fft_ifftn if not device_supports_fp64: torch.from_numpy = from_numpy torch.as_tensor = as_tensor + return device_supports_fp64, can_allocate_plus_4gb From f4a004786500d80e1b47728d216aed9d76869a9d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 26 Feb 2025 20:50:44 +0900 Subject: [PATCH 21/33] feat: support metadata loading in MemoryEfficientSafeOpen --- library/utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/library/utils.py b/library/utils.py index 07079c6d..4df8bd32 100644 --- a/library/utils.py +++ b/library/utils.py @@ -261,11 +261,10 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: class MemoryEfficientSafeOpen: - # does not support metadata loading def __init__(self, filename): self.filename = filename - self.header, self.header_size = self._read_header() self.file = open(filename, "rb") + self.header, self.header_size = self._read_header() def __enter__(self): return self @@ -276,6 +275,9 @@ class MemoryEfficientSafeOpen: def keys(self): return [k for k in self.header.keys() if k != "__metadata__"] + def metadata(self) -> Dict[str, str]: + return self.header.get("__metadata__", {}) + def get_tensor(self, key): if key not in self.header: raise KeyError(f"Tensor '{key}' not found in the file") @@ -293,10 +295,9 @@ class MemoryEfficientSafeOpen: return self._deserialize_tensor(tensor_bytes, metadata) def _read_header(self): - with open(self.filename, "rb") as f: - header_size = struct.unpack(" Date: Wed, 26 Feb 2025 20:50:58 +0900 Subject: [PATCH 22/33] feat: add script to merge multiple safetensors files into a single file for SD3 --- tools/merge_sd3_safetensors.py | 139 +++++++++++++++++++++++++++++++++ 1 file changed, 139 insertions(+) create mode 100644 tools/merge_sd3_safetensors.py diff --git a/tools/merge_sd3_safetensors.py b/tools/merge_sd3_safetensors.py new file mode 100644 index 00000000..bef7c9b9 --- /dev/null +++ b/tools/merge_sd3_safetensors.py @@ -0,0 +1,139 @@ +import argparse +import os +import gc +from typing import Dict, Optional, Union +import torch +from safetensors.torch import safe_open + +from library.utils import setup_logging +from library.utils import load_safetensors, mem_eff_save_file + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def merge_safetensors( + dit_path: str, + vae_path: Optional[str] = None, + clip_l_path: Optional[str] = None, + clip_g_path: Optional[str] = None, + t5xxl_path: Optional[str] = None, + output_path: str = "merged_model.safetensors", + device: str = "cpu", +): + """ + Merge multiple safetensors files into a single file + + Args: + dit_path: Path to the DiT/MMDiT model + vae_path: Path to the VAE model + clip_l_path: Path to the CLIP-L model + clip_g_path: Path to the CLIP-G model + t5xxl_path: Path to the T5-XXL model + output_path: Path to save the merged model + device: Device to load tensors to + """ + logger.info("Starting to merge safetensors files...") + + # 1. Get DiT metadata if available + metadata = None + try: + with safe_open(dit_path, framework="pt") as f: + metadata = f.metadata() # may be None + if metadata: + logger.info(f"Found metadata in DiT model: {metadata}") + except Exception as e: + logger.warning(f"Failed to read metadata from DiT model: {e}") + + # 2. Create empty merged state dict + merged_state_dict = {} + + # 3. Load and merge each model with memory management + + # DiT/MMDiT - prefix: model.diffusion_model. + logger.info(f"Loading DiT model from {dit_path}") + dit_state_dict = load_safetensors(dit_path, device=device, disable_mmap=True) + logger.info(f"Adding DiT model with {len(dit_state_dict)} keys") + for key, value in dit_state_dict.items(): + merged_state_dict[f"model.diffusion_model.{key}"] = value + # Free memory + del dit_state_dict + gc.collect() + + # VAE - prefix: first_stage_model. + if vae_path: + logger.info(f"Loading VAE model from {vae_path}") + vae_state_dict = load_safetensors(vae_path, device=device, disable_mmap=True) + logger.info(f"Adding VAE model with {len(vae_state_dict)} keys") + for key, value in vae_state_dict.items(): + merged_state_dict[f"first_stage_model.{key}"] = value + # Free memory + del vae_state_dict + gc.collect() + + # CLIP-L - prefix: text_encoders.clip_l. + if clip_l_path: + logger.info(f"Loading CLIP-L model from {clip_l_path}") + clip_l_state_dict = load_safetensors(clip_l_path, device=device, disable_mmap=True) + logger.info(f"Adding CLIP-L model with {len(clip_l_state_dict)} keys") + for key, value in clip_l_state_dict.items(): + merged_state_dict[f"text_encoders.clip_l.{key}"] = value + # Free memory + del clip_l_state_dict + gc.collect() + + # CLIP-G - prefix: text_encoders.clip_g. + if clip_g_path: + logger.info(f"Loading CLIP-G model from {clip_g_path}") + clip_g_state_dict = load_safetensors(clip_g_path, device=device, disable_mmap=True) + logger.info(f"Adding CLIP-G model with {len(clip_g_state_dict)} keys") + for key, value in clip_g_state_dict.items(): + merged_state_dict[f"text_encoders.clip_g.{key}"] = value + # Free memory + del clip_g_state_dict + gc.collect() + + # T5-XXL - prefix: text_encoders.t5xxl. + if t5xxl_path: + logger.info(f"Loading T5-XXL model from {t5xxl_path}") + t5xxl_state_dict = load_safetensors(t5xxl_path, device=device, disable_mmap=True) + logger.info(f"Adding T5-XXL model with {len(t5xxl_state_dict)} keys") + for key, value in t5xxl_state_dict.items(): + merged_state_dict[f"text_encoders.t5xxl.{key}"] = value + # Free memory + del t5xxl_state_dict + gc.collect() + + # 4. Save merged state dict + logger.info(f"Saving merged model to {output_path} with {len(merged_state_dict)} keys total") + mem_eff_save_file(merged_state_dict, output_path, metadata) + logger.info("Successfully merged safetensors files") + + +def main(): + parser = argparse.ArgumentParser(description="Merge Stable Diffusion 3.5 model components into a single safetensors file") + parser.add_argument("--dit", required=True, help="Path to the DiT/MMDiT model") + parser.add_argument("--vae", help="Path to the VAE model") + parser.add_argument("--clip_l", help="Path to the CLIP-L model") + parser.add_argument("--clip_g", help="Path to the CLIP-G model") + parser.add_argument("--t5xxl", help="Path to the T5-XXL model") + parser.add_argument("--output", default="merged_model.safetensors", help="Path to save the merged model") + parser.add_argument("--device", default="cpu", help="Device to load tensors to") + + args = parser.parse_args() + + merge_safetensors( + dit_path=args.dit, + vae_path=args.vae, + clip_l_path=args.clip_l, + clip_g_path=args.clip_g, + t5xxl_path=args.t5xxl, + output_path=args.output, + device=args.device, + ) + + +if __name__ == "__main__": + main() From ae409e83c939f2c4a997cfb1679bd7cd364baf7e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 26 Feb 2025 20:56:32 +0900 Subject: [PATCH 23/33] fix: FLUX/SD3 network training not working without caching latents closes #1954 --- flux_train_network.py | 11 ++++++++--- sd3_train_network.py | 11 ++++++++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index ae4b62f5..26503df1 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -36,7 +36,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): self.is_schnell: Optional[bool] = None self.is_swapping_blocks: bool = False - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): super().assert_extra_args(args, train_dataset_group, val_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) @@ -323,7 +328,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) return noise_scheduler - def encode_images_to_latents(self, args, accelerator, vae, images): + def encode_images_to_latents(self, args, vae, images): return vae.encode(images) def shift_scale_latents(self, args, latents): @@ -341,7 +346,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): network, weight_dtype, train_unet, - is_train=True + is_train=True, ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) diff --git a/sd3_train_network.py b/sd3_train_network.py index 2f457949..9438bc7b 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -26,7 +26,12 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): super().__init__() self.sample_prompts_te_outputs = None - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): # super().assert_extra_args(args, train_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) @@ -299,7 +304,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift) return noise_scheduler - def encode_images_to_latents(self, args, accelerator, vae, images): + def encode_images_to_latents(self, args, vae, images): return vae.encode(images) def shift_scale_latents(self, args, latents): @@ -317,7 +322,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): network, weight_dtype, train_unet, - is_train=True + is_train=True, ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) From 3d79239be4b20d67faed67c47f693396342e3af4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 26 Feb 2025 21:21:04 +0900 Subject: [PATCH 24/33] docs: update README to include recent improvements in validation loss calculation --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 4bbd7617..3c699307 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,11 @@ The command to install PyTorch is as follows: ### Recent Updates +Feb 26, 2025: + +- Improve the validation loss calculation in `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py`. PR [#1903](https://github.com/kohya-ss/sd-scripts/pull/1903) + - The validation loss uses the fixed timestep sampling and the fixed random seed. This is to ensure that the validation loss is not fluctuated by the random values. + Jan 25, 2025: - `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py` now support validation loss. PR [#1864](https://github.com/kohya-ss/sd-scripts/pull/1864) Thank you to rockerBOO! From 734333d0c9eec3f20582c9c16f6d148cb1ec2596 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 28 Feb 2025 23:52:29 +0900 Subject: [PATCH 25/33] feat: enhance merging logic for safetensors models to handle key prefixes correctly --- tools/merge_sd3_safetensors.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/tools/merge_sd3_safetensors.py b/tools/merge_sd3_safetensors.py index bef7c9b9..960cf6e7 100644 --- a/tools/merge_sd3_safetensors.py +++ b/tools/merge_sd3_safetensors.py @@ -53,22 +53,30 @@ def merge_safetensors( # 3. Load and merge each model with memory management # DiT/MMDiT - prefix: model.diffusion_model. + # This state dict may have VAE keys. logger.info(f"Loading DiT model from {dit_path}") dit_state_dict = load_safetensors(dit_path, device=device, disable_mmap=True) logger.info(f"Adding DiT model with {len(dit_state_dict)} keys") for key, value in dit_state_dict.items(): - merged_state_dict[f"model.diffusion_model.{key}"] = value + if key.startswith("model.diffusion_model.") or key.startswith("first_stage_model."): + merged_state_dict[key] = value + else: + merged_state_dict[f"model.diffusion_model.{key}"] = value # Free memory del dit_state_dict gc.collect() # VAE - prefix: first_stage_model. + # May be omitted if VAE is already included in DiT model. if vae_path: logger.info(f"Loading VAE model from {vae_path}") vae_state_dict = load_safetensors(vae_path, device=device, disable_mmap=True) logger.info(f"Adding VAE model with {len(vae_state_dict)} keys") for key, value in vae_state_dict.items(): - merged_state_dict[f"first_stage_model.{key}"] = value + if key.startswith("first_stage_model."): + merged_state_dict[key] = value + else: + merged_state_dict[f"first_stage_model.{key}"] = value # Free memory del vae_state_dict gc.collect() @@ -79,7 +87,10 @@ def merge_safetensors( clip_l_state_dict = load_safetensors(clip_l_path, device=device, disable_mmap=True) logger.info(f"Adding CLIP-L model with {len(clip_l_state_dict)} keys") for key, value in clip_l_state_dict.items(): - merged_state_dict[f"text_encoders.clip_l.{key}"] = value + if key.startswith("text_encoders.clip_l.transformer."): + merged_state_dict[key] = value + else: + merged_state_dict[f"text_encoders.clip_l.transformer.{key}"] = value # Free memory del clip_l_state_dict gc.collect() @@ -90,7 +101,10 @@ def merge_safetensors( clip_g_state_dict = load_safetensors(clip_g_path, device=device, disable_mmap=True) logger.info(f"Adding CLIP-G model with {len(clip_g_state_dict)} keys") for key, value in clip_g_state_dict.items(): - merged_state_dict[f"text_encoders.clip_g.{key}"] = value + if key.startswith("text_encoders.clip_g.transformer."): + merged_state_dict[key] = value + else: + merged_state_dict[f"text_encoders.clip_g.transformer.{key}"] = value # Free memory del clip_g_state_dict gc.collect() @@ -101,7 +115,10 @@ def merge_safetensors( t5xxl_state_dict = load_safetensors(t5xxl_path, device=device, disable_mmap=True) logger.info(f"Adding T5-XXL model with {len(t5xxl_state_dict)} keys") for key, value in t5xxl_state_dict.items(): - merged_state_dict[f"text_encoders.t5xxl.{key}"] = value + if key.startswith("text_encoders.t5xxl.transformer."): + merged_state_dict[key] = value + else: + merged_state_dict[f"text_encoders.t5xxl.transformer.{key}"] = value # Free memory del t5xxl_state_dict gc.collect() @@ -115,7 +132,7 @@ def merge_safetensors( def main(): parser = argparse.ArgumentParser(description="Merge Stable Diffusion 3.5 model components into a single safetensors file") parser.add_argument("--dit", required=True, help="Path to the DiT/MMDiT model") - parser.add_argument("--vae", help="Path to the VAE model") + parser.add_argument("--vae", help="Path to the VAE model. May be omitted if VAE is included in DiT model") parser.add_argument("--clip_l", help="Path to the CLIP-L model") parser.add_argument("--clip_g", help="Path to the CLIP-G model") parser.add_argument("--t5xxl", help="Path to the T5-XXL model") From ba5251168a91f608de9fe9e365a2f889e4bb6cf8 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 1 Mar 2025 10:31:39 +0900 Subject: [PATCH 26/33] fix: save tensors as is dtype, add save_precision option --- tools/merge_sd3_safetensors.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/tools/merge_sd3_safetensors.py b/tools/merge_sd3_safetensors.py index 960cf6e7..6bc1003e 100644 --- a/tools/merge_sd3_safetensors.py +++ b/tools/merge_sd3_safetensors.py @@ -6,7 +6,7 @@ import torch from safetensors.torch import safe_open from library.utils import setup_logging -from library.utils import load_safetensors, mem_eff_save_file +from library.utils import load_safetensors, mem_eff_save_file, str_to_dtype setup_logging() import logging @@ -22,6 +22,7 @@ def merge_safetensors( t5xxl_path: Optional[str] = None, output_path: str = "merged_model.safetensors", device: str = "cpu", + save_precision: Optional[str] = None, ): """ Merge multiple safetensors files into a single file @@ -34,9 +35,16 @@ def merge_safetensors( t5xxl_path: Path to the T5-XXL model output_path: Path to save the merged model device: Device to load tensors to + save_precision: Target dtype for model weights (e.g. 'fp16', 'bf16') """ logger.info("Starting to merge safetensors files...") + # Convert save_precision string to torch dtype if specified + if save_precision: + target_dtype = str_to_dtype(save_precision) + else: + target_dtype = None + # 1. Get DiT metadata if available metadata = None try: @@ -55,7 +63,7 @@ def merge_safetensors( # DiT/MMDiT - prefix: model.diffusion_model. # This state dict may have VAE keys. logger.info(f"Loading DiT model from {dit_path}") - dit_state_dict = load_safetensors(dit_path, device=device, disable_mmap=True) + dit_state_dict = load_safetensors(dit_path, device=device, disable_mmap=True, dtype=target_dtype) logger.info(f"Adding DiT model with {len(dit_state_dict)} keys") for key, value in dit_state_dict.items(): if key.startswith("model.diffusion_model.") or key.startswith("first_stage_model."): @@ -70,7 +78,7 @@ def merge_safetensors( # May be omitted if VAE is already included in DiT model. if vae_path: logger.info(f"Loading VAE model from {vae_path}") - vae_state_dict = load_safetensors(vae_path, device=device, disable_mmap=True) + vae_state_dict = load_safetensors(vae_path, device=device, disable_mmap=True, dtype=target_dtype) logger.info(f"Adding VAE model with {len(vae_state_dict)} keys") for key, value in vae_state_dict.items(): if key.startswith("first_stage_model."): @@ -84,7 +92,7 @@ def merge_safetensors( # CLIP-L - prefix: text_encoders.clip_l. if clip_l_path: logger.info(f"Loading CLIP-L model from {clip_l_path}") - clip_l_state_dict = load_safetensors(clip_l_path, device=device, disable_mmap=True) + clip_l_state_dict = load_safetensors(clip_l_path, device=device, disable_mmap=True, dtype=target_dtype) logger.info(f"Adding CLIP-L model with {len(clip_l_state_dict)} keys") for key, value in clip_l_state_dict.items(): if key.startswith("text_encoders.clip_l.transformer."): @@ -98,7 +106,7 @@ def merge_safetensors( # CLIP-G - prefix: text_encoders.clip_g. if clip_g_path: logger.info(f"Loading CLIP-G model from {clip_g_path}") - clip_g_state_dict = load_safetensors(clip_g_path, device=device, disable_mmap=True) + clip_g_state_dict = load_safetensors(clip_g_path, device=device, disable_mmap=True, dtype=target_dtype) logger.info(f"Adding CLIP-G model with {len(clip_g_state_dict)} keys") for key, value in clip_g_state_dict.items(): if key.startswith("text_encoders.clip_g.transformer."): @@ -112,7 +120,7 @@ def merge_safetensors( # T5-XXL - prefix: text_encoders.t5xxl. if t5xxl_path: logger.info(f"Loading T5-XXL model from {t5xxl_path}") - t5xxl_state_dict = load_safetensors(t5xxl_path, device=device, disable_mmap=True) + t5xxl_state_dict = load_safetensors(t5xxl_path, device=device, disable_mmap=True, dtype=target_dtype) logger.info(f"Adding T5-XXL model with {len(t5xxl_state_dict)} keys") for key, value in t5xxl_state_dict.items(): if key.startswith("text_encoders.t5xxl.transformer."): @@ -138,6 +146,7 @@ def main(): parser.add_argument("--t5xxl", help="Path to the T5-XXL model") parser.add_argument("--output", default="merged_model.safetensors", help="Path to save the merged model") parser.add_argument("--device", default="cpu", help="Device to load tensors to") + parser.add_argument("--save_precision", type=str, help="Precision to save the model in (e.g., 'fp16', 'bf16', 'float16', etc.)") args = parser.parse_args() @@ -149,6 +158,7 @@ def main(): t5xxl_path=args.t5xxl, output_path=args.output, device=args.device, + save_precision=args.save_precision, ) From acdca2abb781eb207cb760c759dc4d23f8ca5e72 Mon Sep 17 00:00:00 2001 From: Ivan Chikish Date: Sat, 1 Mar 2025 17:06:17 +0300 Subject: [PATCH 27/33] Fix [occasionally] missing text encoder attn modules Should fix #1952 I added alternative name for CLIPAttention. I have no idea why this name changed. Now it should accept both names. --- networks/dylora.py | 2 +- networks/lora.py | 2 +- networks/lora_diffusers.py | 2 +- networks/lora_fa.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/networks/dylora.py b/networks/dylora.py index b0925453..82d96f59 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -268,7 +268,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh class DyLoRANetwork(torch.nn.Module): UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"] LORA_PREFIX_UNET = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" diff --git a/networks/lora.py b/networks/lora.py index 6f33f1a1..1699a60f 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -866,7 +866,7 @@ class LoRANetwork(torch.nn.Module): UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"] LORA_PREFIX_UNET = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" diff --git a/networks/lora_diffusers.py b/networks/lora_diffusers.py index b99b0244..56b74d10 100644 --- a/networks/lora_diffusers.py +++ b/networks/lora_diffusers.py @@ -278,7 +278,7 @@ def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0): class LoRANetwork(torch.nn.Module): UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"] LORA_PREFIX_UNET = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 919222ce..5fe778b4 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -755,7 +755,7 @@ class LoRANetwork(torch.nn.Module): UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"] LORA_PREFIX_UNET = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" From 3f49053c9068a0dcfa3a360d032529d87f878f8b Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sun, 2 Mar 2025 19:32:06 +0800 Subject: [PATCH 28/33] fatser fix bug for SDXL super SD1.5 assert cant use 32 --- sdxl_train_network.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 83969bb1..3559ab88 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -18,7 +18,6 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): self.is_sdxl = True def assert_extra_args(self, args, train_dataset_group): - super().assert_extra_args(args, train_dataset_group) sdxl_train_util.verify_sdxl_training_args(args) if args.cache_text_encoder_outputs: From aa2bde7ece17be16083acfe9645bb4e21718fb2c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 5 Mar 2025 23:24:52 +0900 Subject: [PATCH 29/33] docs: add utility script for merging SD3 weights into a single .safetensors file --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 3c699307..426eaed8 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Mar 6, 2025: + +- Added a utility script to merge the weights of SD3's DiT, VAE (optional), CLIP-L, CLIP-G, and T5XXL into a single .safetensors file. Run `tools/merge_sd3_safetensors.py`. See `--help` for usage. PR [#1960](https://github.com/kohya-ss/sd-scripts/pull/1960) + Feb 26, 2025: - Improve the validation loss calculation in `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py`. PR [#1903](https://github.com/kohya-ss/sd-scripts/pull/1903) From e5b5c7e1db5a5c8d7e0628cd565e9619f9564adb Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Sat, 15 Mar 2025 13:29:32 +0800 Subject: [PATCH 30/33] Update requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index de39f588..52c3b8c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -43,4 +43,5 @@ rich==13.7.0 # for T5XXL tokenizer (SD3/FLUX) sentencepiece==0.2.0 # for kohya_ss library +pytorch-optimizer -e . From 5b210ad7178c0b88c214686389b0afb03ba3813c Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Wed, 19 Mar 2025 10:49:06 +0800 Subject: [PATCH 31/33] update prodigyopt and prodigy-plus-schedule-free --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 52c3b8c7..7348647f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,6 @@ opencv-python==4.8.1.78 einops==0.7.0 pytorch-lightning==1.9.0 bitsandbytes==0.44.0 -prodigyopt==1.0 lion-pytorch==0.0.6 schedulefree==1.4 tensorboard @@ -44,4 +43,6 @@ rich==13.7.0 sentencepiece==0.2.0 # for kohya_ss library pytorch-optimizer +prodigy-plus-schedule-free==1.9.0 +prodigyopt==1.1.2 -e . From d151833526f5f79414a995cbb416de8a31e000cb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 20 Mar 2025 22:05:29 +0900 Subject: [PATCH 32/33] docs: update README with recent changes and specify version for pytorch-optimizer --- README.md | 4 ++++ requirements.txt | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 426eaed8..59b0e676 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Mar 20, 2025: +- `pytorch-optimizer` is added to requirements.txt. Thank you to gesen2egee for PR [#1985](https://github.com/kohya-ss/sd-scripts/pull/1985). + - For example, you can use CAME optimizer with `--optimizer_type "pytorch_optimizer.CAME" --optimizer_args "weight_decay=0.01"`. + Mar 6, 2025: - Added a utility script to merge the weights of SD3's DiT, VAE (optional), CLIP-L, CLIP-G, and T5XXL into a single .safetensors file. Run `tools/merge_sd3_safetensors.py`. See `--help` for usage. PR [#1960](https://github.com/kohya-ss/sd-scripts/pull/1960) diff --git a/requirements.txt b/requirements.txt index 7348647f..767d9e8e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,9 @@ pytorch-lightning==1.9.0 bitsandbytes==0.44.0 lion-pytorch==0.0.6 schedulefree==1.4 +pytorch-optimizer==3.5.0 +prodigy-plus-schedule-free==1.9.0 +prodigyopt==1.1.2 tensorboard safetensors==0.4.4 # gradio==3.16.2 @@ -42,7 +45,4 @@ rich==13.7.0 # for T5XXL tokenizer (SD3/FLUX) sentencepiece==0.2.0 # for kohya_ss library -pytorch-optimizer -prodigy-plus-schedule-free==1.9.0 -prodigyopt==1.1.2 -e . From 8f4ee8fc343b047965cd8976fca65c3a35b7593a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 21 Mar 2025 22:05:48 +0900 Subject: [PATCH 33/33] doc: update README for latest --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 6beee5e3..7ed3a2f5 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,8 @@ This repository contains training, generation and utility scripts for Stable Dif [__Change History__](#change-history) is moved to the bottom of the page. 更新履歴は[ページ末尾](#change-history)に移しました。 +Latest update: 2025-03-21 (Version 0.9.1) + [日本語版READMEはこちら](./README-ja.md) The development version is in the `dev` branch. Please check the dev branch for the latest changes. @@ -146,6 +148,11 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### Mar 21, 2025 / 2025-03-21 Version 0.9.1 + +- Fixed a bug where some of LoRA modules for CLIP Text Encoder were not trained. Thank you Nekotekina for PR [#1964](https://github.com/kohya-ss/sd-scripts/pull/1964) + - The LoRA modules for CLIP Text Encoder are now 264 modules, which is the same as before. Only 88 modules were trained in the previous version. + ### Jan 17, 2025 / 2025-01-17 Version 0.9.0 - __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries.