Compare commits

...

6 Commits

Author SHA1 Message Date
ykume
ae0872ba3b search block-wise application weights 2024-02-04 14:56:31 +09:00
Kohya S
7f948db158 Merge pull request #1087 from mgz-dev/fix-imports-on-svd_merge_lora
fix broken import in svd_merge_lora script
2024-01-31 21:08:40 +09:00
Kohya S
9d7729c00d Merge pull request #1086 from Disty0/dev
Update IPEX Libs
2024-01-31 21:06:34 +09:00
Disty0
988dee02b9 IPEX torch.tensor FP64 workaround 2024-01-30 01:52:32 +03:00
mgz
d4b9568269 fix broken import in svd_merge_lora script
remove missing import, and remove unused imports
2024-01-28 11:59:07 -06:00
Disty0
ccc3a481e7 Update IPEX Libs 2024-01-28 14:14:31 +03:00
8 changed files with 1193 additions and 23 deletions

View File

@@ -1,3 +1,39 @@
## LoRAの層別適用率の探索について
層別適用率を探索する `train_network_appl_weights.py` を追加してあります。現在は SDXL のみ対応しています。
LoRA 等の学習済みネットワークに対して、層別適用率を変化させながら通常の学習プロセスを実行することで、適用率を探索します。つまり、どのような層別適用率を適用すると、学習データに近い画像が生成されるかを探索することができます。
層別適用率の合計をペナルティとすることが可能です。つまり、画像を再現しつつ、影響の少ない層の適用率が低くなるような適用率が探索できるはずです。
複数のネットワークを対象に探索できます。また探索には最低 1 枚の学習データが必要になります。
何枚程度から正しく動くかは確認していません。50枚程度の画像でテスト済みです。また学習データは LoRA 学習時のデータでなくてもよいはずですが、未確認です。)
コマンドラインオプションは `sdxl_train_network.py` とほぼ同じですが、以下のオプションが追加、拡張されています。
- `--application_loss_weight` : 層別適用率を loss に加える際の重みです。デフォルトは 0.0001 です。大きくすると、なるべく適用率を低くするように学習します。0 を指定するとペナルティが適用されないため、再現度が最も高くなる適用率を自由に探索します。
- `--network_module` : 探索対象の複数のモジュールを指定することができます。たとえば `--network_module networks.lora networks.lora` のように指定します。
- `--network_weights` : 探索対象の複数のネットワークの重みを指定することができます。たとえば `--network_weights model1.safetensors model2.safetensors` のように指定します。
層別適用率のパラメータ数は 20個で、`BASE, IN00-08, MID, OUT00-08` となります。`BASE` は Text Encoder に適用されます。Text Encoder を対象とした LoRA の動作は未確認です。)
パラメータは一応ファイルに保存されますが、画面に表示される値をコピーして保存することをお勧めします。
### 備考
オプティマイザ AdamW、学習率 1e-1 で動作確認しています。学習率はかなり高めに設定してよいようです。この設定では LoRA 学習時の 1/20 ~ 1/10 ほどの epoch 数でそれなりの結果が得られます。
`application_loss_weight` を 0.0001 より大きくすると合計の適用率がかなり低くなるLoRA があまり適用されない)ようです。条件にもよると思いますので、適宜調整してください。
適用率に負の値を使うと、影響の少ない層の適用率を極端に低くして合計を小さくする、という動きをしてしまうので、負の値は10倍の重み付けをしてあります-0.01 は 0.1 とほぼ同じペナルティ)。重み付けを変更するときはソースを修正してください。
「必要ない層への適用率を下げて影響範囲を小さくする」という使い方だけでなく、「あるキャラクターがあるポーズをしている画像を教師データに、キャラクターを維持しつつポーズを取るための LoRA の適用率を探索する」、「ある画風のあるキャラクターの画像を教師データに、画風 LoRA とキャラクター LoRA の適用率を探索する」などの使い方が考えられます。
もしかすると、「あるキャラクターの、あえて別の画風の画像を教師データに、キャラクターの属性を再現するのに必要な層を探す」、「理想とする画像を教師データに、使えそうな LoRA を多数適用し、その中から最も再現度が高い適用率を探す(ただし LoRA の数が多いほど学習が遅くなります)」といった使い方もできるかもしれません。
---
__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training). __SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training).
This repository contains training, generation and utility scripts for Stable Diffusion. This repository contains training, generation and utility scripts for Stable Diffusion.

View File

@@ -125,9 +125,13 @@ def ipex_init(): # pylint: disable=too-many-statements
# AMP: # AMP:
torch.cuda.amp = torch.xpu.amp torch.cuda.amp = torch.xpu.amp
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
if not hasattr(torch.cuda.amp, "common"): if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = contextlib.nullcontext() torch.cuda.amp.common = contextlib.nullcontext()
torch.cuda.amp.common.amp_definitely_not_available = lambda: False torch.cuda.amp.common.amp_definitely_not_available = lambda: False
try: try:
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught except Exception: # pylint: disable=broad-exception-caught
@@ -151,15 +155,16 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.has_half = True torch.cuda.has_half = True
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
torch.version.cuda = "11.7" torch.backends.cuda.is_built = lambda *args, **kwargs: True
torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7] torch.version.cuda = "12.1"
torch.cuda.get_device_properties.major = 11 torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
torch.cuda.get_device_properties.minor = 7 torch.cuda.get_device_properties.major = 12
torch.cuda.get_device_properties.minor = 1
torch.cuda.ipc_collect = lambda *args, **kwargs: None torch.cuda.ipc_collect = lambda *args, **kwargs: None
torch.cuda.utilization = lambda *args, **kwargs: 0 torch.cuda.utilization = lambda *args, **kwargs: 0
ipex_hijacks() ipex_hijacks()
if not torch.xpu.has_fp64_dtype(): if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
try: try:
from .diffusers import ipex_diffusers from .diffusers import ipex_diffusers
ipex_diffusers() ipex_diffusers()

View File

@@ -124,6 +124,7 @@ def torch_bmm_32_bit(input, mat2, *, out=None):
) )
else: else:
return original_torch_bmm(input, mat2, out=out) return original_torch_bmm(input, mat2, out=out)
torch.xpu.synchronize(input.device)
return hidden_states return hidden_states
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
@@ -172,4 +173,5 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
) )
else: else:
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
torch.xpu.synchronize(query.device)
return hidden_states return hidden_states

View File

@@ -149,6 +149,7 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
del attn_slice del attn_slice
torch.xpu.synchronize(query.device)
else: else:
query_slice = query[start_idx:end_idx] query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx] key_slice = key[start_idx:end_idx]
@@ -283,6 +284,7 @@ class AttnProcessor:
hidden_states[start_idx:end_idx] = attn_slice hidden_states[start_idx:end_idx] = attn_slice
del attn_slice del attn_slice
torch.xpu.synchronize(query.device)
else: else:
attention_probs = attn.get_attention_scores(query, key, attention_mask) attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value) hidden_states = torch.bmm(attention_probs, value)

View File

@@ -1,6 +1,11 @@
import contextlib import os
from functools import wraps
from contextlib import nullcontext
import torch import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import numpy as np
device_supports_fp64 = torch.xpu.has_fp64_dtype()
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
@@ -11,7 +16,7 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstr
return module.to("xpu") return module.to("xpu")
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
return contextlib.nullcontext() return nullcontext()
@property @property
def is_cuda(self): def is_cuda(self):
@@ -25,15 +30,17 @@ def return_xpu(device):
# Autocast # Autocast
original_autocast = torch.autocast original_autocast_init = torch.amp.autocast_mode.autocast.__init__
def ipex_autocast(*args, **kwargs): @wraps(torch.amp.autocast_mode.autocast.__init__)
if len(args) > 0 and args[0] == "cuda": def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
return original_autocast("xpu", *args[1:], **kwargs) if device_type == "cuda":
return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
else: else:
return original_autocast(*args, **kwargs) return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
# Latent Antialias CPU Offload: # Latent Antialias CPU Offload:
original_interpolate = torch.nn.functional.interpolate original_interpolate = torch.nn.functional.interpolate
@wraps(torch.nn.functional.interpolate)
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
if antialias or align_corners is not None: if antialias or align_corners is not None:
return_device = tensor.device return_device = tensor.device
@@ -44,15 +51,29 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode, return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias) align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
original_from_numpy = torch.from_numpy original_from_numpy = torch.from_numpy
@wraps(torch.from_numpy)
def from_numpy(ndarray): def from_numpy(ndarray):
if ndarray.dtype == float: if ndarray.dtype == float:
return original_from_numpy(ndarray.astype('float32')) return original_from_numpy(ndarray.astype('float32'))
else: else:
return original_from_numpy(ndarray) return original_from_numpy(ndarray)
if torch.xpu.has_fp64_dtype(): original_as_tensor = torch.as_tensor
@wraps(torch.as_tensor)
def as_tensor(data, dtype=None, device=None):
if check_device(device):
device = return_xpu(device)
if isinstance(data, np.ndarray) and data.dtype == float and not (
(isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
return original_as_tensor(data, dtype=torch.float32, device=device)
else:
return original_as_tensor(data, dtype=dtype, device=device)
if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
original_torch_bmm = torch.bmm original_torch_bmm = torch.bmm
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
else: else:
@@ -66,20 +87,25 @@ else:
# Data Type Errors: # Data Type Errors:
@wraps(torch.bmm)
def torch_bmm(input, mat2, *, out=None): def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype: if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype) mat2 = mat2.to(input.dtype)
return original_torch_bmm(input, mat2, out=out) return original_torch_bmm(input, mat2, out=out)
@wraps(torch.nn.functional.scaled_dot_product_attention)
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
if query.dtype != key.dtype: if query.dtype != key.dtype:
key = key.to(dtype=query.dtype) key = key.to(dtype=query.dtype)
if query.dtype != value.dtype: if query.dtype != value.dtype:
value = value.to(dtype=query.dtype) value = value.to(dtype=query.dtype)
if attn_mask is not None and query.dtype != attn_mask.dtype:
attn_mask = attn_mask.to(dtype=query.dtype)
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
# A1111 FP16 # A1111 FP16
original_functional_group_norm = torch.nn.functional.group_norm original_functional_group_norm = torch.nn.functional.group_norm
@wraps(torch.nn.functional.group_norm)
def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05): def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
if weight is not None and input.dtype != weight.data.dtype: if weight is not None and input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype) input = input.to(dtype=weight.data.dtype)
@@ -89,6 +115,7 @@ def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
# A1111 BF16 # A1111 BF16
original_functional_layer_norm = torch.nn.functional.layer_norm original_functional_layer_norm = torch.nn.functional.layer_norm
@wraps(torch.nn.functional.layer_norm)
def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05): def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
if weight is not None and input.dtype != weight.data.dtype: if weight is not None and input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype) input = input.to(dtype=weight.data.dtype)
@@ -98,6 +125,7 @@ def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1
# Training # Training
original_functional_linear = torch.nn.functional.linear original_functional_linear = torch.nn.functional.linear
@wraps(torch.nn.functional.linear)
def functional_linear(input, weight, bias=None): def functional_linear(input, weight, bias=None):
if input.dtype != weight.data.dtype: if input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype) input = input.to(dtype=weight.data.dtype)
@@ -106,6 +134,7 @@ def functional_linear(input, weight, bias=None):
return original_functional_linear(input, weight, bias=bias) return original_functional_linear(input, weight, bias=bias)
original_functional_conv2d = torch.nn.functional.conv2d original_functional_conv2d = torch.nn.functional.conv2d
@wraps(torch.nn.functional.conv2d)
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if input.dtype != weight.data.dtype: if input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype) input = input.to(dtype=weight.data.dtype)
@@ -115,6 +144,7 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1,
# A1111 Embedding BF16 # A1111 Embedding BF16
original_torch_cat = torch.cat original_torch_cat = torch.cat
@wraps(torch.cat)
def torch_cat(tensor, *args, **kwargs): def torch_cat(tensor, *args, **kwargs):
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
@@ -123,6 +153,7 @@ def torch_cat(tensor, *args, **kwargs):
# SwinIR BF16: # SwinIR BF16:
original_functional_pad = torch.nn.functional.pad original_functional_pad = torch.nn.functional.pad
@wraps(torch.nn.functional.pad)
def functional_pad(input, pad, mode='constant', value=None): def functional_pad(input, pad, mode='constant', value=None):
if mode == 'reflect' and input.dtype == torch.bfloat16: if mode == 'reflect' and input.dtype == torch.bfloat16:
return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16) return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
@@ -131,13 +162,20 @@ def functional_pad(input, pad, mode='constant', value=None):
original_torch_tensor = torch.tensor original_torch_tensor = torch.tensor
def torch_tensor(*args, device=None, **kwargs): @wraps(torch.tensor)
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
if check_device(device): if check_device(device):
return original_torch_tensor(*args, device=return_xpu(device), **kwargs) device = return_xpu(device)
else: if not device_supports_fp64:
return original_torch_tensor(*args, device=device, **kwargs) if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
if dtype == torch.float64:
dtype = torch.float32
elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
dtype = torch.float32
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
original_Tensor_to = torch.Tensor.to original_Tensor_to = torch.Tensor.to
@wraps(torch.Tensor.to)
def Tensor_to(self, device=None, *args, **kwargs): def Tensor_to(self, device=None, *args, **kwargs):
if check_device(device): if check_device(device):
return original_Tensor_to(self, return_xpu(device), *args, **kwargs) return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
@@ -145,6 +183,7 @@ def Tensor_to(self, device=None, *args, **kwargs):
return original_Tensor_to(self, device, *args, **kwargs) return original_Tensor_to(self, device, *args, **kwargs)
original_Tensor_cuda = torch.Tensor.cuda original_Tensor_cuda = torch.Tensor.cuda
@wraps(torch.Tensor.cuda)
def Tensor_cuda(self, device=None, *args, **kwargs): def Tensor_cuda(self, device=None, *args, **kwargs):
if check_device(device): if check_device(device):
return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs) return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
@@ -152,6 +191,7 @@ def Tensor_cuda(self, device=None, *args, **kwargs):
return original_Tensor_cuda(self, device, *args, **kwargs) return original_Tensor_cuda(self, device, *args, **kwargs)
original_UntypedStorage_init = torch.UntypedStorage.__init__ original_UntypedStorage_init = torch.UntypedStorage.__init__
@wraps(torch.UntypedStorage.__init__)
def UntypedStorage_init(*args, device=None, **kwargs): def UntypedStorage_init(*args, device=None, **kwargs):
if check_device(device): if check_device(device):
return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs) return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
@@ -159,6 +199,7 @@ def UntypedStorage_init(*args, device=None, **kwargs):
return original_UntypedStorage_init(*args, device=device, **kwargs) return original_UntypedStorage_init(*args, device=device, **kwargs)
original_UntypedStorage_cuda = torch.UntypedStorage.cuda original_UntypedStorage_cuda = torch.UntypedStorage.cuda
@wraps(torch.UntypedStorage.cuda)
def UntypedStorage_cuda(self, device=None, *args, **kwargs): def UntypedStorage_cuda(self, device=None, *args, **kwargs):
if check_device(device): if check_device(device):
return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs) return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
@@ -166,6 +207,7 @@ def UntypedStorage_cuda(self, device=None, *args, **kwargs):
return original_UntypedStorage_cuda(self, device, *args, **kwargs) return original_UntypedStorage_cuda(self, device, *args, **kwargs)
original_torch_empty = torch.empty original_torch_empty = torch.empty
@wraps(torch.empty)
def torch_empty(*args, device=None, **kwargs): def torch_empty(*args, device=None, **kwargs):
if check_device(device): if check_device(device):
return original_torch_empty(*args, device=return_xpu(device), **kwargs) return original_torch_empty(*args, device=return_xpu(device), **kwargs)
@@ -173,6 +215,7 @@ def torch_empty(*args, device=None, **kwargs):
return original_torch_empty(*args, device=device, **kwargs) return original_torch_empty(*args, device=device, **kwargs)
original_torch_randn = torch.randn original_torch_randn = torch.randn
@wraps(torch.randn)
def torch_randn(*args, device=None, **kwargs): def torch_randn(*args, device=None, **kwargs):
if check_device(device): if check_device(device):
return original_torch_randn(*args, device=return_xpu(device), **kwargs) return original_torch_randn(*args, device=return_xpu(device), **kwargs)
@@ -180,6 +223,7 @@ def torch_randn(*args, device=None, **kwargs):
return original_torch_randn(*args, device=device, **kwargs) return original_torch_randn(*args, device=device, **kwargs)
original_torch_ones = torch.ones original_torch_ones = torch.ones
@wraps(torch.ones)
def torch_ones(*args, device=None, **kwargs): def torch_ones(*args, device=None, **kwargs):
if check_device(device): if check_device(device):
return original_torch_ones(*args, device=return_xpu(device), **kwargs) return original_torch_ones(*args, device=return_xpu(device), **kwargs)
@@ -187,6 +231,7 @@ def torch_ones(*args, device=None, **kwargs):
return original_torch_ones(*args, device=device, **kwargs) return original_torch_ones(*args, device=device, **kwargs)
original_torch_zeros = torch.zeros original_torch_zeros = torch.zeros
@wraps(torch.zeros)
def torch_zeros(*args, device=None, **kwargs): def torch_zeros(*args, device=None, **kwargs):
if check_device(device): if check_device(device):
return original_torch_zeros(*args, device=return_xpu(device), **kwargs) return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
@@ -194,6 +239,7 @@ def torch_zeros(*args, device=None, **kwargs):
return original_torch_zeros(*args, device=device, **kwargs) return original_torch_zeros(*args, device=device, **kwargs)
original_torch_linspace = torch.linspace original_torch_linspace = torch.linspace
@wraps(torch.linspace)
def torch_linspace(*args, device=None, **kwargs): def torch_linspace(*args, device=None, **kwargs):
if check_device(device): if check_device(device):
return original_torch_linspace(*args, device=return_xpu(device), **kwargs) return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
@@ -201,6 +247,7 @@ def torch_linspace(*args, device=None, **kwargs):
return original_torch_linspace(*args, device=device, **kwargs) return original_torch_linspace(*args, device=device, **kwargs)
original_torch_Generator = torch.Generator original_torch_Generator = torch.Generator
@wraps(torch.Generator)
def torch_Generator(device=None): def torch_Generator(device=None):
if check_device(device): if check_device(device):
return original_torch_Generator(return_xpu(device)) return original_torch_Generator(return_xpu(device))
@@ -208,12 +255,14 @@ def torch_Generator(device=None):
return original_torch_Generator(device) return original_torch_Generator(device)
original_torch_load = torch.load original_torch_load = torch.load
@wraps(torch.load)
def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs): def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs):
if check_device(map_location): if check_device(map_location):
return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
else: else:
return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
# Hijack Functions: # Hijack Functions:
def ipex_hijacks(): def ipex_hijacks():
torch.tensor = torch_tensor torch.tensor = torch_tensor
@@ -232,7 +281,7 @@ def ipex_hijacks():
torch.backends.cuda.sdp_kernel = return_null_context torch.backends.cuda.sdp_kernel = return_null_context
torch.nn.DataParallel = DummyDataParallel torch.nn.DataParallel = DummyDataParallel
torch.UntypedStorage.is_cuda = is_cuda torch.UntypedStorage.is_cuda = is_cuda
torch.autocast = ipex_autocast torch.amp.autocast_mode.autocast.__init__ = autocast_init
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
torch.nn.functional.group_norm = functional_group_norm torch.nn.functional.group_norm = functional_group_norm
@@ -244,5 +293,6 @@ def ipex_hijacks():
torch.bmm = torch_bmm torch.bmm = torch_bmm
torch.cat = torch_cat torch.cat = torch_cat
if not torch.xpu.has_fp64_dtype(): if not device_supports_fp64:
torch.from_numpy = from_numpy torch.from_numpy = from_numpy
torch.as_tensor = as_tensor

View File

@@ -511,7 +511,9 @@ def get_block_dims_and_alphas(
len(block_dims) == num_total_blocks len(block_dims) == num_total_blocks
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
else: else:
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") print(
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
)
block_dims = [network_dim] * num_total_blocks block_dims = [network_dim] * num_total_blocks
if block_alphas is not None: if block_alphas is not None:
@@ -1223,3 +1225,40 @@ class LoRANetwork(torch.nn.Module):
norms.append(scalednorm.item()) norms.append(scalednorm.item())
return keys_scaled, sum(norms) / len(norms), max(norms) return keys_scaled, sum(norms) / len(norms), max(norms)
# region application weight
def get_number_of_blocks(self):
# only for SDXL
return 20
def has_text_encoder_block(self):
return self.text_encoder_loras is not None and len(self.text_encoder_loras) > 0
def set_block_wise_weights(self, weights):
if self.text_encoder_loras:
for lora in self.text_encoder_loras:
lora.multiplier = weights[0]
for lora in self.unet_loras:
# determine block index
key = lora.lora_name[10:] # remove "lora_unet_"
if key.startswith("input_blocks"):
block_index = int(key.split("_")[2]) + 1 # 1-9
elif key.startswith("middle_block"):
block_index = 10 # int(key.split("_")[2]) + 10
elif key.startswith("output_blocks"):
block_index = int(key.split("_")[2]) + 11 # 11-19
else:
print(f"unknown block: {key}")
block_index = 0
lora.multiplier = weights[block_index]
# print(f"{lora.lora_name} block index: {block_index}, weight: {lora.multiplier}")
# print(f"set block-wise weights to {weights}")
# TODO LoRA の weight をあらかじめ計算しておいて multiplier を掛けるだけにすると速くなるはず
# endregion

View File

@@ -1,4 +1,3 @@
import math
import argparse import argparse
import os import os
import time import time
@@ -6,8 +5,6 @@ import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from tqdm import tqdm from tqdm import tqdm
from library import sai_model_spec, train_util from library import sai_model_spec, train_util
import library.model_util as model_util
import lora
CLAMP_QUANTILE = 0.99 CLAMP_QUANTILE = 0.99

File diff suppressed because it is too large Load Diff