mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Sd3 freeze x_block (#1417)
* Update sd3_train.py * add freeze block lr * Update train_util.py * update
This commit is contained in:
@@ -3246,6 +3246,12 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
|||||||
default=None,
|
default=None,
|
||||||
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)",
|
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_last_block_to_freeze",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="num_last_block_to_freeze",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||||
@@ -5758,6 +5764,21 @@ def sample_image_inference(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def freeze_blocks(model, num_last_block_to_freeze, block_name="x_block"):
|
||||||
|
|
||||||
|
filtered_blocks = [(name, param) for name, param in model.named_parameters() if block_name in name]
|
||||||
|
print(f"filtered_blocks: {len(filtered_blocks)}")
|
||||||
|
|
||||||
|
num_blocks_to_freeze = min(len(filtered_blocks), num_last_block_to_freeze)
|
||||||
|
|
||||||
|
print(f"freeze_blocks: {num_blocks_to_freeze}")
|
||||||
|
|
||||||
|
start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze)
|
||||||
|
|
||||||
|
for i in range(start_freezing_from, len(filtered_blocks)):
|
||||||
|
_, param = filtered_blocks[i]
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -368,12 +368,19 @@ def train(args):
|
|||||||
vae.eval()
|
vae.eval()
|
||||||
vae.to(accelerator.device, dtype=vae_dtype)
|
vae.to(accelerator.device, dtype=vae_dtype)
|
||||||
|
|
||||||
|
mmdit.requires_grad_(train_mmdit)
|
||||||
|
if not train_mmdit:
|
||||||
|
mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared
|
||||||
|
|
||||||
|
if args.num_last_block_to_freeze:
|
||||||
|
train_util.freeze_blocks(mmdit,num_last_block_to_freeze=args.num_last_block_to_freeze)
|
||||||
|
|
||||||
training_models = []
|
training_models = []
|
||||||
params_to_optimize = []
|
params_to_optimize = []
|
||||||
# if train_unet:
|
# if train_unet:
|
||||||
training_models.append(mmdit)
|
training_models.append(mmdit)
|
||||||
# if block_lrs is None:
|
# if block_lrs is None:
|
||||||
params_to_optimize.append({"params": list(mmdit.parameters()), "lr": args.learning_rate})
|
params_to_optimize.append({"params": list(filter(lambda p: p.requires_grad, mmdit.parameters())), "lr": args.learning_rate})
|
||||||
# else:
|
# else:
|
||||||
# params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs))
|
# params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user