diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index 33350493..9f2e7c41 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -125,9 +125,13 @@ def ipex_init(): # pylint: disable=too-many-statements # AMP: torch.cuda.amp = torch.xpu.amp + torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled + torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype + if not hasattr(torch.cuda.amp, "common"): torch.cuda.amp.common = contextlib.nullcontext() torch.cuda.amp.common.amp_definitely_not_available = lambda: False + try: torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler except Exception: # pylint: disable=broad-exception-caught @@ -151,15 +155,16 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.has_half = True torch.cuda.is_bf16_supported = lambda *args, **kwargs: True torch.cuda.is_fp16_supported = lambda *args, **kwargs: True - torch.version.cuda = "11.7" - torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7] - torch.cuda.get_device_properties.major = 11 - torch.cuda.get_device_properties.minor = 7 + torch.backends.cuda.is_built = lambda *args, **kwargs: True + torch.version.cuda = "12.1" + torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1] + torch.cuda.get_device_properties.major = 12 + torch.cuda.get_device_properties.minor = 1 torch.cuda.ipc_collect = lambda *args, **kwargs: None torch.cuda.utilization = lambda *args, **kwargs: 0 ipex_hijacks() - if not torch.xpu.has_fp64_dtype(): + if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None: try: from .diffusers import ipex_diffusers ipex_diffusers() diff --git a/library/ipex/attention.py b/library/ipex/attention.py index e98807a8..8253c5b1 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -124,6 +124,7 @@ def torch_bmm_32_bit(input, mat2, *, out=None): ) else: return original_torch_bmm(input, mat2, out=out) + torch.xpu.synchronize(input.device) return hidden_states original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention @@ -172,4 +173,5 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo ) else: return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) + torch.xpu.synchronize(query.device) return hidden_states diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py index 47b0375a..732a1856 100644 --- a/library/ipex/diffusers.py +++ b/library/ipex/diffusers.py @@ -149,6 +149,7 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice del attn_slice + torch.xpu.synchronize(query.device) else: query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] @@ -283,6 +284,7 @@ class AttnProcessor: hidden_states[start_idx:end_idx] = attn_slice del attn_slice + torch.xpu.synchronize(query.device) else: attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index b6d246dd..7b2d26d4 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -1,6 +1,11 @@ -import contextlib +import os +from functools import wraps +from contextlib import nullcontext import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +import numpy as np + +device_supports_fp64 = torch.xpu.has_fp64_dtype() # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return @@ -11,7 +16,7 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstr return module.to("xpu") def return_null_context(*args, **kwargs): # pylint: disable=unused-argument - return contextlib.nullcontext() + return nullcontext() @property def is_cuda(self): @@ -25,15 +30,17 @@ def return_xpu(device): # Autocast -original_autocast = torch.autocast -def ipex_autocast(*args, **kwargs): - if len(args) > 0 and args[0] == "cuda": - return original_autocast("xpu", *args[1:], **kwargs) +original_autocast_init = torch.amp.autocast_mode.autocast.__init__ +@wraps(torch.amp.autocast_mode.autocast.__init__) +def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None): + if device_type == "cuda": + return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled) else: - return original_autocast(*args, **kwargs) + return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled) # Latent Antialias CPU Offload: original_interpolate = torch.nn.functional.interpolate +@wraps(torch.nn.functional.interpolate) def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments if antialias or align_corners is not None: return_device = tensor.device @@ -44,15 +51,29 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias) + # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): original_from_numpy = torch.from_numpy +@wraps(torch.from_numpy) def from_numpy(ndarray): if ndarray.dtype == float: return original_from_numpy(ndarray.astype('float32')) else: return original_from_numpy(ndarray) -if torch.xpu.has_fp64_dtype(): +original_as_tensor = torch.as_tensor +@wraps(torch.as_tensor) +def as_tensor(data, dtype=None, device=None): + if check_device(device): + device = return_xpu(device) + if isinstance(data, np.ndarray) and data.dtype == float and not ( + (isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)): + return original_as_tensor(data, dtype=torch.float32, device=device) + else: + return original_as_tensor(data, dtype=dtype, device=device) + + +if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None: original_torch_bmm = torch.bmm original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention else: @@ -66,20 +87,25 @@ else: # Data Type Errors: +@wraps(torch.bmm) def torch_bmm(input, mat2, *, out=None): if input.dtype != mat2.dtype: mat2 = mat2.to(input.dtype) return original_torch_bmm(input, mat2, out=out) +@wraps(torch.nn.functional.scaled_dot_product_attention) def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): if query.dtype != key.dtype: key = key.to(dtype=query.dtype) if query.dtype != value.dtype: value = value.to(dtype=query.dtype) + if attn_mask is not None and query.dtype != attn_mask.dtype: + attn_mask = attn_mask.to(dtype=query.dtype) return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) # A1111 FP16 original_functional_group_norm = torch.nn.functional.group_norm +@wraps(torch.nn.functional.group_norm) def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05): if weight is not None and input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -89,6 +115,7 @@ def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05): # A1111 BF16 original_functional_layer_norm = torch.nn.functional.layer_norm +@wraps(torch.nn.functional.layer_norm) def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05): if weight is not None and input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -98,6 +125,7 @@ def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1 # Training original_functional_linear = torch.nn.functional.linear +@wraps(torch.nn.functional.linear) def functional_linear(input, weight, bias=None): if input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -106,6 +134,7 @@ def functional_linear(input, weight, bias=None): return original_functional_linear(input, weight, bias=bias) original_functional_conv2d = torch.nn.functional.conv2d +@wraps(torch.nn.functional.conv2d) def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): if input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -115,6 +144,7 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, # A1111 Embedding BF16 original_torch_cat = torch.cat +@wraps(torch.cat) def torch_cat(tensor, *args, **kwargs): if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) @@ -123,6 +153,7 @@ def torch_cat(tensor, *args, **kwargs): # SwinIR BF16: original_functional_pad = torch.nn.functional.pad +@wraps(torch.nn.functional.pad) def functional_pad(input, pad, mode='constant', value=None): if mode == 'reflect' and input.dtype == torch.bfloat16: return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16) @@ -131,13 +162,20 @@ def functional_pad(input, pad, mode='constant', value=None): original_torch_tensor = torch.tensor -def torch_tensor(*args, device=None, **kwargs): +@wraps(torch.tensor) +def torch_tensor(data, *args, dtype=None, device=None, **kwargs): if check_device(device): - return original_torch_tensor(*args, device=return_xpu(device), **kwargs) - else: - return original_torch_tensor(*args, device=device, **kwargs) + device = return_xpu(device) + if not device_supports_fp64: + if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device): + if dtype == torch.float64: + dtype = torch.float32 + elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)): + dtype = torch.float32 + return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs) original_Tensor_to = torch.Tensor.to +@wraps(torch.Tensor.to) def Tensor_to(self, device=None, *args, **kwargs): if check_device(device): return original_Tensor_to(self, return_xpu(device), *args, **kwargs) @@ -145,6 +183,7 @@ def Tensor_to(self, device=None, *args, **kwargs): return original_Tensor_to(self, device, *args, **kwargs) original_Tensor_cuda = torch.Tensor.cuda +@wraps(torch.Tensor.cuda) def Tensor_cuda(self, device=None, *args, **kwargs): if check_device(device): return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs) @@ -152,6 +191,7 @@ def Tensor_cuda(self, device=None, *args, **kwargs): return original_Tensor_cuda(self, device, *args, **kwargs) original_UntypedStorage_init = torch.UntypedStorage.__init__ +@wraps(torch.UntypedStorage.__init__) def UntypedStorage_init(*args, device=None, **kwargs): if check_device(device): return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs) @@ -159,6 +199,7 @@ def UntypedStorage_init(*args, device=None, **kwargs): return original_UntypedStorage_init(*args, device=device, **kwargs) original_UntypedStorage_cuda = torch.UntypedStorage.cuda +@wraps(torch.UntypedStorage.cuda) def UntypedStorage_cuda(self, device=None, *args, **kwargs): if check_device(device): return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs) @@ -166,6 +207,7 @@ def UntypedStorage_cuda(self, device=None, *args, **kwargs): return original_UntypedStorage_cuda(self, device, *args, **kwargs) original_torch_empty = torch.empty +@wraps(torch.empty) def torch_empty(*args, device=None, **kwargs): if check_device(device): return original_torch_empty(*args, device=return_xpu(device), **kwargs) @@ -173,6 +215,7 @@ def torch_empty(*args, device=None, **kwargs): return original_torch_empty(*args, device=device, **kwargs) original_torch_randn = torch.randn +@wraps(torch.randn) def torch_randn(*args, device=None, **kwargs): if check_device(device): return original_torch_randn(*args, device=return_xpu(device), **kwargs) @@ -180,6 +223,7 @@ def torch_randn(*args, device=None, **kwargs): return original_torch_randn(*args, device=device, **kwargs) original_torch_ones = torch.ones +@wraps(torch.ones) def torch_ones(*args, device=None, **kwargs): if check_device(device): return original_torch_ones(*args, device=return_xpu(device), **kwargs) @@ -187,6 +231,7 @@ def torch_ones(*args, device=None, **kwargs): return original_torch_ones(*args, device=device, **kwargs) original_torch_zeros = torch.zeros +@wraps(torch.zeros) def torch_zeros(*args, device=None, **kwargs): if check_device(device): return original_torch_zeros(*args, device=return_xpu(device), **kwargs) @@ -194,6 +239,7 @@ def torch_zeros(*args, device=None, **kwargs): return original_torch_zeros(*args, device=device, **kwargs) original_torch_linspace = torch.linspace +@wraps(torch.linspace) def torch_linspace(*args, device=None, **kwargs): if check_device(device): return original_torch_linspace(*args, device=return_xpu(device), **kwargs) @@ -201,6 +247,7 @@ def torch_linspace(*args, device=None, **kwargs): return original_torch_linspace(*args, device=device, **kwargs) original_torch_Generator = torch.Generator +@wraps(torch.Generator) def torch_Generator(device=None): if check_device(device): return original_torch_Generator(return_xpu(device)) @@ -208,12 +255,14 @@ def torch_Generator(device=None): return original_torch_Generator(device) original_torch_load = torch.load +@wraps(torch.load) def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs): if check_device(map_location): return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) else: return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) + # Hijack Functions: def ipex_hijacks(): torch.tensor = torch_tensor @@ -232,7 +281,7 @@ def ipex_hijacks(): torch.backends.cuda.sdp_kernel = return_null_context torch.nn.DataParallel = DummyDataParallel torch.UntypedStorage.is_cuda = is_cuda - torch.autocast = ipex_autocast + torch.amp.autocast_mode.autocast.__init__ = autocast_init torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention torch.nn.functional.group_norm = functional_group_norm @@ -244,5 +293,6 @@ def ipex_hijacks(): torch.bmm = torch_bmm torch.cat = torch_cat - if not torch.xpu.has_fp64_dtype(): + if not device_supports_fp64: torch.from_numpy = from_numpy + torch.as_tensor = as_tensor diff --git a/library/train_util.py b/library/train_util.py index ba428e50..32198774 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -19,7 +19,7 @@ from typing import ( Tuple, Union, ) -from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs +from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState import gc import glob import math @@ -4636,7 +4636,6 @@ def line_to_prompt_dict(line: str) -> dict: return prompt_dict - def sample_images_common( pipe_class, accelerator: Accelerator, @@ -4654,6 +4653,7 @@ def sample_images_common( """ StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した """ + if steps == 0: if not args.sample_at_first: return @@ -4673,8 +4673,10 @@ def sample_images_common( print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + org_vae_device = vae.device # CPUにいるはず - vae.to(device) + vae.to(distributed_state.device) # unwrap unet and text_encoder(s) unet = accelerator.unwrap_model(unet) @@ -4684,10 +4686,6 @@ def sample_images_common( text_encoder = accelerator.unwrap_model(text_encoder) # read prompts - - # with open(args.sample_prompts, "rt", encoding="utf-8") as f: - # prompts = f.readlines() - if args.sample_prompts.endswith(".txt"): with open(args.sample_prompts, "r", encoding="utf-8") as f: lines = f.readlines() @@ -4700,12 +4698,11 @@ def sample_images_common( with open(args.sample_prompts, "r", encoding="utf-8") as f: prompts = json.load(f) - schedulers: dict = {} + # schedulers: dict = {} cannot find where this is used default_scheduler = get_my_scheduler( sample_sampler=args.sample_sampler, v_parameterization=args.v_parameterization, ) - schedulers[args.sample_sampler] = default_scheduler pipeline = pipe_class( text_encoder=text_encoder, @@ -4718,114 +4715,145 @@ def sample_images_common( requires_safety_checker=False, clip_skip=args.clip_skip, ) - pipeline.to(device) - + pipeline.to(distributed_state.device) save_dir = args.output_dir + "/sample" os.makedirs(save_dir, exist_ok=True) + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + # save random state to restore later rng_state = torch.get_rng_state() - cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None - - with torch.no_grad(): - # with accelerator.autocast(): - for i, prompt_dict in enumerate(prompts): - if not accelerator.is_main_process: - continue - - if isinstance(prompt_dict, str): - prompt_dict = line_to_prompt_dict(prompt_dict) - - assert isinstance(prompt_dict, dict) - negative_prompt = prompt_dict.get("negative_prompt") - sample_steps = prompt_dict.get("sample_steps", 30) - width = prompt_dict.get("width", 512) - height = prompt_dict.get("height", 512) - scale = prompt_dict.get("scale", 7.5) - seed = prompt_dict.get("seed") - controlnet_image = prompt_dict.get("controlnet_image") - prompt: str = prompt_dict.get("prompt", "") - sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) - - if seed is not None: - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - - scheduler = schedulers.get(sampler_name) - if scheduler is None: - scheduler = get_my_scheduler( - sample_sampler=sampler_name, - v_parameterization=args.v_parameterization, - ) - schedulers[sampler_name] = scheduler - pipeline.scheduler = scheduler - - if prompt_replacement is not None: - prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) - if negative_prompt is not None: - negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) - - if controlnet_image is not None: - controlnet_image = Image.open(controlnet_image).convert("RGB") - controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) - - height = max(64, height - height % 8) # round to divisible by 8 - width = max(64, width - width % 8) # round to divisible by 8 - print(f"prompt: {prompt}") - print(f"negative_prompt: {negative_prompt}") - print(f"height: {height}") - print(f"width: {width}") - print(f"sample_steps: {sample_steps}") - print(f"scale: {scale}") - print(f"sample_sampler: {sampler_name}") - if seed is not None: - print(f"seed: {seed}") - with accelerator.autocast(): - latents = pipeline( - prompt=prompt, - height=height, - width=width, - num_inference_steps=sample_steps, - guidance_scale=scale, - negative_prompt=negative_prompt, - controlnet=controlnet, - controlnet_image=controlnet_image, - ) - - image = pipeline.latents_to_image(latents)[0] - - ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" - seed_suffix = "" if seed is None else f"_{seed}" - img_filename = ( - f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png" - ) - - image.save(os.path.join(save_dir, img_filename)) - - # wandb有効時のみログを送信 - try: - wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") - - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None # TODO mps etc. support + + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + with torch.no_grad(): + for prompt_dict in prompts: + sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i::distributed_state.num_processes]) + + with torch.no_grad(): + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet) # clear pipeline and cache to reduce vram usage del pipeline - torch.cuda.empty_cache() + with torch.cuda.device(torch.cuda.current_device()): + torch.cuda.empty_cache() + torch.set_rng_state(rng_state) if cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) +def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=None): + assert isinstance(prompt_dict, dict) + negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 30) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 7.5) + seed = prompt_dict.get("seed") + controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() + + scheduler = get_my_scheduler( + sample_sampler=sampler_name, + v_parameterization=args.v_parameterization, + ) + pipeline.scheduler = scheduler + + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + print(f"\nprompt: {prompt}") + print(f"negative_prompt: {negative_prompt}") + print(f"height: {height}") + print(f"width: {width}") + print(f"sample_steps: {sample_steps}") + print(f"scale: {scale}") + print(f"sample_sampler: {sampler_name}") + if seed is not None: + print(f"seed: {seed}") + with accelerator.autocast(): + latents = pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=sample_steps, + guidance_scale=scale, + negative_prompt=negative_prompt, + controlnet=controlnet, + controlnet_image=controlnet_image, + ) + + with torch.cuda.device(torch.cuda.current_device()): + torch.cuda.empty_cache() + + image = pipeline.latents_to_image(latents)[0] + + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = ( + f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + ) + image.save(os.path.join(save_dir, img_filename)) + + # wandb有効時のみログを送信 + try: + wandb_tracker = accelerator.get_tracker("wandb") + try: + import wandb + except ImportError: # 事前に一度確認するのでここはエラー出ないはず + raise ImportError("No wandb / wandb がインストールされていないようです") + + wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) + except: # wandb 無効時 + pass # endregion + + + # region 前処理用 diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index 16e813b3..ee2dae30 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -1,4 +1,3 @@ -import math import argparse import os import time @@ -6,8 +5,6 @@ import torch from safetensors.torch import load_file, save_file from tqdm import tqdm from library import sai_model_spec, train_util -import library.model_util as model_util -import lora CLAMP_QUANTILE = 0.99