reduce VRAM usage, instead of increasing main RAM usage

This commit is contained in:
Kohya S
2024-10-27 10:19:05 +09:00
parent 56b4ea963e
commit ca44e3e447
2 changed files with 14 additions and 4 deletions

View File

@@ -137,6 +137,13 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
## Change History ## Change History
### Oct 27, 2024 / 2024-10-27:
- `svd_merge_lora.py` VRAM usage has been reduced. However, main memory usage will increase (32GB is sufficient).
- This will be included in the next release.
- `svd_merge_lora.py` のVRAM使用量を削減しました。ただし、メインメモリの使用量は増加します32GBあれば十分です
- これは次回リリースに含まれます。
### Oct 26, 2024 / 2024-10-26: ### Oct 26, 2024 / 2024-10-26:
- Fixed a bug in `svd_merge_lora.py`, `sdxl_merge_lora.py`, and `resize_lora.py` where the hash value of LoRA metadata was not correctly calculated when the `save_precision` was different from the `precision` used in the calculation. See issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) for details. Thanks to JujoHotaru for raising the issue. - Fixed a bug in `svd_merge_lora.py`, `sdxl_merge_lora.py`, and `resize_lora.py` where the hash value of LoRA metadata was not correctly calculated when the `save_precision` was different from the `precision` used in the calculation. See issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) for details. Thanks to JujoHotaru for raising the issue.

View File

@@ -301,10 +301,10 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
# make original weight if not exist # make original weight if not exist
if lora_module_name not in merged_sd: if lora_module_name not in merged_sd:
weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype) weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
if device:
weight = weight.to(device)
else: else:
weight = merged_sd[lora_module_name] weight = merged_sd[lora_module_name]
if device:
weight = weight.to(device)
# merge to weight # merge to weight
if device: if device:
@@ -336,13 +336,16 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
weight = weight + ratio * conved * scale weight = weight + ratio * conved * scale
merged_sd[lora_module_name] = weight merged_sd[lora_module_name] = weight.to("cpu")
# extract from merged weights # extract from merged weights
logger.info("extract new lora...") logger.info("extract new lora...")
merged_lora_sd = {} merged_lora_sd = {}
with torch.no_grad(): with torch.no_grad():
for lora_module_name, mat in tqdm(list(merged_sd.items())): for lora_module_name, mat in tqdm(list(merged_sd.items())):
if device:
mat = mat.to(device)
conv2d = len(mat.size()) == 4 conv2d = len(mat.size()) == 4
kernel_size = None if not conv2d else mat.size()[2:4] kernel_size = None if not conv2d else mat.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1) conv2d_3x3 = conv2d and kernel_size != (1, 1)
@@ -381,7 +384,7 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous() merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous() merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank) merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank, device="cpu")
# build minimum metadata # build minimum metadata
dims = f"{new_rank}" dims = f"{new_rank}"