fix sample generation is not working in FLUX1 fine tuning #1647

This commit is contained in:
Kohya S
2024-09-28 17:12:56 +09:00
parent 3ebb65f945
commit a9aa52658a
2 changed files with 6 additions and 3 deletions

View File

@@ -999,8 +999,9 @@ class Flux(nn.Module):
def prepare_block_swap_before_forward(self): def prepare_block_swap_before_forward(self):
# make: first n blocks are on cuda, and last n blocks are on cpu # make: first n blocks are on cuda, and last n blocks are on cpu
if self.blocks_to_swap is None: if self.blocks_to_swap is None or self.blocks_to_swap == 0:
raise ValueError("Block swap is not enabled.") # raise ValueError("Block swap is not enabled.")
return
for i in range(self.num_block_units - self.blocks_to_swap): for i in range(self.num_block_units - self.blocks_to_swap):
for b in self.get_block_unit(i): for b in self.get_block_unit(i):
b.to(self.device) b.to(self.device)

View File

@@ -313,6 +313,7 @@ def denoise(
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
model.prepare_block_swap_before_forward()
pred = model( pred = model(
img=img, img=img,
img_ids=img_ids, img_ids=img_ids,
@@ -325,7 +326,8 @@ def denoise(
) )
img = img + (t_prev - t_curr) * pred img = img + (t_prev - t_curr) * pred
model.prepare_block_swap_before_forward()
return img return img