add t5xxl max token length, support schnell

This commit is contained in:
Kohya S
2024-08-16 17:06:05 +09:00
parent 739a8969bc
commit 3921a4efda
3 changed files with 44 additions and 8 deletions

View File

@@ -863,7 +863,8 @@ class Flux(nn.Module):
self.time_in.enable_gradient_checkpointing()
self.vector_in.enable_gradient_checkpointing()
self.guidance_in.enable_gradient_checkpointing()
if self.guidance_in.__class__ != nn.Identity:
self.guidance_in.enable_gradient_checkpointing()
for block in self.double_blocks + self.single_blocks:
block.enable_gradient_checkpointing()
@@ -875,7 +876,8 @@ class Flux(nn.Module):
self.time_in.disable_gradient_checkpointing()
self.vector_in.disable_gradient_checkpointing()
self.guidance_in.disable_gradient_checkpointing()
if self.guidance_in.__class__ != nn.Identity:
self.guidance_in.disable_gradient_checkpointing()
for block in self.double_blocks + self.single_blocks:
block.disable_gradient_checkpointing()
@@ -972,7 +974,8 @@ class FluxUpper(nn.Module):
self.time_in.enable_gradient_checkpointing()
self.vector_in.enable_gradient_checkpointing()
self.guidance_in.enable_gradient_checkpointing()
if self.guidance_in.__class__ != nn.Identity:
self.guidance_in.enable_gradient_checkpointing()
for block in self.double_blocks:
block.enable_gradient_checkpointing()
@@ -984,7 +987,8 @@ class FluxUpper(nn.Module):
self.time_in.disable_gradient_checkpointing()
self.vector_in.disable_gradient_checkpointing()
self.guidance_in.disable_gradient_checkpointing()
if self.guidance_in.__class__ != nn.Identity:
self.guidance_in.disable_gradient_checkpointing()
for block in self.double_blocks:
block.disable_gradient_checkpointing()