Fix to work full_bf16 and full_fp16.

This commit is contained in:
Kohya S
2024-06-29 17:45:50 +09:00
parent 19086465e8
commit ea18d5ba6d
3 changed files with 24 additions and 18 deletions

View File

@@ -891,6 +891,14 @@ class MMDiT(nn.Module):
def model_type(self):
return "m" # only support medium
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def enable_gradient_checkpointing(self):
self.gradient_checkpointing = True
for block in self.joint_blocks: