From ccc3a481e7cce678a226552e64c906628d0cad32 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Sun, 28 Jan 2024 14:14:31 +0300 Subject: [PATCH 1/7] Update IPEX Libs --- library/ipex/__init__.py | 15 ++++++---- library/ipex/attention.py | 2 ++ library/ipex/diffusers.py | 2 ++ library/ipex/hijacks.py | 58 +++++++++++++++++++++++++++++++++------ 4 files changed, 63 insertions(+), 14 deletions(-) diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index 33350493..9f2e7c41 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -125,9 +125,13 @@ def ipex_init(): # pylint: disable=too-many-statements # AMP: torch.cuda.amp = torch.xpu.amp + torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled + torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype + if not hasattr(torch.cuda.amp, "common"): torch.cuda.amp.common = contextlib.nullcontext() torch.cuda.amp.common.amp_definitely_not_available = lambda: False + try: torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler except Exception: # pylint: disable=broad-exception-caught @@ -151,15 +155,16 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.has_half = True torch.cuda.is_bf16_supported = lambda *args, **kwargs: True torch.cuda.is_fp16_supported = lambda *args, **kwargs: True - torch.version.cuda = "11.7" - torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7] - torch.cuda.get_device_properties.major = 11 - torch.cuda.get_device_properties.minor = 7 + torch.backends.cuda.is_built = lambda *args, **kwargs: True + torch.version.cuda = "12.1" + torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1] + torch.cuda.get_device_properties.major = 12 + torch.cuda.get_device_properties.minor = 1 torch.cuda.ipc_collect = lambda *args, **kwargs: None torch.cuda.utilization = lambda *args, **kwargs: 0 ipex_hijacks() - if not torch.xpu.has_fp64_dtype(): + if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None: try: from .diffusers import ipex_diffusers ipex_diffusers() diff --git a/library/ipex/attention.py b/library/ipex/attention.py index e98807a8..8253c5b1 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -124,6 +124,7 @@ def torch_bmm_32_bit(input, mat2, *, out=None): ) else: return original_torch_bmm(input, mat2, out=out) + torch.xpu.synchronize(input.device) return hidden_states original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention @@ -172,4 +173,5 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo ) else: return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) + torch.xpu.synchronize(query.device) return hidden_states diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py index 47b0375a..732a1856 100644 --- a/library/ipex/diffusers.py +++ b/library/ipex/diffusers.py @@ -149,6 +149,7 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice del attn_slice + torch.xpu.synchronize(query.device) else: query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] @@ -283,6 +284,7 @@ class AttnProcessor: hidden_states[start_idx:end_idx] = attn_slice del attn_slice + torch.xpu.synchronize(query.device) else: attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index b6d246dd..51de31ad 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -1,6 +1,9 @@ -import contextlib +import os +from functools import wraps +from contextlib import nullcontext import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +import numpy as np # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return @@ -11,7 +14,7 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstr return module.to("xpu") def return_null_context(*args, **kwargs): # pylint: disable=unused-argument - return contextlib.nullcontext() + return nullcontext() @property def is_cuda(self): @@ -25,15 +28,17 @@ def return_xpu(device): # Autocast -original_autocast = torch.autocast -def ipex_autocast(*args, **kwargs): - if len(args) > 0 and args[0] == "cuda": - return original_autocast("xpu", *args[1:], **kwargs) +original_autocast_init = torch.amp.autocast_mode.autocast.__init__ +@wraps(torch.amp.autocast_mode.autocast.__init__) +def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None): + if device_type == "cuda": + return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled) else: - return original_autocast(*args, **kwargs) + return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled) # Latent Antialias CPU Offload: original_interpolate = torch.nn.functional.interpolate +@wraps(torch.nn.functional.interpolate) def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments if antialias or align_corners is not None: return_device = tensor.device @@ -46,13 +51,25 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): original_from_numpy = torch.from_numpy +@wraps(torch.from_numpy) def from_numpy(ndarray): if ndarray.dtype == float: return original_from_numpy(ndarray.astype('float32')) else: return original_from_numpy(ndarray) -if torch.xpu.has_fp64_dtype(): +original_as_tensor = torch.as_tensor +@wraps(torch.as_tensor) +def as_tensor(data, dtype=None, device=None): + if check_device(device): + device = return_xpu(device) + if isinstance(data, np.ndarray) and data.dtype == float and not ( + (isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)): + return original_as_tensor(data, dtype=torch.float32, device=device) + else: + return original_as_tensor(data, dtype=dtype, device=device) + +if torch.xpu.has_fp64_dtype() and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None: original_torch_bmm = torch.bmm original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention else: @@ -66,20 +83,25 @@ else: # Data Type Errors: +@wraps(torch.bmm) def torch_bmm(input, mat2, *, out=None): if input.dtype != mat2.dtype: mat2 = mat2.to(input.dtype) return original_torch_bmm(input, mat2, out=out) +@wraps(torch.nn.functional.scaled_dot_product_attention) def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): if query.dtype != key.dtype: key = key.to(dtype=query.dtype) if query.dtype != value.dtype: value = value.to(dtype=query.dtype) + if attn_mask is not None and query.dtype != attn_mask.dtype: + attn_mask = attn_mask.to(dtype=query.dtype) return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) # A1111 FP16 original_functional_group_norm = torch.nn.functional.group_norm +@wraps(torch.nn.functional.group_norm) def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05): if weight is not None and input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -89,6 +111,7 @@ def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05): # A1111 BF16 original_functional_layer_norm = torch.nn.functional.layer_norm +@wraps(torch.nn.functional.layer_norm) def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05): if weight is not None and input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -98,6 +121,7 @@ def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1 # Training original_functional_linear = torch.nn.functional.linear +@wraps(torch.nn.functional.linear) def functional_linear(input, weight, bias=None): if input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -106,6 +130,7 @@ def functional_linear(input, weight, bias=None): return original_functional_linear(input, weight, bias=bias) original_functional_conv2d = torch.nn.functional.conv2d +@wraps(torch.nn.functional.conv2d) def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): if input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -115,6 +140,7 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, # A1111 Embedding BF16 original_torch_cat = torch.cat +@wraps(torch.cat) def torch_cat(tensor, *args, **kwargs): if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) @@ -123,6 +149,7 @@ def torch_cat(tensor, *args, **kwargs): # SwinIR BF16: original_functional_pad = torch.nn.functional.pad +@wraps(torch.nn.functional.pad) def functional_pad(input, pad, mode='constant', value=None): if mode == 'reflect' and input.dtype == torch.bfloat16: return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16) @@ -131,6 +158,7 @@ def functional_pad(input, pad, mode='constant', value=None): original_torch_tensor = torch.tensor +@wraps(torch.tensor) def torch_tensor(*args, device=None, **kwargs): if check_device(device): return original_torch_tensor(*args, device=return_xpu(device), **kwargs) @@ -138,6 +166,7 @@ def torch_tensor(*args, device=None, **kwargs): return original_torch_tensor(*args, device=device, **kwargs) original_Tensor_to = torch.Tensor.to +@wraps(torch.Tensor.to) def Tensor_to(self, device=None, *args, **kwargs): if check_device(device): return original_Tensor_to(self, return_xpu(device), *args, **kwargs) @@ -145,6 +174,7 @@ def Tensor_to(self, device=None, *args, **kwargs): return original_Tensor_to(self, device, *args, **kwargs) original_Tensor_cuda = torch.Tensor.cuda +@wraps(torch.Tensor.cuda) def Tensor_cuda(self, device=None, *args, **kwargs): if check_device(device): return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs) @@ -152,6 +182,7 @@ def Tensor_cuda(self, device=None, *args, **kwargs): return original_Tensor_cuda(self, device, *args, **kwargs) original_UntypedStorage_init = torch.UntypedStorage.__init__ +@wraps(torch.UntypedStorage.__init__) def UntypedStorage_init(*args, device=None, **kwargs): if check_device(device): return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs) @@ -159,6 +190,7 @@ def UntypedStorage_init(*args, device=None, **kwargs): return original_UntypedStorage_init(*args, device=device, **kwargs) original_UntypedStorage_cuda = torch.UntypedStorage.cuda +@wraps(torch.UntypedStorage.cuda) def UntypedStorage_cuda(self, device=None, *args, **kwargs): if check_device(device): return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs) @@ -166,6 +198,7 @@ def UntypedStorage_cuda(self, device=None, *args, **kwargs): return original_UntypedStorage_cuda(self, device, *args, **kwargs) original_torch_empty = torch.empty +@wraps(torch.empty) def torch_empty(*args, device=None, **kwargs): if check_device(device): return original_torch_empty(*args, device=return_xpu(device), **kwargs) @@ -173,6 +206,7 @@ def torch_empty(*args, device=None, **kwargs): return original_torch_empty(*args, device=device, **kwargs) original_torch_randn = torch.randn +@wraps(torch.randn) def torch_randn(*args, device=None, **kwargs): if check_device(device): return original_torch_randn(*args, device=return_xpu(device), **kwargs) @@ -180,6 +214,7 @@ def torch_randn(*args, device=None, **kwargs): return original_torch_randn(*args, device=device, **kwargs) original_torch_ones = torch.ones +@wraps(torch.ones) def torch_ones(*args, device=None, **kwargs): if check_device(device): return original_torch_ones(*args, device=return_xpu(device), **kwargs) @@ -187,6 +222,7 @@ def torch_ones(*args, device=None, **kwargs): return original_torch_ones(*args, device=device, **kwargs) original_torch_zeros = torch.zeros +@wraps(torch.zeros) def torch_zeros(*args, device=None, **kwargs): if check_device(device): return original_torch_zeros(*args, device=return_xpu(device), **kwargs) @@ -194,6 +230,7 @@ def torch_zeros(*args, device=None, **kwargs): return original_torch_zeros(*args, device=device, **kwargs) original_torch_linspace = torch.linspace +@wraps(torch.linspace) def torch_linspace(*args, device=None, **kwargs): if check_device(device): return original_torch_linspace(*args, device=return_xpu(device), **kwargs) @@ -201,6 +238,7 @@ def torch_linspace(*args, device=None, **kwargs): return original_torch_linspace(*args, device=device, **kwargs) original_torch_Generator = torch.Generator +@wraps(torch.Generator) def torch_Generator(device=None): if check_device(device): return original_torch_Generator(return_xpu(device)) @@ -208,6 +246,7 @@ def torch_Generator(device=None): return original_torch_Generator(device) original_torch_load = torch.load +@wraps(torch.load) def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs): if check_device(map_location): return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) @@ -232,7 +271,7 @@ def ipex_hijacks(): torch.backends.cuda.sdp_kernel = return_null_context torch.nn.DataParallel = DummyDataParallel torch.UntypedStorage.is_cuda = is_cuda - torch.autocast = ipex_autocast + torch.amp.autocast_mode.autocast.__init__ = autocast_init torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention torch.nn.functional.group_norm = functional_group_norm @@ -246,3 +285,4 @@ def ipex_hijacks(): torch.cat = torch_cat if not torch.xpu.has_fp64_dtype(): torch.from_numpy = from_numpy + torch.as_tensor = as_tensor From d4b9568269d09f5e4f81d5f243ca0a9bd3e6dcd3 Mon Sep 17 00:00:00 2001 From: mgz <49577754+mgz-dev@users.noreply.github.com> Date: Sun, 28 Jan 2024 11:59:07 -0600 Subject: [PATCH 2/7] fix broken import in svd_merge_lora script remove missing import, and remove unused imports --- networks/svd_merge_lora.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index 16e813b3..ee2dae30 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -1,4 +1,3 @@ -import math import argparse import os import time @@ -6,8 +5,6 @@ import torch from safetensors.torch import load_file, save_file from tqdm import tqdm from library import sai_model_spec, train_util -import library.model_util as model_util -import lora CLAMP_QUANTILE = 0.99 From 988dee02b9feae8a5fba1dd12cd40aaa42905df8 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Tue, 30 Jan 2024 01:52:32 +0300 Subject: [PATCH 3/7] IPEX torch.tensor FP64 workaround --- library/ipex/hijacks.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index 51de31ad..7b2d26d4 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -5,6 +5,8 @@ import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import import numpy as np +device_supports_fp64 = torch.xpu.has_fp64_dtype() + # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods @@ -49,6 +51,7 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias) + # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): original_from_numpy = torch.from_numpy @wraps(torch.from_numpy) @@ -69,7 +72,8 @@ def as_tensor(data, dtype=None, device=None): else: return original_as_tensor(data, dtype=dtype, device=device) -if torch.xpu.has_fp64_dtype() and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None: + +if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None: original_torch_bmm = torch.bmm original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention else: @@ -159,11 +163,16 @@ def functional_pad(input, pad, mode='constant', value=None): original_torch_tensor = torch.tensor @wraps(torch.tensor) -def torch_tensor(*args, device=None, **kwargs): +def torch_tensor(data, *args, dtype=None, device=None, **kwargs): if check_device(device): - return original_torch_tensor(*args, device=return_xpu(device), **kwargs) - else: - return original_torch_tensor(*args, device=device, **kwargs) + device = return_xpu(device) + if not device_supports_fp64: + if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device): + if dtype == torch.float64: + dtype = torch.float32 + elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)): + dtype = torch.float32 + return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs) original_Tensor_to = torch.Tensor.to @wraps(torch.Tensor.to) @@ -253,6 +262,7 @@ def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, else: return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) + # Hijack Functions: def ipex_hijacks(): torch.tensor = torch_tensor @@ -283,6 +293,6 @@ def ipex_hijacks(): torch.bmm = torch_bmm torch.cat = torch_cat - if not torch.xpu.has_fp64_dtype(): + if not device_supports_fp64: torch.from_numpy = from_numpy torch.as_tensor = as_tensor From 1567ce1e1777c169bd59237f1f722b2e9722bbf3 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 3 Feb 2024 20:46:31 +0800 Subject: [PATCH 4/7] Enable distributed sample image generation on multi-GPU enviroment (#1061) * Update train_util.py Modifying to attempt enable multi GPU inference * Update train_util.py additional VRAM checking, refactor check_vram_usage to return string for use with accelerator.print * Update train_network.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py remove sample image debug outputs * Update train_util.py * Update train_util.py * Update train_network.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_network.py * Update train_util.py * Update train_network.py * Update train_network.py * Update train_network.py * Cleanup of debugging outputs * adopt more elegant coding Co-authored-by: Aarni Koskela * Update train_util.py Fix leftover debugging code attempt to refactor inference into separate function * refactor in function generate_per_device_prompt_list() generation of distributed prompt list * Clean up missing variables * fix syntax error * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * true random sample image generation update code to reinitialize random seed to true random if seed was set * true random sample image generation * simplify per process prompt * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_network.py * Update train_network.py * Update train_network.py --------- Co-authored-by: Aarni Koskela --- library/train_util.py | 208 +++++++++++++++++++++++------------------- 1 file changed, 115 insertions(+), 93 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index ba428e50..3e6125f0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -19,7 +19,7 @@ from typing import ( Tuple, Union, ) -from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs +from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState import gc import glob import math @@ -4636,7 +4636,6 @@ def line_to_prompt_dict(line: str) -> dict: return prompt_dict - def sample_images_common( pipe_class, accelerator: Accelerator, @@ -4654,6 +4653,7 @@ def sample_images_common( """ StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した """ + if steps == 0: if not args.sample_at_first: return @@ -4668,13 +4668,15 @@ def sample_images_common( if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch return + distributed_state = PartialState() #testing implementation of multi gpu distributed inference + print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}") if not os.path.isfile(args.sample_prompts): print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return org_vae_device = vae.device # CPUにいるはず - vae.to(device) + vae.to(distributed_state.device) # unwrap unet and text_encoder(s) unet = accelerator.unwrap_model(unet) @@ -4700,12 +4702,11 @@ def sample_images_common( with open(args.sample_prompts, "r", encoding="utf-8") as f: prompts = json.load(f) - schedulers: dict = {} + # schedulers: dict = {} cannot find where this is used default_scheduler = get_my_scheduler( sample_sampler=args.sample_sampler, v_parameterization=args.v_parameterization, ) - schedulers[args.sample_sampler] = default_scheduler pipeline = pipe_class( text_encoder=text_encoder, @@ -4718,114 +4719,135 @@ def sample_images_common( requires_safety_checker=False, clip_skip=args.clip_skip, ) - pipeline.to(device) - + pipeline.to(distributed_state.device) save_dir = args.output_dir + "/sample" os.makedirs(save_dir, exist_ok=True) + + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = generate_per_device_prompt_list(prompts, num_of_processes = distributed_state.num_processes, prompt_replacement = prompt_replacement) rng_state = torch.get_rng_state() cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + # True random sample image generation + torch.seed() + torch.cuda.seed() with torch.no_grad(): - # with accelerator.autocast(): - for i, prompt_dict in enumerate(prompts): - if not accelerator.is_main_process: - continue + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, controlnet=controlnet) - if isinstance(prompt_dict, str): - prompt_dict = line_to_prompt_dict(prompt_dict) - - assert isinstance(prompt_dict, dict) - negative_prompt = prompt_dict.get("negative_prompt") - sample_steps = prompt_dict.get("sample_steps", 30) - width = prompt_dict.get("width", 512) - height = prompt_dict.get("height", 512) - scale = prompt_dict.get("scale", 7.5) - seed = prompt_dict.get("seed") - controlnet_image = prompt_dict.get("controlnet_image") - prompt: str = prompt_dict.get("prompt", "") - sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) - - if seed is not None: - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - - scheduler = schedulers.get(sampler_name) - if scheduler is None: - scheduler = get_my_scheduler( - sample_sampler=sampler_name, - v_parameterization=args.v_parameterization, - ) - schedulers[sampler_name] = scheduler - pipeline.scheduler = scheduler - - if prompt_replacement is not None: - prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) - if negative_prompt is not None: - negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) - - if controlnet_image is not None: - controlnet_image = Image.open(controlnet_image).convert("RGB") - controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) - - height = max(64, height - height % 8) # round to divisible by 8 - width = max(64, width - width % 8) # round to divisible by 8 - print(f"prompt: {prompt}") - print(f"negative_prompt: {negative_prompt}") - print(f"height: {height}") - print(f"width: {width}") - print(f"sample_steps: {sample_steps}") - print(f"scale: {scale}") - print(f"sample_sampler: {sampler_name}") - if seed is not None: - print(f"seed: {seed}") - with accelerator.autocast(): - latents = pipeline( - prompt=prompt, - height=height, - width=width, - num_inference_steps=sample_steps, - guidance_scale=scale, - negative_prompt=negative_prompt, - controlnet=controlnet, - controlnet_image=controlnet_image, - ) - - image = pipeline.latents_to_image(latents)[0] - - ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" - seed_suffix = "" if seed is None else f"_{seed}" - img_filename = ( - f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png" - ) - - image.save(os.path.join(save_dir, img_filename)) - - # wandb有効時のみログを送信 - try: - wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") - - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass # clear pipeline and cache to reduce vram usage del pipeline - torch.cuda.empty_cache() + with torch.cuda.device(torch.cuda.current_device()): + torch.cuda.empty_cache() + torch.set_rng_state(rng_state) if cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) +def generate_per_device_prompt_list(prompts, num_of_processes, prompt_replacement=None): + + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [[] for i in range(num_of_processes)] + for i, prompt in enumerate(prompts): + if isinstance(prompt, str): + prompt = line_to_prompt_dict(prompt) + assert isinstance(prompt, dict) + prompt.pop("subset", None) # Clean up subset key + prompt["enum"] = i + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + if prompt_replacement is not None: + prompt["prompt"] = prompt["prompt"].replace(prompt_replacement[0], prompt_replacement[1]) + if prompt["negative_prompt"] is not None: + prompt["negative_prompt"] = prompt["negative_prompt"].replace(prompt_replacement[0], prompt_replacement[1]) + # Refactor prompt replacement to here in order to simplify sample_image_inference function. + per_process_prompts[i % num_of_processes].append(prompt) + return per_process_prompts +def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, pipeline, save_dir, prompt_dict, epoch, steps, controlnet=None): + assert isinstance(prompt_dict, dict) + negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 30) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 7.5) + seed = prompt_dict.get("seed") + controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + scheduler = get_my_scheduler( + sample_sampler=sampler_name, + v_parameterization=args.v_parameterization, + ) + pipeline.scheduler = scheduler + + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + print(f"\nprompt: {prompt}") + print(f"negative_prompt: {negative_prompt}") + print(f"height: {height}") + print(f"width: {width}") + print(f"sample_steps: {sample_steps}") + print(f"scale: {scale}") + print(f"sample_sampler: {sampler_name}") + if seed is not None: + print(f"seed: {seed}") + with accelerator.autocast(): + latents = pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=sample_steps, + guidance_scale=scale, + negative_prompt=negative_prompt, + controlnet=controlnet, + controlnet_image=controlnet_image, + ) + image = pipeline.latents_to_image(latents)[0] + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = ( + f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + ) + + image.save(os.path.join(save_dir, img_filename)) + if seed is not None: + torch.seed() + torch.cuda.seed() + # wandb有効時のみログを送信 + try: + wandb_tracker = accelerator.get_tracker("wandb") + try: + import wandb + except ImportError: # 事前に一度確認するのでここはエラー出ないはず + raise ImportError("No wandb / wandb がインストールされていないようです") + + wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) + except: # wandb 無効時 + pass # endregion + + + # region 前処理用 From 11aced35005221c05920e33658ceba46fc0e4272 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 3 Feb 2024 22:25:29 +0900 Subject: [PATCH 5/7] simplify multi-GPU sample generation --- library/train_util.py | 94 ++++++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 46 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 3e6125f0..177fae55 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4668,13 +4668,13 @@ def sample_images_common( if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch return - distributed_state = PartialState() #testing implementation of multi gpu distributed inference - print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}") if not os.path.isfile(args.sample_prompts): print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + org_vae_device = vae.device # CPUにいるはず vae.to(distributed_state.device) @@ -4686,10 +4686,6 @@ def sample_images_common( text_encoder = accelerator.unwrap_model(text_encoder) # read prompts - - # with open(args.sample_prompts, "rt", encoding="utf-8") as f: - # prompts = f.readlines() - if args.sample_prompts.endswith(".txt"): with open(args.sample_prompts, "r", encoding="utf-8") as f: lines = f.readlines() @@ -4722,22 +4718,39 @@ def sample_images_common( pipeline.to(distributed_state.device) save_dir = args.output_dir + "/sample" os.makedirs(save_dir, exist_ok=True) - - # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available) - # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. - per_process_prompts = generate_per_device_prompt_list(prompts, num_of_processes = distributed_state.num_processes, prompt_replacement = prompt_replacement) + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + # save random state to restore later rng_state = torch.get_rng_state() - cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None - # True random sample image generation - torch.seed() - torch.cuda.seed() - - with torch.no_grad(): - with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: - for prompt_dict in prompt_dict_lists[0]: - sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, controlnet=controlnet) - + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None # TODO mps etc. support + + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + with torch.no_grad(): + for prompt_dict in prompts: + sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i::distributed_state.num_processes]) + + with torch.no_grad(): + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet) # clear pipeline and cache to reduce vram usage del pipeline @@ -4750,27 +4763,7 @@ def sample_images_common( torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) -def generate_per_device_prompt_list(prompts, num_of_processes, prompt_replacement=None): - - # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available) - # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. - per_process_prompts = [[] for i in range(num_of_processes)] - for i, prompt in enumerate(prompts): - if isinstance(prompt, str): - prompt = line_to_prompt_dict(prompt) - assert isinstance(prompt, dict) - prompt.pop("subset", None) # Clean up subset key - prompt["enum"] = i - # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. - if prompt_replacement is not None: - prompt["prompt"] = prompt["prompt"].replace(prompt_replacement[0], prompt_replacement[1]) - if prompt["negative_prompt"] is not None: - prompt["negative_prompt"] = prompt["negative_prompt"].replace(prompt_replacement[0], prompt_replacement[1]) - # Refactor prompt replacement to here in order to simplify sample_image_inference function. - per_process_prompts[i % num_of_processes].append(prompt) - return per_process_prompts - -def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, pipeline, save_dir, prompt_dict, epoch, steps, controlnet=None): +def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=None): assert isinstance(prompt_dict, dict) negative_prompt = prompt_dict.get("negative_prompt") sample_steps = prompt_dict.get("sample_steps", 30) @@ -4781,10 +4774,19 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p controlnet_image = prompt_dict.get("controlnet_image") prompt: str = prompt_dict.get("prompt", "") sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) - + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if seed is not None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() scheduler = get_my_scheduler( sample_sampler=sampler_name, @@ -4819,7 +4821,10 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p controlnet_image=controlnet_image, ) image = pipeline.latents_to_image(latents)[0] + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" seed_suffix = "" if seed is None else f"_{seed}" @@ -4827,11 +4832,8 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p img_filename = ( f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" ) - image.save(os.path.join(save_dir, img_filename)) - if seed is not None: - torch.seed() - torch.cuda.seed() + # wandb有効時のみログを送信 try: wandb_tracker = accelerator.get_tracker("wandb") From 2f9a34429729c8b44c6f04cb27656d578ecb9420 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 3 Feb 2024 23:26:57 +0900 Subject: [PATCH 6/7] fix typo --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 177fae55..1377997c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4741,7 +4741,7 @@ def sample_images_common( for prompt_dict in prompts: sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet) else: - # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available) + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. per_process_prompts = [] # list of lists for i in range(distributed_state.num_processes): From e793d7780d779855f23210d1c88368fd9286666e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 4 Feb 2024 17:31:01 +0900 Subject: [PATCH 7/7] reduce peak VRAM in sample gen --- library/train_util.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 1377997c..32198774 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4820,6 +4820,10 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p controlnet=controlnet, controlnet_image=controlnet_image, ) + + with torch.cuda.device(torch.cuda.current_device()): + torch.cuda.empty_cache() + image = pipeline.latents_to_image(latents)[0] # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list