mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix sample generation is not working in FLUX1 fine tuning #1647
This commit is contained in:
@@ -999,8 +999,9 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
def prepare_block_swap_before_forward(self):
|
def prepare_block_swap_before_forward(self):
|
||||||
# make: first n blocks are on cuda, and last n blocks are on cpu
|
# make: first n blocks are on cuda, and last n blocks are on cpu
|
||||||
if self.blocks_to_swap is None:
|
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||||
raise ValueError("Block swap is not enabled.")
|
# raise ValueError("Block swap is not enabled.")
|
||||||
|
return
|
||||||
for i in range(self.num_block_units - self.blocks_to_swap):
|
for i in range(self.num_block_units - self.blocks_to_swap):
|
||||||
for b in self.get_block_unit(i):
|
for b in self.get_block_unit(i):
|
||||||
b.to(self.device)
|
b.to(self.device)
|
||||||
|
|||||||
@@ -313,6 +313,7 @@ def denoise(
|
|||||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||||
|
model.prepare_block_swap_before_forward()
|
||||||
pred = model(
|
pred = model(
|
||||||
img=img,
|
img=img,
|
||||||
img_ids=img_ids,
|
img_ids=img_ids,
|
||||||
@@ -325,7 +326,8 @@ def denoise(
|
|||||||
)
|
)
|
||||||
|
|
||||||
img = img + (t_prev - t_curr) * pred
|
img = img + (t_prev - t_curr) * pred
|
||||||
|
|
||||||
|
model.prepare_block_swap_before_forward()
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user