fix cond image normlization, add independent LR for control

This commit is contained in:
Kohya S
2024-10-03 21:32:21 +09:00
parent 793999d116
commit c2440f9e53
3 changed files with 46 additions and 7 deletions

View File

@@ -253,11 +253,20 @@ def train(args):
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
trainable_params = list(control_net.parameters())
# for p in trainable_params:
# p.requires_grad = True
logger.info(f"trainable params count: {len(trainable_params)}")
logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
trainable_params = []
ctrlnet_params = []
unet_params = []
for name, param in control_net.named_parameters():
if name.startswith("controlnet_"):
ctrlnet_params.append(param)
else:
unet_params.append(param)
trainable_params.append({"params": ctrlnet_params, "lr": args.control_net_lr})
trainable_params.append({"params": unet_params, "lr": args.learning_rate})
all_params = ctrlnet_params + unet_params
logger.info(f"trainable params count: {len(all_params)}")
logger.info(f"number of trainable parameters: {sum(p.numel() for p in all_params)}")
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
@@ -456,6 +465,8 @@ def train(args):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
control_net.train()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(control_net):
@@ -510,6 +521,9 @@ def train(args):
controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
# '-1 to +1' to '0 to 1'
controlnet_image = (controlnet_image + 1) / 2
with accelerator.autocast():
input_resi_add, mid_add = control_net(
noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image
@@ -690,6 +704,12 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
parser.add_argument(
"--control_net_lr",
type=float,
default=1e-4,
help="learning rate for controlnet / controlnetの学習率",
)
return parser