fix: resolve model corruption issue with pos_embed when using --enable_scaled_pos_embed

This commit is contained in:
Kohya S
2024-12-07 15:12:27 +09:00
parent 8b36d907d8
commit 6bee18db4f
2 changed files with 6 additions and 2 deletions

View File

@@ -14,6 +14,8 @@ The command to install PyTorch is as follows:
### Recent Updates ### Recent Updates
Dec 7, 2024:
- Fixed an issue where the saved model would be corrupted (pos_embed would not be saved) when `--enable_scaled_pos_embed` was specified in `sd3_train.py`.
Dec 3, 2024: Dec 3, 2024:

View File

@@ -870,8 +870,10 @@ class MMDiT(nn.Module):
self.use_scaled_pos_embed = use_scaled_pos_embed self.use_scaled_pos_embed = use_scaled_pos_embed
if self.use_scaled_pos_embed: if self.use_scaled_pos_embed:
# remove pos_embed to free up memory up to 0.4 GB # # remove pos_embed to free up memory up to 0.4 GB -> this causes error because pos_embed is not saved
self.pos_embed = None # self.pos_embed = None
# move pos_embed to CPU to free up memory up to 0.4 GB
self.pos_embed = self.pos_embed.cpu()
# remove duplicates and sort latent sizes in ascending order # remove duplicates and sort latent sizes in ascending order
latent_sizes = list(set(latent_sizes)) latent_sizes = list(set(latent_sizes))