mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Cast weights to correct precision before transferring them to GPU
This commit is contained in:
@@ -358,6 +358,11 @@ class NetworkTrainer:
|
|||||||
accelerator.print("enable full fp16 training.")
|
accelerator.print("enable full fp16 training.")
|
||||||
network.to(weight_dtype)
|
network.to(weight_dtype)
|
||||||
|
|
||||||
|
unet.requires_grad_(False)
|
||||||
|
unet.to(dtype=weight_dtype)
|
||||||
|
for t_enc in text_encoders:
|
||||||
|
t_enc.requires_grad_(False)
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
# TODO めちゃくちゃ冗長なのでコードを整理する
|
# TODO めちゃくちゃ冗長なのでコードを整理する
|
||||||
if train_unet and train_text_encoder:
|
if train_unet and train_text_encoder:
|
||||||
@@ -397,11 +402,6 @@ class NetworkTrainer:
|
|||||||
text_encoders = train_util.transform_models_if_DDP(text_encoders)
|
text_encoders = train_util.transform_models_if_DDP(text_encoders)
|
||||||
unet, network = train_util.transform_models_if_DDP([unet, network])
|
unet, network = train_util.transform_models_if_DDP([unet, network])
|
||||||
|
|
||||||
unet.requires_grad_(False)
|
|
||||||
unet.to(accelerator.device, dtype=weight_dtype)
|
|
||||||
for t_enc in text_encoders:
|
|
||||||
t_enc.requires_grad_(False)
|
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
# according to TI example in Diffusers, train is required
|
# according to TI example in Diffusers, train is required
|
||||||
unet.train()
|
unet.train()
|
||||||
|
|||||||
Reference in New Issue
Block a user