diff --git a/library/train_util.py b/library/train_util.py index 989758ad..74aae0a7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3246,6 +3246,12 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): default=None, help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", ) + parser.add_argument( + "--num_last_block_to_freeze", + type=int, + default=None, + help="num_last_block_to_freeze", + ) def add_optimizer_arguments(parser: argparse.ArgumentParser): @@ -5758,6 +5764,21 @@ def sample_image_inference( pass +def freeze_blocks(model, num_last_block_to_freeze, block_name="x_block"): + + filtered_blocks = [(name, param) for name, param in model.named_parameters() if block_name in name] + print(f"filtered_blocks: {len(filtered_blocks)}") + + num_blocks_to_freeze = min(len(filtered_blocks), num_last_block_to_freeze) + + print(f"freeze_blocks: {num_blocks_to_freeze}") + + start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze) + + for i in range(start_freezing_from, len(filtered_blocks)): + _, param = filtered_blocks[i] + param.requires_grad = False + # endregion diff --git a/sd3_train.py b/sd3_train.py index 3b6c8a11..ce9500b0 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -368,12 +368,19 @@ def train(args): vae.eval() vae.to(accelerator.device, dtype=vae_dtype) + mmdit.requires_grad_(train_mmdit) + if not train_mmdit: + mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared + + if args.num_last_block_to_freeze: + train_util.freeze_blocks(mmdit,num_last_block_to_freeze=args.num_last_block_to_freeze) + training_models = [] params_to_optimize = [] # if train_unet: training_models.append(mmdit) # if block_lrs is None: - params_to_optimize.append({"params": list(mmdit.parameters()), "lr": args.learning_rate}) + params_to_optimize.append({"params": list(filter(lambda p: p.requires_grad, mmdit.parameters())), "lr": args.learning_rate}) # else: # params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs))