mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix text encodes are on gpu even when not trained
This commit is contained in:
@@ -95,8 +95,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
unet.to(org_unet_device)
|
unet.to(org_unet_device)
|
||||||
else:
|
else:
|
||||||
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
||||||
text_encoders[0].to(accelerator.device)
|
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||||
text_encoders[1].to(accelerator.device)
|
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
||||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||||
|
|||||||
@@ -117,7 +117,7 @@ class NetworkTrainer:
|
|||||||
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
|
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
|
||||||
):
|
):
|
||||||
for t_enc in text_encoders:
|
for t_enc in text_encoders:
|
||||||
t_enc.to(accelerator.device)
|
t_enc.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
||||||
input_ids = batch["input_ids"].to(accelerator.device)
|
input_ids = batch["input_ids"].to(accelerator.device)
|
||||||
@@ -278,6 +278,7 @@ class NetworkTrainer:
|
|||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
|
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
|
||||||
|
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
|
||||||
self.cache_text_encoder_outputs_if_needed(
|
self.cache_text_encoder_outputs_if_needed(
|
||||||
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
|
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
|
||||||
)
|
)
|
||||||
@@ -394,8 +395,7 @@ class NetworkTrainer:
|
|||||||
for t_enc in text_encoders:
|
for t_enc in text_encoders:
|
||||||
t_enc.requires_grad_(False)
|
t_enc.requires_grad_(False)
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
||||||
# TODO めちゃくちゃ冗長なのでコードを整理する
|
|
||||||
if train_unet:
|
if train_unet:
|
||||||
unet = accelerator.prepare(unet)
|
unet = accelerator.prepare(unet)
|
||||||
else:
|
else:
|
||||||
@@ -407,8 +407,8 @@ class NetworkTrainer:
|
|||||||
text_encoder = accelerator.prepare(text_encoder)
|
text_encoder = accelerator.prepare(text_encoder)
|
||||||
text_encoders = [text_encoder]
|
text_encoders = [text_encoder]
|
||||||
else:
|
else:
|
||||||
for t_enc in text_encoders:
|
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
|
||||||
t_enc.to(accelerator.device, dtype=weight_dtype)
|
|
||||||
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
|
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
@@ -685,7 +685,7 @@ class NetworkTrainer:
|
|||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
init_kwargs = {}
|
init_kwargs = {}
|
||||||
if args.wandb_run_name:
|
if args.wandb_run_name:
|
||||||
init_kwargs['wandb'] = {'name': args.wandb_run_name}
|
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||||
if args.log_tracker_config is not None:
|
if args.log_tracker_config is not None:
|
||||||
init_kwargs = toml.load(args.log_tracker_config)
|
init_kwargs = toml.load(args.log_tracker_config)
|
||||||
accelerator.init_trackers(
|
accelerator.init_trackers(
|
||||||
|
|||||||
Reference in New Issue
Block a user