Merge pull request #1000 from Isotr0py/dev

Fix multi-gpu SDXL training
This commit is contained in:
Kohya S
2023-12-13 20:52:11 +09:00
committed by GitHub
2 changed files with 14 additions and 2 deletions

View File

@@ -19,7 +19,7 @@ from typing import (
Tuple, Tuple,
Union, Union,
) )
from accelerate import Accelerator, InitProcessGroupKwargs from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
import gc import gc
import glob import glob
import math import math
@@ -2899,6 +2899,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None, default=None,
help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト分、Noneでaccelerateのデフォルト", help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト分、Noneでaccelerateのデフォルト",
) )
parser.add_argument(
"--gradient_as_bucket_view", action="store_true", help="enable gradient_as_bucket_view for DDP",
)
parser.add_argument(
"--static_graph", action="store_true", help="enable static_graph for DDP",
)
parser.add_argument( parser.add_argument(
"--clip_skip", "--clip_skip",
type=int, type=int,
@@ -3860,9 +3866,12 @@ def prepare_accelerator(args: argparse.Namespace):
if args.wandb_api_key is not None: if args.wandb_api_key is not None:
wandb.login(key=args.wandb_api_key) wandb.login(key=args.wandb_api_key)
kwargs_handlers = ( kwargs_handlers = (
None if args.ddp_timeout is None else [InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout))] InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None,
DistributedDataParallelKwargs(gradient_as_bucket_view=args.gradient_as_bucket_view, static_graph=args.static_graph)
) )
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,

View File

@@ -398,6 +398,9 @@ def train(args):
if train_unet: if train_unet:
unet = accelerator.prepare(unet) unet = accelerator.prepare(unet)
if train_text_encoder1: if train_text_encoder1:
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
text_encoder1 = accelerator.prepare(text_encoder1) text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2: if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2) text_encoder2 = accelerator.prepare(text_encoder2)