diff --git a/library/train_util.py b/library/train_util.py index d2eb7cb2..71cf49f3 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -19,7 +19,7 @@ from typing import ( Tuple, Union, ) -from accelerate import Accelerator, InitProcessGroupKwargs +from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs import gc import glob import math @@ -2878,6 +2878,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)", ) + parser.add_argument( + "--gradient_as_bucket_view", action="store_true", help="enable gradient_as_bucket_view for DDP", + ) + parser.add_argument( + "--static_graph", action="store_true", help="enable static_graph for DDP", + ) parser.add_argument( "--clip_skip", type=int, @@ -3832,9 +3838,12 @@ def prepare_accelerator(args: argparse.Namespace): if args.wandb_api_key is not None: wandb.login(key=args.wandb_api_key) + kwargs_handlers = ( - None if args.ddp_timeout is None else [InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout))] + InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None, + DistributedDataParallelKwargs(gradient_as_bucket_view=args.gradient_as_bucket_view, static_graph=args.static_graph) ) + kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers)) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, diff --git a/sdxl_train.py b/sdxl_train.py index 501eef65..aa2eb5df 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -398,6 +398,9 @@ def train(args): if train_unet: unet = accelerator.prepare(unet) if train_text_encoder1: + # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer + text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) + text_encoder1.text_model.final_layer_norm.requires_grad_(False) text_encoder1 = accelerator.prepare(text_encoder1) if train_text_encoder2: text_encoder2 = accelerator.prepare(text_encoder2)