mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Support SD3.5M multi resolutional training
This commit is contained in:
@@ -88,6 +88,78 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
|||||||
return emb
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
def get_scaled_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, sample_size=64, base_size=16):
|
||||||
|
"""
|
||||||
|
This function is contributed by KohakuBlueleaf. Thanks for the contribution!
|
||||||
|
|
||||||
|
Creates scaled 2D sinusoidal positional embeddings that maintain consistent relative positions
|
||||||
|
when the resolution differs from the training resolution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embed_dim (int): Dimension of the positional embedding.
|
||||||
|
grid_size (int or tuple): Size of the position grid (H, W). If int, assumes square grid.
|
||||||
|
cls_token (bool): Whether to include class token. Defaults to False.
|
||||||
|
extra_tokens (int): Number of extra tokens (e.g., cls_token). Defaults to 0.
|
||||||
|
sample_size (int): Reference resolution (typically training resolution). Defaults to 64.
|
||||||
|
base_size (int): Base grid size used during training. Defaults to 16.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray: Positional embeddings of shape (H*W, embed_dim) or
|
||||||
|
(H*W + extra_tokens, embed_dim) if cls_token is True.
|
||||||
|
"""
|
||||||
|
# Convert grid_size to tuple if it's an integer
|
||||||
|
if isinstance(grid_size, int):
|
||||||
|
grid_size = (grid_size, grid_size)
|
||||||
|
|
||||||
|
# Create normalized grid coordinates (0 to 1)
|
||||||
|
grid_h = np.arange(grid_size[0], dtype=np.float32) / grid_size[0]
|
||||||
|
grid_w = np.arange(grid_size[1], dtype=np.float32) / grid_size[1]
|
||||||
|
|
||||||
|
# Calculate scaling factors for height and width
|
||||||
|
# This ensures that the central region matches the original resolution's embeddings
|
||||||
|
scale_h = base_size * grid_size[0] / (sample_size)
|
||||||
|
scale_w = base_size * grid_size[1] / (sample_size)
|
||||||
|
|
||||||
|
# Calculate shift values to center the original resolution's embedding region
|
||||||
|
# This ensures that the central sample_size x sample_size region has similar
|
||||||
|
# positional embeddings to the original resolution
|
||||||
|
shift_h = 1 * scale_h * (grid_size[0] - sample_size) / (2 * grid_size[0])
|
||||||
|
shift_w = 1 * scale_w * (grid_size[1] - sample_size) / (2 * grid_size[1])
|
||||||
|
|
||||||
|
# Apply scaling and shifting to create the final grid coordinates
|
||||||
|
grid_h = grid_h * scale_h - shift_h
|
||||||
|
grid_w = grid_w * scale_w - shift_w
|
||||||
|
|
||||||
|
# Create 2D grid using meshgrid (note: w goes first)
|
||||||
|
grid = np.meshgrid(grid_w, grid_h)
|
||||||
|
grid = np.stack(grid, axis=0)
|
||||||
|
|
||||||
|
# # Calculate the starting indices for the central region
|
||||||
|
# # This is used for debugging/visualization of the central region
|
||||||
|
# st_h = (grid_size[0] - sample_size) // 2
|
||||||
|
# st_w = (grid_size[1] - sample_size) // 2
|
||||||
|
# print(grid[:, st_h : st_h + sample_size, st_w : st_w + sample_size])
|
||||||
|
|
||||||
|
# Reshape grid for positional embedding calculation
|
||||||
|
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
||||||
|
|
||||||
|
# Generate the sinusoidal positional embeddings
|
||||||
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||||
|
|
||||||
|
# Add zeros for extra tokens (e.g., [CLS] token) if required
|
||||||
|
if cls_token and extra_tokens > 0:
|
||||||
|
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
||||||
|
|
||||||
|
return pos_embed
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# # This is what you get when you load SD3.5 state dict
|
||||||
|
# pos_emb = torch.from_numpy(get_scaled_2d_sincos_pos_embed(
|
||||||
|
# 1536, [384, 384], sample_size=64, base_size=16
|
||||||
|
# )).float().unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||||
"""
|
"""
|
||||||
embed_dim: output dimension for each position
|
embed_dim: output dimension for each position
|
||||||
@@ -617,7 +689,7 @@ class MMDiTBlock(nn.Module):
|
|||||||
|
|
||||||
self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs)
|
self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs)
|
||||||
self.x_block = SingleDiTBlock(*args, pre_only=False, x_block_self_attn=x_block_self_attn, **kwargs)
|
self.x_block = SingleDiTBlock(*args, pre_only=False, x_block_self_attn=x_block_self_attn, **kwargs)
|
||||||
|
|
||||||
self.head_dim = self.x_block.attn.head_dim
|
self.head_dim = self.x_block.attn.head_dim
|
||||||
self.mode = self.x_block.attn_mode
|
self.mode = self.x_block.attn_mode
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
@@ -669,6 +741,9 @@ class MMDiT(nn.Module):
|
|||||||
Diffusion model with a Transformer backbone.
|
Diffusion model with a Transformer backbone.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# prepare pos_embed for latent size * 2
|
||||||
|
POS_EMBED_MAX_RATIO = 1.5
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_size: int = 32,
|
input_size: int = 32,
|
||||||
@@ -697,6 +772,8 @@ class MMDiT(nn.Module):
|
|||||||
x_block_self_attn_layers: Optional[list[int]] = [],
|
x_block_self_attn_layers: Optional[list[int]] = [],
|
||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
pos_emb_random_crop_rate: float = 0.0,
|
pos_emb_random_crop_rate: float = 0.0,
|
||||||
|
use_scaled_pos_embed: bool = False,
|
||||||
|
pos_embed_latent_sizes: Optional[list[int]] = None,
|
||||||
model_type: str = "sd3m",
|
model_type: str = "sd3m",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -722,6 +799,8 @@ class MMDiT(nn.Module):
|
|||||||
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
self.enable_scaled_pos_embed(use_scaled_pos_embed, pos_embed_latent_sizes)
|
||||||
|
|
||||||
self.x_embedder = PatchEmbed(
|
self.x_embedder = PatchEmbed(
|
||||||
input_size,
|
input_size,
|
||||||
patch_size,
|
patch_size,
|
||||||
@@ -785,6 +864,43 @@ class MMDiT(nn.Module):
|
|||||||
self.blocks_to_swap = None
|
self.blocks_to_swap = None
|
||||||
self.thread_pool: Optional[ThreadPoolExecutor] = None
|
self.thread_pool: Optional[ThreadPoolExecutor] = None
|
||||||
|
|
||||||
|
def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Optional[list[int]]):
|
||||||
|
self.use_scaled_pos_embed = use_scaled_pos_embed
|
||||||
|
|
||||||
|
if self.use_scaled_pos_embed:
|
||||||
|
# # remove pos_embed to free up memory up to 0.4 GB
|
||||||
|
self.pos_embed = None
|
||||||
|
|
||||||
|
# sort latent sizes in ascending order
|
||||||
|
latent_sizes = sorted(latent_sizes)
|
||||||
|
|
||||||
|
patched_sizes = [latent_size // self.patch_size for latent_size in latent_sizes]
|
||||||
|
|
||||||
|
# calculate value range for each latent area: this is used to determine the pos_emb size from the latent shape
|
||||||
|
max_areas = []
|
||||||
|
for i in range(1, len(patched_sizes)):
|
||||||
|
prev_area = patched_sizes[i - 1] ** 2
|
||||||
|
area = patched_sizes[i] ** 2
|
||||||
|
max_areas.append((prev_area + area) // 2)
|
||||||
|
|
||||||
|
# area of the last latent size, if the latent size exceeds this, error will be raised
|
||||||
|
max_areas.append(int((patched_sizes[-1] * MMDiT.POS_EMBED_MAX_RATIO) ** 2))
|
||||||
|
# print("max_areas", max_areas)
|
||||||
|
|
||||||
|
self.resolution_area_to_latent_size = [(area, latent_size) for area, latent_size in zip(max_areas, patched_sizes)]
|
||||||
|
|
||||||
|
self.resolution_pos_embeds = {}
|
||||||
|
for patched_size in patched_sizes:
|
||||||
|
grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO)
|
||||||
|
pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size)
|
||||||
|
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
|
||||||
|
self.resolution_pos_embeds[patched_size] = pos_embed
|
||||||
|
# print(f"pos_embed for {patched_size}x{patched_size} latent size: {pos_embed.shape}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.resolution_area_to_latent_size = None
|
||||||
|
self.resolution_pos_embeds = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_type(self):
|
def model_type(self):
|
||||||
return self._model_type
|
return self._model_type
|
||||||
@@ -884,6 +1000,54 @@ class MMDiT(nn.Module):
|
|||||||
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
||||||
return spatial_pos_embed
|
return spatial_pos_embed
|
||||||
|
|
||||||
|
def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: bool = False):
|
||||||
|
p = self.x_embedder.patch_size
|
||||||
|
# patched size
|
||||||
|
h = (h + 1) // p
|
||||||
|
w = (w + 1) // p
|
||||||
|
|
||||||
|
# select pos_embed size based on area
|
||||||
|
area = h * w
|
||||||
|
patched_size = None
|
||||||
|
for area_, patched_size_ in self.resolution_area_to_latent_size:
|
||||||
|
if area <= area_:
|
||||||
|
patched_size = patched_size_
|
||||||
|
break
|
||||||
|
if patched_size is None:
|
||||||
|
raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.")
|
||||||
|
|
||||||
|
pos_embed_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO)
|
||||||
|
if h > pos_embed_size or w > pos_embed_size:
|
||||||
|
# fallback to normal pos_embed
|
||||||
|
logger.warning(
|
||||||
|
f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide."
|
||||||
|
)
|
||||||
|
return self.cropped_pos_embed(h, w, device=device, random_crop=random_crop)
|
||||||
|
|
||||||
|
if not random_crop:
|
||||||
|
top = (pos_embed_size - h) // 2
|
||||||
|
left = (pos_embed_size - w) // 2
|
||||||
|
else:
|
||||||
|
top = torch.randint(0, pos_embed_size - h + 1, (1,)).item()
|
||||||
|
left = torch.randint(0, pos_embed_size - w + 1, (1,)).item()
|
||||||
|
|
||||||
|
pos_embed = self.resolution_pos_embeds[patched_size]
|
||||||
|
if pos_embed.device != device:
|
||||||
|
pos_embed = pos_embed.to(device)
|
||||||
|
# which is better to update device, or transfer every time to device? -> 64x64 emb is 96*96*1536*4=56MB. It's okay to update device.
|
||||||
|
self.resolution_pos_embeds[patched_size] = pos_embed # update device
|
||||||
|
if pos_embed.dtype != dtype:
|
||||||
|
pos_embed = pos_embed.to(dtype)
|
||||||
|
self.resolution_pos_embeds[patched_size] = pos_embed # update dtype
|
||||||
|
|
||||||
|
spatial_pos_embed = pos_embed.reshape(1, pos_embed_size, pos_embed_size, pos_embed.shape[-1])
|
||||||
|
spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
|
||||||
|
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
||||||
|
# print(
|
||||||
|
# f"patched size: {h}x{w}, pos_embed size: {pos_embed_size}, pos_embed shape: {pos_embed.shape}, top: {top}, left: {left}"
|
||||||
|
# )
|
||||||
|
return spatial_pos_embed
|
||||||
|
|
||||||
def enable_block_swap(self, num_blocks: int):
|
def enable_block_swap(self, num_blocks: int):
|
||||||
self.blocks_to_swap = num_blocks
|
self.blocks_to_swap = num_blocks
|
||||||
|
|
||||||
@@ -931,7 +1095,16 @@ class MMDiT(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype)
|
|
||||||
|
# x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype)
|
||||||
|
if not self.use_scaled_pos_embed:
|
||||||
|
pos_embed = self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype)
|
||||||
|
else:
|
||||||
|
# print(f"Using scaled pos_embed for size {H}x{W}")
|
||||||
|
pos_embed = self.cropped_scaled_pos_embed(H, W, device=x.device, dtype=x.dtype, random_crop=pos_emb_random_crop)
|
||||||
|
x = self.x_embedder(x) + pos_embed
|
||||||
|
del pos_embed
|
||||||
|
|
||||||
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
||||||
if y is not None and self.y_embedder is not None:
|
if y is not None and self.y_embedder is not None:
|
||||||
y = self.y_embedder(y) # (N, D)
|
y = self.y_embedder(y) # (N, D)
|
||||||
|
|||||||
@@ -246,6 +246,12 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="Random crop rate for positional embeddings, default is 0.0. Only for SD3.5M"
|
help="Random crop rate for positional embeddings, default is 0.0. Only for SD3.5M"
|
||||||
" / 位置埋め込みのランダムクロップ率、デフォルトは0.0。SD3.5M以外では予期しない動作になります",
|
" / 位置埋め込みのランダムクロップ率、デフォルトは0.0。SD3.5M以外では予期しない動作になります",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable_scaled_pos_embed",
|
||||||
|
action="store_true",
|
||||||
|
help="Scale position embeddings for each resolution during multi-resolution training. Only for SD3.5M"
|
||||||
|
" / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります",
|
||||||
|
)
|
||||||
|
|
||||||
# copy from Diffusers
|
# copy from Diffusers
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -518,7 +518,7 @@ class LatentsCachingStrategy:
|
|||||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||||
"""
|
"""
|
||||||
for SD/SDXL/SD3.0
|
for SD/SDXL
|
||||||
"""
|
"""
|
||||||
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
|
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
|
||||||
|
|
||||||
|
|||||||
@@ -212,7 +212,7 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, True)
|
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
||||||
|
|
||||||
def load_latents_from_disk(
|
def load_latents_from_disk(
|
||||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||||
@@ -226,7 +226,7 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy):
|
|||||||
vae_dtype = vae.dtype
|
vae_dtype = vae.dtype
|
||||||
|
|
||||||
self._default_cache_batch_latents(
|
self._default_cache_batch_latents(
|
||||||
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, True
|
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if not train_util.HIGH_VRAM:
|
if not train_util.HIGH_VRAM:
|
||||||
|
|||||||
@@ -399,7 +399,12 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
|
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
||||||
|
|
||||||
|
def load_latents_from_disk(
|
||||||
|
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||||
|
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||||
|
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
|
||||||
|
|
||||||
# TODO remove circular dependency for ImageInfo
|
# TODO remove circular dependency for ImageInfo
|
||||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||||
@@ -407,7 +412,9 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
|
|||||||
vae_device = vae.device
|
vae_device = vae.device
|
||||||
vae_dtype = vae.dtype
|
vae_dtype = vae.dtype
|
||||||
|
|
||||||
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
|
self._default_cache_batch_latents(
|
||||||
|
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
|
||||||
|
)
|
||||||
|
|
||||||
if not train_util.HIGH_VRAM:
|
if not train_util.HIGH_VRAM:
|
||||||
train_util.clean_memory_on_device(vae.device)
|
train_util.clean_memory_on_device(vae.device)
|
||||||
|
|||||||
@@ -2510,6 +2510,9 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
|||||||
for dataset in self.datasets:
|
for dataset in self.datasets:
|
||||||
dataset.verify_bucket_reso_steps(min_steps)
|
dataset.verify_bucket_reso_steps(min_steps)
|
||||||
|
|
||||||
|
def get_resolutions(self) -> List[Tuple[int, int]]:
|
||||||
|
return [(dataset.width, dataset.height) for dataset in self.datasets]
|
||||||
|
|
||||||
def is_latent_cacheable(self) -> bool:
|
def is_latent_cacheable(self) -> bool:
|
||||||
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
||||||
|
|
||||||
|
|||||||
@@ -361,7 +361,14 @@ def train(args):
|
|||||||
# ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。"
|
# ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。"
|
||||||
|
|
||||||
mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate)
|
mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate)
|
||||||
|
|
||||||
|
# set resolutions for positional embeddings
|
||||||
|
if args.enable_scaled_pos_embed:
|
||||||
|
resolutions = train_dataset_group.get_resolutions()
|
||||||
|
latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in resolutions] # 8 is stride for latent
|
||||||
|
logger.info(f"Prepare scaled positional embeddings for resolutions: {resolutions}, sizes: {latent_sizes}")
|
||||||
|
mmdit.enable_scaled_pos_embed(True, latent_sizes)
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
mmdit.enable_gradient_checkpointing()
|
mmdit.enable_gradient_checkpointing()
|
||||||
|
|
||||||
|
|||||||
@@ -26,8 +26,8 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.sample_prompts_te_outputs = None
|
self.sample_prompts_te_outputs = None
|
||||||
|
|
||||||
def assert_extra_args(self, args, train_dataset_group):
|
def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup):
|
||||||
super().assert_extra_args(args, train_dataset_group)
|
# super().assert_extra_args(args, train_dataset_group)
|
||||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||||
|
|
||||||
if args.fp8_base_unet:
|
if args.fp8_base_unet:
|
||||||
@@ -53,6 +53,9 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||||||
|
|
||||||
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
||||||
|
|
||||||
|
# enumerate resolutions from dataset for positional embeddings
|
||||||
|
self.resolutions = train_dataset_group.get_resolutions()
|
||||||
|
|
||||||
def load_target_model(self, args, weight_dtype, accelerator):
|
def load_target_model(self, args, weight_dtype, accelerator):
|
||||||
# currently offload to cpu for some models
|
# currently offload to cpu for some models
|
||||||
|
|
||||||
@@ -67,6 +70,12 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||||||
self.model_type = mmdit.model_type
|
self.model_type = mmdit.model_type
|
||||||
mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate)
|
mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate)
|
||||||
|
|
||||||
|
# set resolutions for positional embeddings
|
||||||
|
if args.enable_scaled_pos_embed:
|
||||||
|
latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in self.resolutions] # 8 is stride for latent
|
||||||
|
logger.info(f"Prepare scaled positional embeddings for resolutions: {self.resolutions}, sizes: {latent_sizes}")
|
||||||
|
mmdit.enable_scaled_pos_embed(True, latent_sizes)
|
||||||
|
|
||||||
if args.fp8_base:
|
if args.fp8_base:
|
||||||
# check dtype of model
|
# check dtype of model
|
||||||
if mmdit.dtype == torch.float8_e4m3fnuz or mmdit.dtype == torch.float8_e5m2 or mmdit.dtype == torch.float8_e5m2fnuz:
|
if mmdit.dtype == torch.float8_e4m3fnuz or mmdit.dtype == torch.float8_e5m2 or mmdit.dtype == torch.float8_e5m2fnuz:
|
||||||
|
|||||||
Reference in New Issue
Block a user