mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Add options to reduce memory usage in extract_lora_from_models.py closes #1059
This commit is contained in:
@@ -262,6 +262,9 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
|
|||||||
- For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate.
|
- For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate.
|
||||||
- Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate.
|
- Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate.
|
||||||
- Please specify `network_multiplier` in `[[datasets]]` in `.toml` file.
|
- Please specify `network_multiplier` in `[[datasets]]` in `.toml` file.
|
||||||
|
- Some options are added to `networks/extract_lora_from_models.py` to reduce the memory usage.
|
||||||
|
- `--load_precision` option can be used to specify the precision when loading the model. If the model is saved in fp16, you can reduce the memory usage by specifying `--load_precision fp16` without losing precision.
|
||||||
|
- `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU.
|
||||||
|
|
||||||
- (実験的) LoRA等の学習スクリプトで、ベースモデル(U-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。
|
- (実験的) LoRA等の学習スクリプトで、ベースモデル(U-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。
|
||||||
- `train_network.py` または `sdxl_train_network.py` で `--fp8_base` を指定してください。
|
- `train_network.py` または `sdxl_train_network.py` で `--fp8_base` を指定してください。
|
||||||
@@ -273,6 +276,10 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
|
|||||||
- たとえば状態 A を `1.0`、状態 B を `-1.0` として学習すると、LoRA の適用率に応じて状態 A と B を切り替えつつ生成できるかもしれません。
|
- たとえば状態 A を `1.0`、状態 B を `-1.0` として学習すると、LoRA の適用率に応じて状態 A と B を切り替えつつ生成できるかもしれません。
|
||||||
- また、五段階の状態を用意し、それぞれ `0.2`、`0.4`、`0.6`、`0.8`、`1.0` として学習すると、適用率でなめらかに状態を切り替えて生成できるかもしれません。
|
- また、五段階の状態を用意し、それぞれ `0.2`、`0.4`、`0.6`、`0.8`、`1.0` として学習すると、適用率でなめらかに状態を切り替えて生成できるかもしれません。
|
||||||
- `.toml` ファイルで `[[datasets]]` に `network_multiplier` を指定してください。
|
- `.toml` ファイルで `[[datasets]]` に `network_multiplier` を指定してください。
|
||||||
|
- `networks/extract_lora_from_models.py` に使用メモリ量を削減するいくつかのオプションを追加しました。
|
||||||
|
- `--load_precision` で読み込み時の精度を指定できます。モデルが fp16 で保存されている場合は `--load_precision fp16` を指定して精度を変えずにメモリ量を削減できます。
|
||||||
|
- `--load_original_model_to` で元モデルを読み込むデバイスを、`--load_tuned_model_to` で派生モデルを読み込むデバイスを指定できます。デフォルトは両方とも `cpu` ですがそれぞれ `cuda` 等を指定できます。片方を GPU に読み込むことでメモリ量を削減できます。
|
||||||
|
|
||||||
|
|
||||||
- `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例
|
- `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例
|
||||||
|
|
||||||
|
|||||||
@@ -43,6 +43,9 @@ def svd(
|
|||||||
clamp_quantile=0.99,
|
clamp_quantile=0.99,
|
||||||
min_diff=0.01,
|
min_diff=0.01,
|
||||||
no_metadata=False,
|
no_metadata=False,
|
||||||
|
load_precision=None,
|
||||||
|
load_original_model_to=None,
|
||||||
|
load_tuned_model_to=None,
|
||||||
):
|
):
|
||||||
def str_to_dtype(p):
|
def str_to_dtype(p):
|
||||||
if p == "float":
|
if p == "float":
|
||||||
@@ -57,28 +60,51 @@ def svd(
|
|||||||
if v_parameterization is None:
|
if v_parameterization is None:
|
||||||
v_parameterization = v2
|
v_parameterization = v2
|
||||||
|
|
||||||
|
load_dtype = str_to_dtype(load_precision) if load_precision else None
|
||||||
save_dtype = str_to_dtype(save_precision)
|
save_dtype = str_to_dtype(save_precision)
|
||||||
|
work_device = "cpu"
|
||||||
|
|
||||||
# load models
|
# load models
|
||||||
if not sdxl:
|
if not sdxl:
|
||||||
print(f"loading original SD model : {model_org}")
|
print(f"loading original SD model : {model_org}")
|
||||||
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
|
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
|
||||||
text_encoders_o = [text_encoder_o]
|
text_encoders_o = [text_encoder_o]
|
||||||
|
if load_dtype is not None:
|
||||||
|
text_encoder_o = text_encoder_o.to(load_dtype)
|
||||||
|
unet_o = unet_o.to(load_dtype)
|
||||||
|
|
||||||
print(f"loading tuned SD model : {model_tuned}")
|
print(f"loading tuned SD model : {model_tuned}")
|
||||||
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
|
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
|
||||||
text_encoders_t = [text_encoder_t]
|
text_encoders_t = [text_encoder_t]
|
||||||
|
if load_dtype is not None:
|
||||||
|
text_encoder_t = text_encoder_t.to(load_dtype)
|
||||||
|
unet_t = unet_t.to(load_dtype)
|
||||||
|
|
||||||
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
|
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
|
||||||
else:
|
else:
|
||||||
|
device_org = load_original_model_to if load_original_model_to else "cpu"
|
||||||
|
device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu"
|
||||||
|
|
||||||
print(f"loading original SDXL model : {model_org}")
|
print(f"loading original SDXL model : {model_org}")
|
||||||
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu"
|
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
|
||||||
)
|
)
|
||||||
text_encoders_o = [text_encoder_o1, text_encoder_o2]
|
text_encoders_o = [text_encoder_o1, text_encoder_o2]
|
||||||
|
if load_dtype is not None:
|
||||||
|
text_encoder_o1 = text_encoder_o1.to(load_dtype)
|
||||||
|
text_encoder_o2 = text_encoder_o2.to(load_dtype)
|
||||||
|
unet_o = unet_o.to(load_dtype)
|
||||||
|
|
||||||
print(f"loading original SDXL model : {model_tuned}")
|
print(f"loading original SDXL model : {model_tuned}")
|
||||||
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu"
|
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
|
||||||
)
|
)
|
||||||
text_encoders_t = [text_encoder_t1, text_encoder_t2]
|
text_encoders_t = [text_encoder_t1, text_encoder_t2]
|
||||||
|
if load_dtype is not None:
|
||||||
|
text_encoder_t1 = text_encoder_t1.to(load_dtype)
|
||||||
|
text_encoder_t2 = text_encoder_t2.to(load_dtype)
|
||||||
|
unet_t = unet_t.to(load_dtype)
|
||||||
|
|
||||||
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
|
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
|
||||||
|
|
||||||
# create LoRA network to extract weights: Use dim (rank) as alpha
|
# create LoRA network to extract weights: Use dim (rank) as alpha
|
||||||
@@ -100,38 +126,54 @@ def svd(
|
|||||||
lora_name = lora_o.lora_name
|
lora_name = lora_o.lora_name
|
||||||
module_o = lora_o.org_module
|
module_o = lora_o.org_module
|
||||||
module_t = lora_t.org_module
|
module_t = lora_t.org_module
|
||||||
diff = module_t.weight - module_o.weight
|
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
|
||||||
|
|
||||||
|
# clear weight to save memory
|
||||||
|
module_o.weight = None
|
||||||
|
module_t.weight = None
|
||||||
|
|
||||||
# Text Encoder might be same
|
# Text Encoder might be same
|
||||||
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
|
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
|
||||||
text_encoder_different = True
|
text_encoder_different = True
|
||||||
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
|
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
|
||||||
|
|
||||||
diff = diff.float()
|
|
||||||
diffs[lora_name] = diff
|
diffs[lora_name] = diff
|
||||||
|
|
||||||
|
# clear target Text Encoder to save memory
|
||||||
|
for text_encoder in text_encoders_t:
|
||||||
|
del text_encoder
|
||||||
|
|
||||||
if not text_encoder_different:
|
if not text_encoder_different:
|
||||||
print("Text encoder is same. Extract U-Net only.")
|
print("Text encoder is same. Extract U-Net only.")
|
||||||
lora_network_o.text_encoder_loras = []
|
lora_network_o.text_encoder_loras = []
|
||||||
diffs = {}
|
diffs = {} # clear diffs
|
||||||
|
|
||||||
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
|
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
|
||||||
lora_name = lora_o.lora_name
|
lora_name = lora_o.lora_name
|
||||||
module_o = lora_o.org_module
|
module_o = lora_o.org_module
|
||||||
module_t = lora_t.org_module
|
module_t = lora_t.org_module
|
||||||
diff = module_t.weight - module_o.weight
|
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
|
||||||
diff = diff.float()
|
|
||||||
|
|
||||||
if args.device:
|
# clear weight to save memory
|
||||||
diff = diff.to(args.device)
|
module_o.weight = None
|
||||||
|
module_t.weight = None
|
||||||
|
|
||||||
diffs[lora_name] = diff
|
diffs[lora_name] = diff
|
||||||
|
|
||||||
|
# clear LoRA network, target U-Net to save memory
|
||||||
|
del lora_network_o
|
||||||
|
del lora_network_t
|
||||||
|
del unet_t
|
||||||
|
|
||||||
# make LoRA with svd
|
# make LoRA with svd
|
||||||
print("calculating by svd")
|
print("calculating by svd")
|
||||||
lora_weights = {}
|
lora_weights = {}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for lora_name, mat in tqdm(list(diffs.items())):
|
for lora_name, mat in tqdm(list(diffs.items())):
|
||||||
|
if args.device:
|
||||||
|
mat = mat.to(args.device)
|
||||||
|
mat = mat.to(torch.float) # calc by float
|
||||||
|
|
||||||
# if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
|
# if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
|
||||||
conv2d = len(mat.size()) == 4
|
conv2d = len(mat.size()) == 4
|
||||||
kernel_size = None if not conv2d else mat.size()[2:4]
|
kernel_size = None if not conv2d else mat.size()[2:4]
|
||||||
@@ -171,8 +213,8 @@ def svd(
|
|||||||
U = U.reshape(out_dim, rank, 1, 1)
|
U = U.reshape(out_dim, rank, 1, 1)
|
||||||
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||||
|
|
||||||
U = U.to("cpu").contiguous()
|
U = U.to(work_device, dtype=save_dtype).contiguous()
|
||||||
Vh = Vh.to("cpu").contiguous()
|
Vh = Vh.to(work_device, dtype=save_dtype).contiguous()
|
||||||
|
|
||||||
lora_weights[lora_name] = (U, Vh)
|
lora_weights[lora_name] = (U, Vh)
|
||||||
|
|
||||||
@@ -230,6 +272,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む"
|
"--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--load_precision",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
choices=[None, "float", "fp16", "bf16"],
|
||||||
|
help="precision in loading, model default if omitted / 読み込み時に精度を変更して読み込む、省略時はモデルファイルによる"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save_precision",
|
"--save_precision",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -285,6 +334,18 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
||||||
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--load_original_model_to",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="location to load original model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 元モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--load_tuned_model_to",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="location to load tuned model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 派生モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user