sample generation in SDXL ControlNet training

This commit is contained in:
Kohya S
2024-09-30 23:39:32 +09:00
parent d78f6a775c
commit 793999d116
5 changed files with 322 additions and 165 deletions

View File

@@ -74,6 +74,7 @@ import imagesize
import cv2
import safetensors.torch
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
import library.model_util as model_util
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
@@ -3581,7 +3582,20 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
# available backends:
# https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5
# https://pytorch.org/docs/stable/torch.compiler.html
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt", "tensort", "ipex", "tvm"],
choices=[
"eager",
"aot_eager",
"inductor",
"aot_ts_nvfuser",
"nvprims_nvfuser",
"cudagraphs",
"ofi",
"fx2trt",
"onnxrt",
"tensort",
"ipex",
"tvm",
],
help="dynamo backend type (default is inductor) / dynamoのbackendの種類デフォルトは inductor",
)
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
@@ -5850,8 +5864,8 @@ def sample_images_common(
pipe_class,
accelerator: Accelerator,
args: argparse.Namespace,
epoch,
steps,
epoch: int,
steps: int,
device,
vae,
tokenizer,
@@ -5910,11 +5924,7 @@ def sample_images_common(
with open(args.sample_prompts, "r", encoding="utf-8") as f:
prompts = json.load(f)
# schedulers: dict = {} cannot find where this is used
default_scheduler = get_my_scheduler(
sample_sampler=args.sample_sampler,
v_parameterization=args.v_parameterization,
)
default_scheduler = get_my_scheduler(sample_sampler=args.sample_sampler, v_parameterization=args.v_parameterization)
pipeline = pipe_class(
text_encoder=text_encoder,
@@ -5975,21 +5985,18 @@ def sample_images_common(
# clear pipeline and cache to reduce vram usage
del pipeline
# I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here.
# with torch.cuda.device(torch.cuda.current_device()):
# torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
torch.set_rng_state(rng_state)
if torch.cuda.is_available() and cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
pipeline,
pipeline: Union[StableDiffusionLongPromptWeightingPipeline, SdxlStableDiffusionLongPromptWeightingPipeline],
save_dir,
prompt_dict,
epoch,