From af8e216035128767234163a24debf2f4df5aa36d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 28 Oct 2024 22:08:57 +0900 Subject: [PATCH 1/2] Fix sample image gen to work with block swap --- library/sd3_train_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index a0202ad4..054d1b4a 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -364,6 +364,7 @@ def do_sample( x_c_nc = torch.cat([x, x], dim=0) # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) + mmdit.prepare_block_swap_before_forward() model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) model_output = model_output.float() batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) @@ -385,6 +386,7 @@ def do_sample( x = x + d * dt x = x.to(dtype) + mmdit.prepare_block_swap_before_forward() return x From 75554867ce390ec0957cc52a70c0695e19c71fe2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 29 Oct 2024 08:34:31 +0900 Subject: [PATCH 2/2] Fix error on saving T5XXL --- library/sd3_train_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 054d1b4a..1702e81c 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -75,7 +75,14 @@ def save_models( save_file(clip_g.state_dict(), clip_g_path) if t5xxl is not None: t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors") - save_file(t5xxl.state_dict(), t5xxl_path) + t5xxl_state_dict = t5xxl.state_dict() + + # replace "shared.weight" with copy of it to avoid annoying shared tensor error on safetensors.save_file + shared_weight = t5xxl_state_dict["shared.weight"] + shared_weight_copy = shared_weight.detach().clone() + t5xxl_state_dict["shared.weight"] = shared_weight_copy + + save_file(t5xxl_state_dict, t5xxl_path) def save_sd3_model_on_train_end(