Merge branch 'dev' into gradual_latent_hires_fix

This commit is contained in:
Kohya S
2024-01-28 08:21:15 +09:00
24 changed files with 873 additions and 690 deletions

File diff suppressed because it is too large Load Diff

24
library/ipex_interop.py Normal file
View File

@@ -0,0 +1,24 @@
import torch
def init_ipex():
"""
Try to import `intel_extension_for_pytorch`, and apply
the hijacks using `library.ipex.ipex_init`.
If IPEX is not installed, this function does nothing.
"""
try:
import intel_extension_for_pytorch as ipex # noqa
except ImportError:
return
try:
from library.ipex import ipex_init
if torch.xpu.is_available():
is_initialized, error_message = ipex_init()
if not is_initialized:
print("failed to initialize ipex:", error_message)
except Exception as e:
print("failed to initialize ipex:", e)

View File

@@ -5,15 +5,9 @@ import math
import os
import torch
try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
init_ipex()
import diffusers
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
@@ -1245,8 +1239,13 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod
if vae is None:
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
# original U-Net cannot be saved, so we need to convert it to the Diffusers version
# TODO this consumes a lot of memory
diffusers_unet = diffusers.UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
diffusers_unet.load_state_dict(unet.state_dict())
pipeline = StableDiffusionPipeline(
unet=unet,
unet=diffusers_unet,
text_encoder=text_encoder,
vae=vae,
scheduler=scheduler,

View File

@@ -1262,9 +1262,9 @@ class CrossAttnUpBlock2D(nn.Module):
for attn in self.attentions:
attn.set_use_memory_efficient_attention(xformers, mem_eff)
def set_use_sdpa(self, spda):
def set_use_sdpa(self, sdpa):
for attn in self.attentions:
attn.set_use_sdpa(spda)
attn.set_use_sdpa(sdpa)
def forward(
self,

View File

@@ -923,7 +923,11 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
if up1 is not None:
uncond_pool = up1
dtype = self.unet.dtype
unet_dtype = self.unet.dtype
dtype = unet_dtype
if hasattr(dtype, "itemsize") and dtype.itemsize == 1: # fp8
dtype = torch.float16
self.unet.to(dtype)
# 4. Preprocess image and mask
if isinstance(image, PIL.Image.Image):
@@ -1028,6 +1032,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
if is_cancelled_callback is not None and is_cancelled_callback():
return None
self.unet.to(unet_dtype)
return latents
def latents_to_image(self, latents):

View File

@@ -558,6 +558,7 @@ class BaseDataset(torch.utils.data.Dataset):
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]],
max_token_length: int,
resolution: Optional[Tuple[int, int]],
network_multiplier: float,
debug_dataset: bool,
) -> None:
super().__init__()
@@ -567,6 +568,7 @@ class BaseDataset(torch.utils.data.Dataset):
self.max_token_length = max_token_length
# width/height is used when enable_bucket==False
self.width, self.height = (None, None) if resolution is None else resolution
self.network_multiplier = network_multiplier
self.debug_dataset = debug_dataset
self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []
@@ -1106,7 +1108,9 @@ class BaseDataset(torch.utils.data.Dataset):
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
subset = self.image_to_subset[image_key]
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
loss_weights.append(
self.prior_loss_weight if image_info.is_reg else 1.0
) # in case of fine tuning, is_reg is always False
flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance
@@ -1272,6 +1276,8 @@ class BaseDataset(torch.utils.data.Dataset):
example["target_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in target_sizes_hw])
example["flippeds"] = flippeds
example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions))
if self.debug_dataset:
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
return example
@@ -1346,15 +1352,16 @@ class DreamBoothDataset(BaseDataset):
tokenizer,
max_token_length,
resolution,
network_multiplier: float,
enable_bucket: bool,
min_bucket_reso: int,
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
prior_loss_weight: float,
debug_dataset,
debug_dataset: bool,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
assert resolution is not None, f"resolution is required / resolution解像度指定は必須です"
@@ -1520,14 +1527,15 @@ class FineTuningDataset(BaseDataset):
tokenizer,
max_token_length,
resolution,
network_multiplier: float,
enable_bucket: bool,
min_bucket_reso: int,
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
debug_dataset,
debug_dataset: bool,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
self.batch_size = batch_size
@@ -1724,14 +1732,15 @@ class ControlNetDataset(BaseDataset):
tokenizer,
max_token_length,
resolution,
network_multiplier: float,
enable_bucket: bool,
min_bucket_reso: int,
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
debug_dataset,
debug_dataset: float,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
db_subsets = []
for subset in subsets:
@@ -1765,6 +1774,7 @@ class ControlNetDataset(BaseDataset):
tokenizer,
max_token_length,
resolution,
network_multiplier,
enable_bucket,
min_bucket_reso,
max_bucket_reso,
@@ -2039,6 +2049,8 @@ def debug_dataset(train_dataset, show_input_ids=False):
print(
f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}'
)
if "network_multipliers" in example:
print(f"network multiplier: {example['network_multipliers'][j]}")
if show_input_ids:
print(f"input ids: {iid}")
@@ -2105,8 +2117,8 @@ def glob_images_pathlib(dir_path, recursive):
class MinimalDataset(BaseDataset):
def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False):
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False):
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
self.num_train_images = 0 # update in subclass
self.num_reg_images = 0 # update in subclass
@@ -2850,14 +2862,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
)
parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う")
parser.add_argument(
"--dynamo_backend",
type=str,
default="inductor",
"--dynamo_backend",
type=str,
default="inductor",
# available backends:
# https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5
# https://pytorch.org/docs/stable/torch.compiler.html
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"],
help="dynamo backend type (default is inductor) / dynamoのbackendの種類デフォルトは inductor"
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"],
help="dynamo backend type (default is inductor) / dynamoのbackendの種類デフォルトは inductor",
)
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
parser.add_argument(
@@ -2904,6 +2916,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
) # TODO move to SDXL training, because it is not supported by SD1/2
parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う")
parser.add_argument(
"--ddp_timeout",
type=int,
@@ -2946,6 +2959,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名",
)
parser.add_argument(
"--wandb_run_name",
type=str,
default=None,
help="The name of the specific wandb session / wandb ログに表示される特定の実行の名前",
)
parser.add_argument(
"--log_tracker_config",
type=str,
@@ -3880,7 +3899,7 @@ def prepare_accelerator(args: argparse.Namespace):
os.environ["WANDB_DIR"] = logging_dir
if args.wandb_api_key is not None:
wandb.login(key=args.wandb_api_key)
# torch.compile のオプション。 NO の場合は torch.compile は使わない
dynamo_backend = "NO"
if args.torch_compile: