fix text encodes are on gpu even when not trained

This commit is contained in:
Kohya S
2024-01-17 21:31:50 +09:00
parent dcf0eeb5b6
commit 976d092c68
2 changed files with 8 additions and 8 deletions

View File

@@ -95,8 +95,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
unet.to(org_unet_device) unet.to(org_unet_device)
else: else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく # Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device) text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device) text_encoders[1].to(accelerator.device, dtype=weight_dtype)
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:

View File

@@ -117,7 +117,7 @@ class NetworkTrainer:
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
): ):
for t_enc in text_encoders: for t_enc in text_encoders:
t_enc.to(accelerator.device) t_enc.to(accelerator.device, dtype=weight_dtype)
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
input_ids = batch["input_ids"].to(accelerator.device) input_ids = batch["input_ids"].to(accelerator.device)
@@ -278,6 +278,7 @@ class NetworkTrainer:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
self.cache_text_encoder_outputs_if_needed( self.cache_text_encoder_outputs_if_needed(
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
) )
@@ -394,8 +395,7 @@ class NetworkTrainer:
for t_enc in text_encoders: for t_enc in text_encoders:
t_enc.requires_grad_(False) t_enc.requires_grad_(False)
# acceleratorがなんかよろしくやってくれるらしい # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
# TODO めちゃくちゃ冗長なのでコードを整理する
if train_unet: if train_unet:
unet = accelerator.prepare(unet) unet = accelerator.prepare(unet)
else: else:
@@ -407,8 +407,8 @@ class NetworkTrainer:
text_encoder = accelerator.prepare(text_encoder) text_encoder = accelerator.prepare(text_encoder)
text_encoders = [text_encoder] text_encoders = [text_encoder]
else: else:
for t_enc in text_encoders: pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
t_enc.to(accelerator.device, dtype=weight_dtype)
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
if args.gradient_checkpointing: if args.gradient_checkpointing:
@@ -685,7 +685,7 @@ class NetworkTrainer:
if accelerator.is_main_process: if accelerator.is_main_process:
init_kwargs = {} init_kwargs = {}
if args.wandb_run_name: if args.wandb_run_name:
init_kwargs['wandb'] = {'name': args.wandb_run_name} init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None: if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config) init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers( accelerator.init_trackers(