mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Reduce peak RAM usage
This commit is contained in:
@@ -831,7 +831,7 @@ def is_safetensors(path):
|
|||||||
return os.path.splitext(path)[1].lower() == '.safetensors'
|
return os.path.splitext(path)[1].lower() == '.safetensors'
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
def load_checkpoint_with_text_encoder_conversion(ckpt_path, device):
|
||||||
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
||||||
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
||||||
('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
|
('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
|
||||||
@@ -841,9 +841,9 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
|||||||
|
|
||||||
if is_safetensors(ckpt_path):
|
if is_safetensors(ckpt_path):
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
state_dict = load_file(ckpt_path, "cpu")
|
state_dict = load_file(ckpt_path, device)
|
||||||
else:
|
else:
|
||||||
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||||
if "state_dict" in checkpoint:
|
if "state_dict" in checkpoint:
|
||||||
state_dict = checkpoint["state_dict"]
|
state_dict = checkpoint["state_dict"]
|
||||||
else:
|
else:
|
||||||
@@ -865,18 +865,14 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
|||||||
|
|
||||||
|
|
||||||
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
||||||
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device='cpu', dtype=None):
|
||||||
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
|
||||||
if dtype is not None:
|
|
||||||
for k, v in state_dict.items():
|
|
||||||
if type(v) is torch.Tensor:
|
|
||||||
state_dict[k] = v.to(dtype)
|
|
||||||
|
|
||||||
# Convert the UNet2DConditionModel model.
|
# Convert the UNet2DConditionModel model.
|
||||||
unet_config = create_unet_diffusers_config(v2)
|
unet_config = create_unet_diffusers_config(v2)
|
||||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
|
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
|
||||||
|
|
||||||
unet = UNet2DConditionModel(**unet_config)
|
unet = UNet2DConditionModel(**unet_config).to(device)
|
||||||
info = unet.load_state_dict(converted_unet_checkpoint)
|
info = unet.load_state_dict(converted_unet_checkpoint)
|
||||||
print("loading u-net:", info)
|
print("loading u-net:", info)
|
||||||
|
|
||||||
@@ -884,7 +880,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
|||||||
vae_config = create_vae_diffusers_config()
|
vae_config = create_vae_diffusers_config()
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
|
||||||
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config).to(device)
|
||||||
info = vae.load_state_dict(converted_vae_checkpoint)
|
info = vae.load_state_dict(converted_vae_checkpoint)
|
||||||
print("loading vae:", info)
|
print("loading vae:", info)
|
||||||
|
|
||||||
@@ -918,7 +914,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
|||||||
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
||||||
|
|
||||||
logging.set_verbosity_error() # don't show annoying warning
|
logging.set_verbosity_error() # don't show annoying warning
|
||||||
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
|
||||||
logging.set_verbosity_warning()
|
logging.set_verbosity_warning()
|
||||||
|
|
||||||
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
||||||
|
|||||||
@@ -2536,13 +2536,13 @@ def prepare_dtype(args: argparse.Namespace):
|
|||||||
return weight_dtype, save_dtype
|
return weight_dtype, save_dtype
|
||||||
|
|
||||||
|
|
||||||
def load_target_model(args: argparse.Namespace, weight_dtype):
|
def load_target_model(args: argparse.Namespace, weight_dtype, device='cpu'):
|
||||||
name_or_path = args.pretrained_model_name_or_path
|
name_or_path = args.pretrained_model_name_or_path
|
||||||
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||||
if load_stable_diffusion_format:
|
if load_stable_diffusion_format:
|
||||||
print("load StableDiffusion checkpoint")
|
print("load StableDiffusion checkpoint")
|
||||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path)
|
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device)
|
||||||
else:
|
else:
|
||||||
print("load Diffusers pretrained models")
|
print("load Diffusers pretrained models")
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -123,12 +123,18 @@ def train(args):
|
|||||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
|
for pi in range(accelerator.state.num_processes):
|
||||||
|
if pi == accelerator.state.local_process_index:
|
||||||
|
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||||
|
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator.device)
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
# work on low-ram device
|
# work on low-ram device
|
||||||
if args.lowram:
|
if args.lowram:
|
||||||
text_encoder.to("cuda")
|
text_encoder.to(accelerator.device)
|
||||||
unet.to("cuda")
|
unet.to(accelerator.device)
|
||||||
|
|
||||||
# モデルに xformers とか memory efficient attention を組み込む
|
# モデルに xformers とか memory efficient attention を組み込む
|
||||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||||
|
|||||||
Reference in New Issue
Block a user