mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix DDP SDXL training
This commit is contained in:
@@ -19,7 +19,7 @@ from typing import (
|
|||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from accelerate import Accelerator, InitProcessGroupKwargs
|
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
|
||||||
import gc
|
import gc
|
||||||
import glob
|
import glob
|
||||||
import math
|
import math
|
||||||
@@ -2878,6 +2878,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
default=None,
|
default=None,
|
||||||
help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)",
|
help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gradient_as_bucket_view", action="store_true", help="enable gradient_as_bucket_view for DDP",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--static_graph", action="store_true", help="enable static_graph for DDP",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--clip_skip",
|
"--clip_skip",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -3832,9 +3838,12 @@ def prepare_accelerator(args: argparse.Namespace):
|
|||||||
if args.wandb_api_key is not None:
|
if args.wandb_api_key is not None:
|
||||||
wandb.login(key=args.wandb_api_key)
|
wandb.login(key=args.wandb_api_key)
|
||||||
|
|
||||||
|
|
||||||
kwargs_handlers = (
|
kwargs_handlers = (
|
||||||
None if args.ddp_timeout is None else [InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout))]
|
InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None,
|
||||||
|
DistributedDataParallelKwargs(gradient_as_bucket_view=args.gradient_as_bucket_view, static_graph=args.static_graph)
|
||||||
)
|
)
|
||||||
|
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
|
||||||
accelerator = Accelerator(
|
accelerator = Accelerator(
|
||||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
mixed_precision=args.mixed_precision,
|
mixed_precision=args.mixed_precision,
|
||||||
|
|||||||
@@ -398,6 +398,9 @@ def train(args):
|
|||||||
if train_unet:
|
if train_unet:
|
||||||
unet = accelerator.prepare(unet)
|
unet = accelerator.prepare(unet)
|
||||||
if train_text_encoder1:
|
if train_text_encoder1:
|
||||||
|
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
|
||||||
|
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
|
||||||
|
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
|
||||||
text_encoder1 = accelerator.prepare(text_encoder1)
|
text_encoder1 = accelerator.prepare(text_encoder1)
|
||||||
if train_text_encoder2:
|
if train_text_encoder2:
|
||||||
text_encoder2 = accelerator.prepare(text_encoder2)
|
text_encoder2 = accelerator.prepare(text_encoder2)
|
||||||
|
|||||||
Reference in New Issue
Block a user