mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix sample generation is not working in FLUX1 fine tuning #1647
This commit is contained in:
@@ -999,8 +999,9 @@ class Flux(nn.Module):
|
||||
|
||||
def prepare_block_swap_before_forward(self):
|
||||
# make: first n blocks are on cuda, and last n blocks are on cpu
|
||||
if self.blocks_to_swap is None:
|
||||
raise ValueError("Block swap is not enabled.")
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
# raise ValueError("Block swap is not enabled.")
|
||||
return
|
||||
for i in range(self.num_block_units - self.blocks_to_swap):
|
||||
for b in self.get_block_unit(i):
|
||||
b.to(self.device)
|
||||
|
||||
@@ -313,6 +313,7 @@ def denoise(
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
model.prepare_block_swap_before_forward()
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
@@ -325,7 +326,8 @@ def denoise(
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
|
||||
model.prepare_block_swap_before_forward()
|
||||
return img
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user