Better implementation for te autocast (#895)

* Better implementation for te

* Fix some misunderstanding

* as same as unet, add explicit convert

* Better cache TE and TE lr

* Fix with list

* Add timeout settings

* Fix arg style
This commit is contained in:
Kohaku-Blueleaf
2023-10-28 14:49:59 +08:00
committed by GitHub
parent 202f2c3292
commit 1cefb2a753
4 changed files with 41 additions and 30 deletions

View File

@@ -287,6 +287,8 @@ def train(args):
training_models.append(text_encoder2)
# set require_grad=True later
else:
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)
text_encoder1.requires_grad_(False)
text_encoder2.requires_grad_(False)
text_encoder1.eval()
@@ -295,7 +297,7 @@ def train(args):
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
with torch.no_grad():
with torch.no_grad(), accelerator.autocast():
train_dataset_group.cache_text_encoder_outputs(
(tokenizer1, tokenizer2),
(text_encoder1, text_encoder2),
@@ -315,25 +317,23 @@ def train(args):
m.requires_grad_(True)
if block_lrs is None:
params = []
for m in training_models:
params.extend(m.parameters())
params_to_optimize = params
# calculate number of trainable parameters
n_params = 0
for p in params:
n_params += p.numel()
params_to_optimize = [
{"params": list(training_models[0].parameters()), "lr": args.learning_rate},
]
else:
params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net
for m in training_models[1:]: # Text Encoders if exists
params_to_optimize.append({"params": m.parameters(), "lr": args.learning_rate})
# calculate number of trainable parameters
n_params = 0
for params in params_to_optimize:
for p in params["params"]:
n_params += p.numel()
for m in training_models[1:]: # Text Encoders if exists
params_to_optimize.append({
"params": list(m.parameters()),
"lr": args.learning_rate_te or args.learning_rate
})
# calculate number of trainable parameters
n_params = 0
for params in params_to_optimize:
for p in params["params"]:
n_params += p.numel()
accelerator.print(f"number of models: {len(training_models)}")
accelerator.print(f"number of trainable parameters: {n_params}")
@@ -396,8 +396,6 @@ def train(args):
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
(unet,) = train_util.transform_models_if_DDP([unet])
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
@@ -728,6 +726,7 @@ def setup_parser() -> argparse.ArgumentParser:
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser)
parser.add_argument("--learning_rate_te", type=float, default=0.0, help="learning rate for text encoder")
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")