mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'sd3' into multi-gpu-caching
This commit is contained in:
@@ -31,8 +31,10 @@ import hashlib
|
||||
import subprocess
|
||||
from io import BytesIO
|
||||
import toml
|
||||
# from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from tqdm import tqdm
|
||||
from packaging.version import Version
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
@@ -74,6 +76,7 @@ import imagesize
|
||||
import cv2
|
||||
import safetensors.torch
|
||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
||||
import library.model_util as model_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
@@ -911,6 +914,23 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if info.image_size is None:
|
||||
info.image_size = self.get_image_size(info.absolute_path)
|
||||
|
||||
# # run in parallel
|
||||
# max_workers = min(os.cpu_count(), len(self.image_data)) # TODO consider multi-gpu (processes)
|
||||
# with ThreadPoolExecutor(max_workers) as executor:
|
||||
# futures = []
|
||||
# for info in tqdm(self.image_data.values(), desc="loading image sizes"):
|
||||
# if info.image_size is None:
|
||||
# def get_and_set_image_size(info):
|
||||
# info.image_size = self.get_image_size(info.absolute_path)
|
||||
# futures.append(executor.submit(get_and_set_image_size, info))
|
||||
# # consume futures to reduce memory usage and prevent Ctrl-C hang
|
||||
# if len(futures) >= max_workers:
|
||||
# for future in futures:
|
||||
# future.result()
|
||||
# futures = []
|
||||
# for future in futures:
|
||||
# future.result()
|
||||
|
||||
if self.enable_bucket:
|
||||
logger.info("make buckets")
|
||||
else:
|
||||
@@ -1830,7 +1850,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
captions = []
|
||||
missing_captions = []
|
||||
for img_path in img_paths:
|
||||
for img_path in tqdm(img_paths, desc="read caption"):
|
||||
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
|
||||
if cap_for_img is None and subset.class_tokens is None:
|
||||
logger.warning(
|
||||
@@ -3586,7 +3606,20 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
# available backends:
|
||||
# https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5
|
||||
# https://pytorch.org/docs/stable/torch.compiler.html
|
||||
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"],
|
||||
choices=[
|
||||
"eager",
|
||||
"aot_eager",
|
||||
"inductor",
|
||||
"aot_ts_nvfuser",
|
||||
"nvprims_nvfuser",
|
||||
"cudagraphs",
|
||||
"ofi",
|
||||
"fx2trt",
|
||||
"onnxrt",
|
||||
"tensort",
|
||||
"ipex",
|
||||
"tvm",
|
||||
],
|
||||
help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)",
|
||||
)
|
||||
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
||||
@@ -5050,17 +5083,18 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
if args.torch_compile:
|
||||
dynamo_backend = args.dynamo_backend
|
||||
|
||||
kwargs_handlers = (
|
||||
InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None,
|
||||
(
|
||||
DistributedDataParallelKwargs(
|
||||
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph
|
||||
)
|
||||
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
|
||||
else None
|
||||
),
|
||||
)
|
||||
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
|
||||
kwargs_handlers = [
|
||||
InitProcessGroupKwargs(
|
||||
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
||||
init_method="env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None,
|
||||
timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None
|
||||
) if torch.cuda.device_count() > 1 else None,
|
||||
DistributedDataParallelKwargs(
|
||||
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view,
|
||||
static_graph=args.ddp_static_graph
|
||||
) if args.ddp_gradient_as_bucket_view or args.ddp_static_graph else None
|
||||
]
|
||||
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
|
||||
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)
|
||||
|
||||
accelerator = Accelerator(
|
||||
@@ -5855,8 +5889,8 @@ def sample_images_common(
|
||||
pipe_class,
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
epoch,
|
||||
steps,
|
||||
epoch: int,
|
||||
steps: int,
|
||||
device,
|
||||
vae,
|
||||
tokenizer,
|
||||
@@ -5915,11 +5949,7 @@ def sample_images_common(
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
prompts = json.load(f)
|
||||
|
||||
# schedulers: dict = {} cannot find where this is used
|
||||
default_scheduler = get_my_scheduler(
|
||||
sample_sampler=args.sample_sampler,
|
||||
v_parameterization=args.v_parameterization,
|
||||
)
|
||||
default_scheduler = get_my_scheduler(sample_sampler=args.sample_sampler, v_parameterization=args.v_parameterization)
|
||||
|
||||
pipeline = pipe_class(
|
||||
text_encoder=text_encoder,
|
||||
@@ -5980,21 +6010,18 @@ def sample_images_common(
|
||||
# clear pipeline and cache to reduce vram usage
|
||||
del pipeline
|
||||
|
||||
# I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here.
|
||||
# with torch.cuda.device(torch.cuda.current_device()):
|
||||
# torch.cuda.empty_cache()
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
if torch.cuda.is_available() and cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
vae.to(org_vae_device)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
|
||||
def sample_image_inference(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
pipeline,
|
||||
pipeline: Union[StableDiffusionLongPromptWeightingPipeline, SdxlStableDiffusionLongPromptWeightingPipeline],
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
|
||||
Reference in New Issue
Block a user