mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Fix error on saving T5XXL
This commit is contained in:
@@ -75,7 +75,14 @@ def save_models(
|
|||||||
save_file(clip_g.state_dict(), clip_g_path)
|
save_file(clip_g.state_dict(), clip_g_path)
|
||||||
if t5xxl is not None:
|
if t5xxl is not None:
|
||||||
t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors")
|
t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors")
|
||||||
save_file(t5xxl.state_dict(), t5xxl_path)
|
t5xxl_state_dict = t5xxl.state_dict()
|
||||||
|
|
||||||
|
# replace "shared.weight" with copy of it to avoid annoying shared tensor error on safetensors.save_file
|
||||||
|
shared_weight = t5xxl_state_dict["shared.weight"]
|
||||||
|
shared_weight_copy = shared_weight.detach().clone()
|
||||||
|
t5xxl_state_dict["shared.weight"] = shared_weight_copy
|
||||||
|
|
||||||
|
save_file(t5xxl_state_dict, t5xxl_path)
|
||||||
|
|
||||||
|
|
||||||
def save_sd3_model_on_train_end(
|
def save_sd3_model_on_train_end(
|
||||||
|
|||||||
Reference in New Issue
Block a user