mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Better implementation for te autocast (#895)
* Better implementation for te * Fix some misunderstanding * as same as unet, add explicit convert * Better cache TE and TE lr * Fix with list * Add timeout settings * Fix arg style
This commit is contained in:
@@ -287,6 +287,8 @@ def train(args):
|
||||
training_models.append(text_encoder2)
|
||||
# set require_grad=True later
|
||||
else:
|
||||
text_encoder1.to(weight_dtype)
|
||||
text_encoder2.to(weight_dtype)
|
||||
text_encoder1.requires_grad_(False)
|
||||
text_encoder2.requires_grad_(False)
|
||||
text_encoder1.eval()
|
||||
@@ -295,7 +297,7 @@ def train(args):
|
||||
# TextEncoderの出力をキャッシュする
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad
|
||||
with torch.no_grad():
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
train_dataset_group.cache_text_encoder_outputs(
|
||||
(tokenizer1, tokenizer2),
|
||||
(text_encoder1, text_encoder2),
|
||||
@@ -315,25 +317,23 @@ def train(args):
|
||||
m.requires_grad_(True)
|
||||
|
||||
if block_lrs is None:
|
||||
params = []
|
||||
for m in training_models:
|
||||
params.extend(m.parameters())
|
||||
params_to_optimize = params
|
||||
|
||||
# calculate number of trainable parameters
|
||||
n_params = 0
|
||||
for p in params:
|
||||
n_params += p.numel()
|
||||
params_to_optimize = [
|
||||
{"params": list(training_models[0].parameters()), "lr": args.learning_rate},
|
||||
]
|
||||
else:
|
||||
params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net
|
||||
for m in training_models[1:]: # Text Encoders if exists
|
||||
params_to_optimize.append({"params": m.parameters(), "lr": args.learning_rate})
|
||||
|
||||
# calculate number of trainable parameters
|
||||
n_params = 0
|
||||
for params in params_to_optimize:
|
||||
for p in params["params"]:
|
||||
n_params += p.numel()
|
||||
for m in training_models[1:]: # Text Encoders if exists
|
||||
params_to_optimize.append({
|
||||
"params": list(m.parameters()),
|
||||
"lr": args.learning_rate_te or args.learning_rate
|
||||
})
|
||||
|
||||
# calculate number of trainable parameters
|
||||
n_params = 0
|
||||
for params in params_to_optimize:
|
||||
for p in params["params"]:
|
||||
n_params += p.numel()
|
||||
|
||||
accelerator.print(f"number of models: {len(training_models)}")
|
||||
accelerator.print(f"number of trainable parameters: {n_params}")
|
||||
@@ -396,8 +396,6 @@ def train(args):
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
(unet,) = train_util.transform_models_if_DDP([unet])
|
||||
text_encoder1.to(weight_dtype)
|
||||
text_encoder2.to(weight_dtype)
|
||||
|
||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||
if args.cache_text_encoder_outputs:
|
||||
@@ -728,6 +726,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||
parser.add_argument("--learning_rate_te", type=float, default=0.0, help="learning rate for text encoder")
|
||||
|
||||
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
||||
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||
|
||||
Reference in New Issue
Block a user