Sd3 freeze x_block (#1417)

* Update sd3_train.py

* add freeze block lr

* Update train_util.py

* update
This commit is contained in:
青龍聖者@bdsqlsz
2024-09-01 17:41:01 +08:00
committed by GitHub
parent 928e0fc096
commit ef510b3cb9
2 changed files with 29 additions and 1 deletions

View File

@@ -3246,6 +3246,12 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
default=None,
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリネット接続なしでの学習のため",
)
parser.add_argument(
"--num_last_block_to_freeze",
type=int,
default=None,
help="num_last_block_to_freeze",
)
def add_optimizer_arguments(parser: argparse.ArgumentParser):
@@ -5758,6 +5764,21 @@ def sample_image_inference(
pass
def freeze_blocks(model, num_last_block_to_freeze, block_name="x_block"):
filtered_blocks = [(name, param) for name, param in model.named_parameters() if block_name in name]
print(f"filtered_blocks: {len(filtered_blocks)}")
num_blocks_to_freeze = min(len(filtered_blocks), num_last_block_to_freeze)
print(f"freeze_blocks: {num_blocks_to_freeze}")
start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze)
for i in range(start_freezing_from, len(filtered_blocks)):
_, param = filtered_blocks[i]
param.requires_grad = False
# endregion