fix get_trainable_params in controlnet-llite training

This commit is contained in:
AngelBottomless
2024-05-07 18:21:31 +09:00
committed by GitHub
parent 71e2c91330
commit 793aeb94da

View File

@@ -477,7 +477,7 @@ def train(args):
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = unet.get_trainable_params()
params_to_clip = accelerator.unwrap_model(unet).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()