mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Merge branch 'kohya-ss:main' into feature/stratified_lr
This commit is contained in:
@@ -127,6 +127,11 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
|||||||
|
|
||||||
## Change History
|
## Change History
|
||||||
|
|
||||||
|
- 31 Mar. 2023, 2023/3/31:
|
||||||
|
- Fix an issue that the VRAM usage temporarily increases when loading a model in `train_network.py`.
|
||||||
|
- Fix an issue that an error occurs when loading a `.safetensors` model in `train_network.py`. [#354](https://github.com/kohya-ss/sd-scripts/issues/354)
|
||||||
|
- `train_network.py` でモデル読み込み時にVRAM使用量が一時的に大きくなる不具合を修正しました。
|
||||||
|
- `train_network.py` で `.safetensors` 形式のモデルを読み込むとエラーになる不具合を修正しました。[#354](https://github.com/kohya-ss/sd-scripts/issues/354)
|
||||||
- 30 Mar. 2023, 2023/3/30:
|
- 30 Mar. 2023, 2023/3/30:
|
||||||
- Support [P+](https://prompt-plus.github.io/) training. Thank you jakaline-dev!
|
- Support [P+](https://prompt-plus.github.io/) training. Thank you jakaline-dev!
|
||||||
- See [#327](https://github.com/kohya-ss/sd-scripts/pull/327) for details.
|
- See [#327](https://github.com/kohya-ss/sd-scripts/pull/327) for details.
|
||||||
|
|||||||
@@ -841,7 +841,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
|
|||||||
|
|
||||||
if is_safetensors(ckpt_path):
|
if is_safetensors(ckpt_path):
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
state_dict = load_file(ckpt_path, device)
|
state_dict = load_file(ckpt_path) # , device) # may causes error
|
||||||
else:
|
else:
|
||||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||||
if "state_dict" in checkpoint:
|
if "state_dict" in checkpoint:
|
||||||
|
|||||||
@@ -131,16 +131,21 @@ def train(args):
|
|||||||
# TODO: modify other training scripts as well
|
# TODO: modify other training scripts as well
|
||||||
if pi == accelerator.state.local_process_index:
|
if pi == accelerator.state.local_process_index:
|
||||||
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator.device)
|
|
||||||
|
text_encoder, vae, unet, _ = train_util.load_target_model(
|
||||||
|
args, weight_dtype, accelerator.device if args.lowram else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# work on low-ram device
|
||||||
|
if args.lowram:
|
||||||
|
text_encoder.to(accelerator.device)
|
||||||
|
unet.to(accelerator.device)
|
||||||
|
vae.to(accelerator.device)
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
# work on low-ram device
|
|
||||||
# NOTE: this may not be necessary because we already load them on gpu
|
|
||||||
if args.lowram:
|
|
||||||
text_encoder.to(accelerator.device)
|
|
||||||
unet.to(accelerator.device)
|
|
||||||
|
|
||||||
# モデルに xformers とか memory efficient attention を組み込む
|
# モデルに xformers とか memory efficient attention を組み込む
|
||||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||||
@@ -566,7 +571,7 @@ def train(args):
|
|||||||
loss = loss * loss_weights
|
loss = loss * loss_weights
|
||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||||
|
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user