mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Fix AttributeError: 'T5EncoderModel' object has no attribute 'text_model'
While loading T5 model in GPU.
This commit is contained in:
@@ -540,9 +540,13 @@ class NetworkTrainer:
|
|||||||
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
|
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
|
||||||
if t_enc.device.type != "cpu":
|
if t_enc.device.type != "cpu":
|
||||||
t_enc.to(dtype=te_weight_dtype)
|
t_enc.to(dtype=te_weight_dtype)
|
||||||
if hasattr(t_enc.text_model, "embeddings"):
|
if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"):
|
||||||
# nn.Embedding not support FP8
|
# nn.Embedding not support FP8
|
||||||
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
t_enc.text_model.embeddings.to(
|
||||||
|
dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||||
|
elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"):
|
||||||
|
t_enc.encoder.embeddings.to(
|
||||||
|
dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
||||||
if args.deepspeed:
|
if args.deepspeed:
|
||||||
|
|||||||
Reference in New Issue
Block a user