Merge pull request #844 from kohya-ss/dev

IPEX update, concat LoRA
This commit is contained in:
Kohya S
2023-10-01 12:06:36 +09:00
committed by GitHub
5 changed files with 94 additions and 30 deletions

View File

@@ -144,7 +144,7 @@ def ipex_init(): # pylint: disable=too-many-statements
ipex._C._DeviceProperties.minor = 2
#Fix functions with ipex:
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_allocated(device)), torch.xpu.get_device_properties(device).total_memory]
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
torch._utils._get_available_device_type = lambda: "xpu"
torch.has_cuda = True
torch.cuda.has_half = True
@@ -156,7 +156,6 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.get_device_properties.minor = 7
torch.cuda.ipc_collect = lambda *args, **kwargs: None
torch.cuda.utilization = lambda *args, **kwargs: 0
# getDeviceIdListForCard is renamed since https://github.com/intel/intel-extension-for-pytorch/commit/835b41fd5c8b6facf9efee8312f20699850ee592
if hasattr(torch.xpu, 'getDeviceIdListForCard'):
torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard
torch.cuda.get_device_id_list_per_card = torch.xpu.getDeviceIdListForCard

View File

@@ -10,13 +10,15 @@ def torch_bmm(input, mat2, *, out=None):
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
block_multiply = 2.4 if input.dtype == torch.float32 else 1.2
block_size = (batch_size_attention * input_tokens * mat2_shape) / 1024 * block_multiply #MB
block_multiply = input.element_size()
slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
if block_size >= 4000:
if block_size > 4:
do_split = True
#Find something divisible with the input_tokens
while ((split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply) > 4000:
while (split_slice_size * slice_block_size) > 4:
split_slice_size = split_slice_size // 2
if split_slice_size <= 1:
split_slice_size = 1
@@ -24,12 +26,12 @@ def torch_bmm(input, mat2, *, out=None):
else:
do_split = False
split_block_size = (split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply #MB
split_2_slice_size = input_tokens
if split_block_size >= 4000:
if split_slice_size * slice_block_size > 4:
slice_block_size2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply
do_split_2 = True
#Find something divisible with the input_tokens
while ((split_slice_size * split_2_slice_size * mat2_shape) / 1024 * block_multiply) > 4000:
while (split_2_slice_size * slice_block_size2) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
@@ -71,13 +73,16 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
else:
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
no_shape_one = False
block_multiply = 3.6 if query.dtype == torch.float32 else 1.8
block_size = (shape_one * batch_size_attention * query_tokens * shape_four) / 1024 * block_multiply #MB
block_multiply = query.element_size()
slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
if block_size >= 4000:
if block_size > 4:
do_split = True
#Find something divisible with the shape_one
while ((shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply) > 4000:
while (split_slice_size * slice_block_size) > 4:
split_slice_size = split_slice_size // 2
if split_slice_size <= 1:
split_slice_size = 1
@@ -85,12 +90,12 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
else:
do_split = False
split_block_size = (shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply #MB
split_2_slice_size = query_tokens
if split_block_size >= 4000:
if split_slice_size * slice_block_size > 4:
slice_block_size2 = shape_one * split_slice_size * shape_four / 1024 / 1024 * block_multiply
do_split_2 = True
#Find something divisible with the batch_size_attention
while ((shape_one * split_slice_size * split_2_slice_size * shape_four) / 1024 * block_multiply) > 4000:
while (split_2_slice_size * slice_block_size2) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1

View File

@@ -55,13 +55,14 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
)
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
block_multiply = 2.4 if query.dtype == torch.float32 else 1.2
block_size = (batch_size_attention * query_tokens * shape_three) / 1024 * block_multiply #MB
block_multiply = query.element_size()
slice_block_size = self.slice_size * shape_three / 1024 / 1024 * block_multiply
block_size = query_tokens * slice_block_size
split_2_slice_size = query_tokens
if block_size >= 4000:
if block_size > 4:
do_split_2 = True
#Find something divisible with the query_tokens
while ((self.slice_size * split_2_slice_size * shape_three) / 1024 * block_multiply) > 4000:
while (split_2_slice_size * slice_block_size) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1

View File

@@ -110,7 +110,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
module.weight = torch.nn.Parameter(weight)
def merge_lora_models(models, ratios, merge_dtype):
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
base_alphas = {} # alpha for merged model
base_dims = {}
@@ -158,6 +158,12 @@ def merge_lora_models(models, ratios, merge_dtype):
for key in lora_sd.keys():
if "alpha" in key:
continue
if "lora_up" in key and concat:
concat_dim = 1
elif "lora_down" in key and concat:
concat_dim = 0
else:
concat_dim = None
lora_module_name = key[: key.rfind(".lora_")]
@@ -165,12 +171,16 @@ def merge_lora_models(models, ratios, merge_dtype):
alpha = alphas[lora_module_name]
scale = math.sqrt(alpha / base_alpha) * ratio
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
if key in merged_sd:
assert (
merged_sd[key].size() == lora_sd[key].size()
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
if concat_dim is not None:
merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
else:
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
else:
merged_sd[key] = lora_sd[key] * scale
@@ -178,6 +188,13 @@ def merge_lora_models(models, ratios, merge_dtype):
for lora_module_name, alpha in base_alphas.items():
key = lora_module_name + ".alpha"
merged_sd[key] = torch.tensor(alpha)
if shuffle:
key_down = lora_module_name + ".lora_down.weight"
key_up = lora_module_name + ".lora_up.weight"
dim = merged_sd[key_down].shape[0]
perm = torch.randperm(dim)
merged_sd[key_down] = merged_sd[key_down][perm]
merged_sd[key_up] = merged_sd[key_up][:,perm]
print("merged model")
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
@@ -256,7 +273,7 @@ def merge(args):
args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae
)
else:
state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype)
state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
print(f"calculating hashes and creating metadata...")
@@ -317,7 +334,19 @@ def setup_parser() -> argparse.ArgumentParser:
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しないLoRAの最低限のss_metadataは保存される",
)
parser.add_argument(
"--concat",
action="store_true",
help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
+ "マージの代わりに結合するLoRAのdim(rank)は入力dimの合計になる",
)
parser.add_argument(
"--shuffle",
action="store_true",
help="shuffle lora weight./ "
+ "LoRAの重みをシャッフルする",
)
return parser

View File

@@ -113,7 +113,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
module.weight = torch.nn.Parameter(weight)
def merge_lora_models(models, ratios, merge_dtype):
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
base_alphas = {} # alpha for merged model
base_dims = {}
@@ -161,6 +161,13 @@ def merge_lora_models(models, ratios, merge_dtype):
for key in tqdm(lora_sd.keys()):
if "alpha" in key:
continue
if "lora_up" in key and concat:
concat_dim = 1
elif "lora_down" in key and concat:
concat_dim = 0
else:
concat_dim = None
lora_module_name = key[: key.rfind(".lora_")]
@@ -168,12 +175,16 @@ def merge_lora_models(models, ratios, merge_dtype):
alpha = alphas[lora_module_name]
scale = math.sqrt(alpha / base_alpha) * ratio
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
if key in merged_sd:
assert (
merged_sd[key].size() == lora_sd[key].size()
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
if concat_dim is not None:
merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
else:
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
else:
merged_sd[key] = lora_sd[key] * scale
@@ -181,6 +192,13 @@ def merge_lora_models(models, ratios, merge_dtype):
for lora_module_name, alpha in base_alphas.items():
key = lora_module_name + ".alpha"
merged_sd[key] = torch.tensor(alpha)
if shuffle:
key_down = lora_module_name + ".lora_down.weight"
key_up = lora_module_name + ".lora_up.weight"
dim = merged_sd[key_down].shape[0]
perm = torch.randperm(dim)
merged_sd[key_down] = merged_sd[key_down][perm]
merged_sd[key_up] = merged_sd[key_up][:,perm]
print("merged model")
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
@@ -252,7 +270,7 @@ def merge(args):
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
)
else:
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype)
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
print(f"calculating hashes and creating metadata...")
@@ -307,6 +325,18 @@ def setup_parser() -> argparse.ArgumentParser:
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しないLoRAの最低限のss_metadataは保存される",
)
parser.add_argument(
"--concat",
action="store_true",
help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
+ "マージの代わりに結合するLoRAのdim(rank)は入力dimの合計になる",
)
parser.add_argument(
"--shuffle",
action="store_true",
help="shuffle lora weight./ "
+ "LoRAの重みをシャッフルする",
)
return parser