mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix vae type error during training sdxl
This commit is contained in:
@@ -17,7 +17,6 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
|||||||
|
|
||||||
|
|
||||||
def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||||
# load models for each process
|
|
||||||
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
|
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
|
||||||
for pi in range(accelerator.state.num_processes):
|
for pi in range(accelerator.state.num_processes):
|
||||||
if pi == accelerator.state.local_process_index:
|
if pi == accelerator.state.local_process_index:
|
||||||
|
|||||||
@@ -4042,28 +4042,23 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
|
|||||||
|
|
||||||
|
|
||||||
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
|
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
|
||||||
# load models for each process
|
|
||||||
for pi in range(accelerator.state.num_processes):
|
for pi in range(accelerator.state.num_processes):
|
||||||
if pi == accelerator.state.local_process_index:
|
if pi == accelerator.state.local_process_index:
|
||||||
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||||
|
|
||||||
text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
|
text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
|
||||||
args,
|
args,
|
||||||
weight_dtype,
|
weight_dtype,
|
||||||
accelerator.device if args.lowram else "cpu",
|
accelerator.device if args.lowram else "cpu",
|
||||||
unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2,
|
unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2,
|
||||||
)
|
)
|
||||||
|
|
||||||
# work on low-ram device
|
# work on low-ram device
|
||||||
if args.lowram:
|
if args.lowram:
|
||||||
text_encoder.to(accelerator.device)
|
text_encoder.to(accelerator.device)
|
||||||
unet.to(accelerator.device)
|
unet.to(accelerator.device)
|
||||||
vae.to(accelerator.device)
|
vae.to(accelerator.device)
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -392,23 +392,20 @@ def train(args):
|
|||||||
if args.deepspeed:
|
if args.deepspeed:
|
||||||
# Wrapping model for DeepSpeed
|
# Wrapping model for DeepSpeed
|
||||||
class DeepSpeedModel(torch.nn.Module):
|
class DeepSpeedModel(torch.nn.Module):
|
||||||
def __init__(self, unet, text_encoder, vae) -> None:
|
def __init__(self, unet, text_encoder) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.unet = unet
|
self.unet = unet
|
||||||
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
|
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
|
||||||
self.vae = vae
|
|
||||||
|
|
||||||
def get_models(self):
|
def get_models(self):
|
||||||
return self.unet, self.text_encoders, self.vae
|
return self.unet, self.text_encoders
|
||||||
text_encoders = [text_encoder1, text_encoder2]
|
text_encoders = [text_encoder1, text_encoder2]
|
||||||
unet.to(accelerator.device, dtype=weight_dtype)
|
ds_model = DeepSpeedModel(unet, text_encoders)
|
||||||
[t_enc.to(accelerator.device, dtype=weight_dtype) for t_enc in text_encoders]
|
|
||||||
ds_model = DeepSpeedModel(unet, text_encoders, vae)
|
|
||||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
||||||
# Now, ds_model is an instance of DeepSpeedEngine.
|
# Now, ds_model is an instance of DeepSpeedEngine.
|
||||||
unet, text_encoders, vae = ds_model.get_models() # for compatiblility
|
unet, text_encoders = ds_model.get_models() # for compatiblility
|
||||||
vae.to(vae_dtype) # to avoid explicitly half-vae
|
text_encoder1, text_encoder2 = text_encoder = text_encoders
|
||||||
text_encoder1, text_encoder2 = text_encoders[0], text_encoders[1]
|
training_models = [unet, text_encoder1, text_encoder2]
|
||||||
else: # acceleratorがなんかよろしくやってくれるらしい
|
else: # acceleratorがなんかよろしくやってくれるらしい
|
||||||
if train_unet:
|
if train_unet:
|
||||||
unet = accelerator.prepare(unet)
|
unet = accelerator.prepare(unet)
|
||||||
@@ -493,10 +490,10 @@ def train(args):
|
|||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(*training_models):
|
with accelerator.accumulate(*training_models):
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
with torch.no_grad(): # why this block differ within train_network.py?
|
||||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
else:
|
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||||
with torch.no_grad():
|
else:
|
||||||
# latentに変換
|
# latentに変換
|
||||||
latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype)
|
latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype)
|
||||||
|
|
||||||
@@ -504,7 +501,7 @@ def train(args):
|
|||||||
if torch.any(torch.isnan(latents)):
|
if torch.any(torch.isnan(latents)):
|
||||||
accelerator.print("NaN found in latents, replacing with zeros")
|
accelerator.print("NaN found in latents, replacing with zeros")
|
||||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||||
|
|
||||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||||
input_ids1 = batch["input_ids"]
|
input_ids1 = batch["input_ids"]
|
||||||
|
|||||||
Reference in New Issue
Block a user