Add options to reduce memory usage in extract_lora_from_models.py closes #1059

This commit is contained in:
Kohya S
2024-01-20 18:45:54 +09:00
parent fef172966f
commit c59249a664
2 changed files with 79 additions and 11 deletions

View File

@@ -262,6 +262,9 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
- For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate. - For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate.
- Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate. - Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate.
- Please specify `network_multiplier` in `[[datasets]]` in `.toml` file. - Please specify `network_multiplier` in `[[datasets]]` in `.toml` file.
- Some options are added to `networks/extract_lora_from_models.py` to reduce the memory usage.
- `--load_precision` option can be used to specify the precision when loading the model. If the model is saved in fp16, you can reduce the memory usage by specifying `--load_precision fp16` without losing precision.
- `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU.
- 実験的 LoRA等の学習スクリプトで、ベースモデルU-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。 - 実験的 LoRA等の学習スクリプトで、ベースモデルU-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。
- `train_network.py` または `sdxl_train_network.py``--fp8_base` を指定してください。 - `train_network.py` または `sdxl_train_network.py``--fp8_base` を指定してください。
@@ -273,6 +276,10 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
- たとえば状態 A を `1.0`、状態 B を `-1.0` として学習すると、LoRA の適用率に応じて状態 A と B を切り替えつつ生成できるかもしれません。 - たとえば状態 A を `1.0`、状態 B を `-1.0` として学習すると、LoRA の適用率に応じて状態 A と B を切り替えつつ生成できるかもしれません。
- また、五段階の状態を用意し、それぞれ `0.2``0.4``0.6``0.8``1.0` として学習すると、適用率でなめらかに状態を切り替えて生成できるかもしれません。 - また、五段階の状態を用意し、それぞれ `0.2``0.4``0.6``0.8``1.0` として学習すると、適用率でなめらかに状態を切り替えて生成できるかもしれません。
- `.toml` ファイルで `[[datasets]]``network_multiplier` を指定してください。 - `.toml` ファイルで `[[datasets]]``network_multiplier` を指定してください。
- `networks/extract_lora_from_models.py` に使用メモリ量を削減するいくつかのオプションを追加しました。
- `--load_precision` で読み込み時の精度を指定できます。モデルが fp16 で保存されている場合は `--load_precision fp16` を指定して精度を変えずにメモリ量を削減できます。
- `--load_original_model_to` で元モデルを読み込むデバイスを、`--load_tuned_model_to` で派生モデルを読み込むデバイスを指定できます。デフォルトは両方とも `cpu` ですがそれぞれ `cuda` 等を指定できます。片方を GPU に読み込むことでメモリ量を削減できます。
- `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例 - `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例

View File

@@ -43,6 +43,9 @@ def svd(
clamp_quantile=0.99, clamp_quantile=0.99,
min_diff=0.01, min_diff=0.01,
no_metadata=False, no_metadata=False,
load_precision=None,
load_original_model_to=None,
load_tuned_model_to=None,
): ):
def str_to_dtype(p): def str_to_dtype(p):
if p == "float": if p == "float":
@@ -57,28 +60,51 @@ def svd(
if v_parameterization is None: if v_parameterization is None:
v_parameterization = v2 v_parameterization = v2
load_dtype = str_to_dtype(load_precision) if load_precision else None
save_dtype = str_to_dtype(save_precision) save_dtype = str_to_dtype(save_precision)
work_device = "cpu"
# load models # load models
if not sdxl: if not sdxl:
print(f"loading original SD model : {model_org}") print(f"loading original SD model : {model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org) text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
text_encoders_o = [text_encoder_o] text_encoders_o = [text_encoder_o]
if load_dtype is not None:
text_encoder_o = text_encoder_o.to(load_dtype)
unet_o = unet_o.to(load_dtype)
print(f"loading tuned SD model : {model_tuned}") print(f"loading tuned SD model : {model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned) text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
text_encoders_t = [text_encoder_t] text_encoders_t = [text_encoder_t]
if load_dtype is not None:
text_encoder_t = text_encoder_t.to(load_dtype)
unet_t = unet_t.to(load_dtype)
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization) model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
else: else:
device_org = load_original_model_to if load_original_model_to else "cpu"
device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu"
print(f"loading original SDXL model : {model_org}") print(f"loading original SDXL model : {model_org}")
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu" sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
) )
text_encoders_o = [text_encoder_o1, text_encoder_o2] text_encoders_o = [text_encoder_o1, text_encoder_o2]
if load_dtype is not None:
text_encoder_o1 = text_encoder_o1.to(load_dtype)
text_encoder_o2 = text_encoder_o2.to(load_dtype)
unet_o = unet_o.to(load_dtype)
print(f"loading original SDXL model : {model_tuned}") print(f"loading original SDXL model : {model_tuned}")
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu" sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
) )
text_encoders_t = [text_encoder_t1, text_encoder_t2] text_encoders_t = [text_encoder_t1, text_encoder_t2]
if load_dtype is not None:
text_encoder_t1 = text_encoder_t1.to(load_dtype)
text_encoder_t2 = text_encoder_t2.to(load_dtype)
unet_t = unet_t.to(load_dtype)
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0 model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
# create LoRA network to extract weights: Use dim (rank) as alpha # create LoRA network to extract weights: Use dim (rank) as alpha
@@ -100,38 +126,54 @@ def svd(
lora_name = lora_o.lora_name lora_name = lora_o.lora_name
module_o = lora_o.org_module module_o = lora_o.org_module
module_t = lora_t.org_module module_t = lora_t.org_module
diff = module_t.weight - module_o.weight diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
# clear weight to save memory
module_o.weight = None
module_t.weight = None
# Text Encoder might be same # Text Encoder might be same
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff: if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
text_encoder_different = True text_encoder_different = True
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}") print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
diff = diff.float()
diffs[lora_name] = diff diffs[lora_name] = diff
# clear target Text Encoder to save memory
for text_encoder in text_encoders_t:
del text_encoder
if not text_encoder_different: if not text_encoder_different:
print("Text encoder is same. Extract U-Net only.") print("Text encoder is same. Extract U-Net only.")
lora_network_o.text_encoder_loras = [] lora_network_o.text_encoder_loras = []
diffs = {} diffs = {} # clear diffs
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)): for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
lora_name = lora_o.lora_name lora_name = lora_o.lora_name
module_o = lora_o.org_module module_o = lora_o.org_module
module_t = lora_t.org_module module_t = lora_t.org_module
diff = module_t.weight - module_o.weight diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
diff = diff.float()
if args.device: # clear weight to save memory
diff = diff.to(args.device) module_o.weight = None
module_t.weight = None
diffs[lora_name] = diff diffs[lora_name] = diff
# clear LoRA network, target U-Net to save memory
del lora_network_o
del lora_network_t
del unet_t
# make LoRA with svd # make LoRA with svd
print("calculating by svd") print("calculating by svd")
lora_weights = {} lora_weights = {}
with torch.no_grad(): with torch.no_grad():
for lora_name, mat in tqdm(list(diffs.items())): for lora_name, mat in tqdm(list(diffs.items())):
if args.device:
mat = mat.to(args.device)
mat = mat.to(torch.float) # calc by float
# if conv_dim is None, diffs do not include LoRAs for conv2d-3x3 # if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
conv2d = len(mat.size()) == 4 conv2d = len(mat.size()) == 4
kernel_size = None if not conv2d else mat.size()[2:4] kernel_size = None if not conv2d else mat.size()[2:4]
@@ -171,8 +213,8 @@ def svd(
U = U.reshape(out_dim, rank, 1, 1) U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1]) Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
U = U.to("cpu").contiguous() U = U.to(work_device, dtype=save_dtype).contiguous()
Vh = Vh.to("cpu").contiguous() Vh = Vh.to(work_device, dtype=save_dtype).contiguous()
lora_weights[lora_name] = (U, Vh) lora_weights[lora_name] = (U, Vh)
@@ -230,6 +272,13 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument( parser.add_argument(
"--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む" "--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む"
) )
parser.add_argument(
"--load_precision",
type=str,
default=None,
choices=[None, "float", "fp16", "bf16"],
help="precision in loading, model default if omitted / 読み込み時に精度を変更して読み込む、省略時はモデルファイルによる"
)
parser.add_argument( parser.add_argument(
"--save_precision", "--save_precision",
type=str, type=str,
@@ -285,6 +334,18 @@ def setup_parser() -> argparse.ArgumentParser:
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しないLoRAの最低限のss_metadataは保存される", + "sai modelspecのメタデータを保存しないLoRAの最低限のss_metadataは保存される",
) )
parser.add_argument(
"--load_original_model_to",
type=str,
default=None,
help="location to load original model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 元モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
)
parser.add_argument(
"--load_tuned_model_to",
type=str,
default=None,
help="location to load tuned model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 派生モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
)
return parser return parser