mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 14:45:19 +00:00
Move freeze_blocks to sd3_train because it's only for sd3
This commit is contained in:
@@ -309,6 +309,9 @@ resolution = [512, 512]
|
|||||||
|
|
||||||
SD3 training is done with `sd3_train.py`.
|
SD3 training is done with `sd3_train.py`.
|
||||||
|
|
||||||
|
__Sep 1, 2024__:
|
||||||
|
- `--num_last_block_to_freeze` is added to `sd3_train.py`. This option is to freeze the last n blocks of the MMDiT. See [#1417](https://github.com/kohya-ss/sd-scripts/pull/1417) for details. Thanks to sdbds!
|
||||||
|
|
||||||
__Jul 27, 2024__:
|
__Jul 27, 2024__:
|
||||||
- Latents and text encoder outputs caching mechanism is refactored significantly.
|
- Latents and text encoder outputs caching mechanism is refactored significantly.
|
||||||
- Existing cache files for SD3 need to be recreated. Please delete the previous cache files.
|
- Existing cache files for SD3 need to be recreated. Please delete the previous cache files.
|
||||||
|
|||||||
@@ -3246,12 +3246,6 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
|||||||
default=None,
|
default=None,
|
||||||
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)",
|
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--num_last_block_to_freeze",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="num_last_block_to_freeze",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||||
@@ -5764,21 +5758,6 @@ def sample_image_inference(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def freeze_blocks(model, num_last_block_to_freeze, block_name="x_block"):
|
|
||||||
|
|
||||||
filtered_blocks = [(name, param) for name, param in model.named_parameters() if block_name in name]
|
|
||||||
print(f"filtered_blocks: {len(filtered_blocks)}")
|
|
||||||
|
|
||||||
num_blocks_to_freeze = min(len(filtered_blocks), num_last_block_to_freeze)
|
|
||||||
|
|
||||||
print(f"freeze_blocks: {num_blocks_to_freeze}")
|
|
||||||
|
|
||||||
start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze)
|
|
||||||
|
|
||||||
for i in range(start_freezing_from, len(filtered_blocks)):
|
|
||||||
_, param = filtered_blocks[i]
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
22
sd3_train.py
22
sd3_train.py
@@ -373,7 +373,20 @@ def train(args):
|
|||||||
mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared
|
mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared
|
||||||
|
|
||||||
if args.num_last_block_to_freeze:
|
if args.num_last_block_to_freeze:
|
||||||
train_util.freeze_blocks(mmdit,num_last_block_to_freeze=args.num_last_block_to_freeze)
|
# freeze last n blocks of MM-DIT
|
||||||
|
block_name = "x_block"
|
||||||
|
filtered_blocks = [(name, param) for name, param in mmdit.named_parameters() if block_name in name]
|
||||||
|
accelerator.print(f"filtered_blocks: {len(filtered_blocks)}")
|
||||||
|
|
||||||
|
num_blocks_to_freeze = min(len(filtered_blocks), args.num_last_block_to_freeze)
|
||||||
|
|
||||||
|
accelerator.print(f"freeze_blocks: {num_blocks_to_freeze}")
|
||||||
|
|
||||||
|
start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze)
|
||||||
|
|
||||||
|
for i in range(start_freezing_from, len(filtered_blocks)):
|
||||||
|
_, param = filtered_blocks[i]
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
training_models = []
|
training_models = []
|
||||||
params_to_optimize = []
|
params_to_optimize = []
|
||||||
@@ -1033,12 +1046,17 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
default=None,
|
default=None,
|
||||||
help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数",
|
help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--skip_latents_validity_check",
|
"--skip_latents_validity_check",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="skip latents validity check / latentsの正当性チェックをスキップする",
|
help="skip latents validity check / latentsの正当性チェックをスキップする",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_last_block_to_freeze",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="freeze last n blocks of MM-DIT / MM-DITの最後のnブロックを凍結する",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user