mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Move freeze_blocks to sd3_train because it's only for sd3
This commit is contained in:
@@ -3246,12 +3246,6 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
||||
default=None,
|
||||
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_last_block_to_freeze",
|
||||
type=int,
|
||||
default=None,
|
||||
help="num_last_block_to_freeze",
|
||||
)
|
||||
|
||||
|
||||
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
@@ -5764,21 +5758,6 @@ def sample_image_inference(
|
||||
pass
|
||||
|
||||
|
||||
def freeze_blocks(model, num_last_block_to_freeze, block_name="x_block"):
|
||||
|
||||
filtered_blocks = [(name, param) for name, param in model.named_parameters() if block_name in name]
|
||||
print(f"filtered_blocks: {len(filtered_blocks)}")
|
||||
|
||||
num_blocks_to_freeze = min(len(filtered_blocks), num_last_block_to_freeze)
|
||||
|
||||
print(f"freeze_blocks: {num_blocks_to_freeze}")
|
||||
|
||||
start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze)
|
||||
|
||||
for i in range(start_freezing_from, len(filtered_blocks)):
|
||||
_, param = filtered_blocks[i]
|
||||
param.requires_grad = False
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user