mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
Merge branch 'sd3_5_support' of https://github.com/kohya-ss/sd-scripts into sd3_5_support
This commit is contained in:
@@ -75,7 +75,14 @@ def save_models(
|
||||
save_file(clip_g.state_dict(), clip_g_path)
|
||||
if t5xxl is not None:
|
||||
t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors")
|
||||
save_file(t5xxl.state_dict(), t5xxl_path)
|
||||
t5xxl_state_dict = t5xxl.state_dict()
|
||||
|
||||
# replace "shared.weight" with copy of it to avoid annoying shared tensor error on safetensors.save_file
|
||||
shared_weight = t5xxl_state_dict["shared.weight"]
|
||||
shared_weight_copy = shared_weight.detach().clone()
|
||||
t5xxl_state_dict["shared.weight"] = shared_weight_copy
|
||||
|
||||
save_file(t5xxl_state_dict, t5xxl_path)
|
||||
|
||||
|
||||
def save_sd3_model_on_train_end(
|
||||
@@ -364,6 +371,7 @@ def do_sample(
|
||||
x_c_nc = torch.cat([x, x], dim=0)
|
||||
# print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)
|
||||
|
||||
mmdit.prepare_block_swap_before_forward()
|
||||
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
|
||||
model_output = model_output.float()
|
||||
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
|
||||
@@ -385,6 +393,7 @@ def do_sample(
|
||||
x = x + d * dt
|
||||
x = x.to(dtype)
|
||||
|
||||
mmdit.prepare_block_swap_before_forward()
|
||||
return x
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user