mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
remove duplicate resolution for scaled pos embed
This commit is contained in:
@@ -871,7 +871,8 @@ class MMDiT(nn.Module):
|
|||||||
# remove pos_embed to free up memory up to 0.4 GB
|
# remove pos_embed to free up memory up to 0.4 GB
|
||||||
self.pos_embed = None
|
self.pos_embed = None
|
||||||
|
|
||||||
# sort latent sizes in ascending order
|
# remove duplcates and sort latent sizes in ascending order
|
||||||
|
latent_sizes = list(set(latent_sizes))
|
||||||
latent_sizes = sorted(latent_sizes)
|
latent_sizes = sorted(latent_sizes)
|
||||||
|
|
||||||
patched_sizes = [latent_size // self.patch_size for latent_size in latent_sizes]
|
patched_sizes = [latent_size // self.patch_size for latent_size in latent_sizes]
|
||||||
|
|||||||
@@ -366,6 +366,7 @@ def train(args):
|
|||||||
if args.enable_scaled_pos_embed:
|
if args.enable_scaled_pos_embed:
|
||||||
resolutions = train_dataset_group.get_resolutions()
|
resolutions = train_dataset_group.get_resolutions()
|
||||||
latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in resolutions] # 8 is stride for latent
|
latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in resolutions] # 8 is stride for latent
|
||||||
|
latent_sizes = list(set(latent_sizes)) # remove duplicates
|
||||||
logger.info(f"Prepare scaled positional embeddings for resolutions: {resolutions}, sizes: {latent_sizes}")
|
logger.info(f"Prepare scaled positional embeddings for resolutions: {resolutions}, sizes: {latent_sizes}")
|
||||||
mmdit.enable_scaled_pos_embed(True, latent_sizes)
|
mmdit.enable_scaled_pos_embed(True, latent_sizes)
|
||||||
|
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||||||
# set resolutions for positional embeddings
|
# set resolutions for positional embeddings
|
||||||
if args.enable_scaled_pos_embed:
|
if args.enable_scaled_pos_embed:
|
||||||
latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in self.resolutions] # 8 is stride for latent
|
latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in self.resolutions] # 8 is stride for latent
|
||||||
|
latent_sizes = list(set(latent_sizes)) # remove duplicates
|
||||||
logger.info(f"Prepare scaled positional embeddings for resolutions: {self.resolutions}, sizes: {latent_sizes}")
|
logger.info(f"Prepare scaled positional embeddings for resolutions: {self.resolutions}, sizes: {latent_sizes}")
|
||||||
mmdit.enable_scaled_pos_embed(True, latent_sizes)
|
mmdit.enable_scaled_pos_embed(True, latent_sizes)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user