diff --git a/train_db.py b/train_db.py index 9ae6c8ca..b6d11d8d 100644 --- a/train_db.py +++ b/train_db.py @@ -1011,6 +1011,7 @@ def train(args): if stop_text_encoder_training: print(f"stop text encoder training at step {global_step}") text_encoder.train(False) + text_encoder.requires_grad_(False) with accelerator.accumulate(unet): with torch.no_grad():