Fix error on saving T5XXL

This commit is contained in:
Kohya S
2024-10-29 08:34:31 +09:00
parent af8e216035
commit 75554867ce

View File

@@ -75,7 +75,14 @@ def save_models(
save_file(clip_g.state_dict(), clip_g_path) save_file(clip_g.state_dict(), clip_g_path)
if t5xxl is not None: if t5xxl is not None:
t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors") t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors")
save_file(t5xxl.state_dict(), t5xxl_path) t5xxl_state_dict = t5xxl.state_dict()
# replace "shared.weight" with copy of it to avoid annoying shared tensor error on safetensors.save_file
shared_weight = t5xxl_state_dict["shared.weight"]
shared_weight_copy = shared_weight.detach().clone()
t5xxl_state_dict["shared.weight"] = shared_weight_copy
save_file(t5xxl_state_dict, t5xxl_path)
def save_sd3_model_on_train_end( def save_sd3_model_on_train_end(