Support SD3.5M multi resolutional training

This commit is contained in:
Kohya S
2024-10-31 19:58:22 +09:00
parent 70a179e446
commit 1434d8506f
8 changed files with 215 additions and 10 deletions

View File

@@ -361,7 +361,14 @@ def train(args):
# ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。"
mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate)
# set resolutions for positional embeddings
if args.enable_scaled_pos_embed:
resolutions = train_dataset_group.get_resolutions()
latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in resolutions] # 8 is stride for latent
logger.info(f"Prepare scaled positional embeddings for resolutions: {resolutions}, sizes: {latent_sizes}")
mmdit.enable_scaled_pos_embed(True, latent_sizes)
if args.gradient_checkpointing:
mmdit.enable_gradient_checkpointing()