Merge branch 'sd3' into new_cache

This commit is contained in:
Kohya S
2024-12-09 18:36:10 +09:00
9 changed files with 760 additions and 683 deletions

View File

@@ -564,7 +564,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
)
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス*.sftまたは*.safetensors")
parser.add_argument(
"--controlnet",
"--controlnet_model_name_or_path",
type=str,
default=None,
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス*.sftまたは*.safetensors"

View File

@@ -870,8 +870,10 @@ class MMDiT(nn.Module):
self.use_scaled_pos_embed = use_scaled_pos_embed
if self.use_scaled_pos_embed:
# remove pos_embed to free up memory up to 0.4 GB
self.pos_embed = None
# # remove pos_embed to free up memory up to 0.4 GB -> this causes error because pos_embed is not saved
# self.pos_embed = None
# move pos_embed to CPU to free up memory up to 0.4 GB
self.pos_embed = self.pos_embed.cpu()
# remove duplicates and sort latent sizes in ascending order
latent_sizes = list(set(latent_sizes))