mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Better implementation for te autocast (#895)
* Better implementation for te * Fix some misunderstanding * as same as unet, add explicit convert * Better cache TE and TE lr * Fix with list * Add timeout settings * Fix arg style
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
import argparse
|
||||
import ast
|
||||
import asyncio
|
||||
import datetime
|
||||
import importlib
|
||||
import json
|
||||
import pathlib
|
||||
@@ -18,7 +19,7 @@ from typing import (
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from accelerate import Accelerator
|
||||
from accelerate import Accelerator, InitProcessGroupKwargs
|
||||
import gc
|
||||
import glob
|
||||
import math
|
||||
@@ -2855,6 +2856,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument(
|
||||
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
|
||||
) # TODO move to SDXL training, because it is not supported by SD1/2
|
||||
parser.add_argument(
|
||||
"--ddp_timeout", type=int, default=30, help="DDP timeout (min) / DDPのタイムアウト(min)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip_skip",
|
||||
type=int,
|
||||
@@ -3786,6 +3790,7 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=log_with,
|
||||
project_dir=logging_dir,
|
||||
kwargs_handlers=[InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout))],
|
||||
)
|
||||
return accelerator
|
||||
|
||||
|
||||
Reference in New Issue
Block a user