mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Fix block swap for sample images
This commit is contained in:
@@ -317,7 +317,6 @@ def denoise(
|
|||||||
# this is ignored for schnell
|
# this is ignored for schnell
|
||||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||||
|
|
||||||
|
|
||||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||||
model.prepare_block_swap_before_forward()
|
model.prepare_block_swap_before_forward()
|
||||||
|
|||||||
@@ -604,7 +604,6 @@ def retrieve_timesteps(
|
|||||||
timesteps = scheduler.timesteps
|
timesteps = scheduler.timesteps
|
||||||
return timesteps, num_inference_steps
|
return timesteps, num_inference_steps
|
||||||
|
|
||||||
|
|
||||||
def denoise(
|
def denoise(
|
||||||
scheduler,
|
scheduler,
|
||||||
model: lumina_models.NextDiT,
|
model: lumina_models.NextDiT,
|
||||||
@@ -648,6 +647,7 @@ def denoise(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
for i, t in enumerate(tqdm(timesteps)):
|
for i, t in enumerate(tqdm(timesteps)):
|
||||||
|
|
||||||
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
|
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
|
||||||
current_timestep = 1 - t / scheduler.config.num_train_timesteps
|
current_timestep = 1 - t / scheduler.config.num_train_timesteps
|
||||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
@@ -700,6 +700,7 @@ def denoise(
|
|||||||
noise_pred = -noise_pred
|
noise_pred = -noise_pred
|
||||||
img = scheduler.step(noise_pred, t, img, return_dict=False)[0]
|
img = scheduler.step(noise_pred, t, img, return_dict=False)[0]
|
||||||
|
|
||||||
|
model.prepare_block_swap_before_forward()
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -367,6 +367,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
parser = train_network.setup_parser()
|
parser = train_network.setup_parser()
|
||||||
train_util.add_dit_training_arguments(parser)
|
train_util.add_dit_training_arguments(parser)
|
||||||
|
|||||||
Reference in New Issue
Block a user