mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Add partitioned VAE
This commit is contained in:
@@ -141,6 +141,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
|
||||
# Apply partitioned for Diffusion4k
|
||||
if args.partitioned_vae:
|
||||
ae.decoder.partitioned = True
|
||||
|
||||
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
|
||||
|
||||
def get_tokenize_strategy(self, args):
|
||||
@@ -360,7 +364,13 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
# pack latents and get img_ids
|
||||
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
||||
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
||||
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
||||
|
||||
if args.partitioned_vae:
|
||||
packed_latent_height, packed_latent_width = noisy_model_input.shape[2], noisy_model_input.shape[3]
|
||||
img_ids = flux_utils.prepare_paritioned_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
||||
else:
|
||||
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
||||
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
||||
|
||||
# get guidance
|
||||
# ensure guidance_scale in args is float
|
||||
@@ -408,7 +418,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
)
|
||||
|
||||
# unpack latents
|
||||
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
||||
if args.partitioned_vae:
|
||||
model_pred = flux_utils.unpack_partitioned_latents(model_pred, packed_latent_height, packed_latent_width)
|
||||
else:
|
||||
# unpack latents
|
||||
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
||||
|
||||
# apply model prediction type
|
||||
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
|
||||
@@ -54,6 +54,8 @@ class AutoEncoderParams:
|
||||
z_channels: int
|
||||
scale_factor: float
|
||||
shift_factor: float
|
||||
stride: int
|
||||
partitioned: bool
|
||||
|
||||
|
||||
def swish(x: Tensor) -> Tensor:
|
||||
@@ -228,6 +230,8 @@ class Decoder(nn.Module):
|
||||
in_channels: int,
|
||||
resolution: int,
|
||||
z_channels: int,
|
||||
partitioned=False,
|
||||
stride=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
@@ -236,6 +240,8 @@ class Decoder(nn.Module):
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.ffactor = 2 ** (self.num_resolutions - 1)
|
||||
self.stride = stride
|
||||
self.partitioned = partitioned
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
@@ -272,7 +278,7 @@ class Decoder(nn.Module):
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
def forward(self, z: Tensor, partitioned=None) -> Tensor:
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
@@ -291,9 +297,55 @@ class Decoder(nn.Module):
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = swish(h)
|
||||
h = self.conv_out(h)
|
||||
|
||||
# Diffusion4k
|
||||
partitioned = partitioned if not None else self.partitioned
|
||||
if self.stride > 1 and partitioned:
|
||||
h = swish(h)
|
||||
|
||||
overlap_size = 1 # because last conv kernel_size = 3
|
||||
res = []
|
||||
partitioned_height = h.shape[2] // self.stride
|
||||
partitioned_width = h.shape[3] // self.stride
|
||||
|
||||
assert self.stride == 2 # only support stride = 2 for now
|
||||
rows = []
|
||||
for i in range(0, h.shape[2], partitioned_height):
|
||||
row = []
|
||||
for j in range(0, h.shape[3], partitioned_width):
|
||||
partition = h[:,:, max(i - overlap_size, 0) : min(i + partitioned_height + overlap_size, h.shape[2]), max(j - overlap_size, 0) : min(j + partitioned_width + overlap_size, h.shape[3])]
|
||||
|
||||
# for strih
|
||||
if i==0 and j==0:
|
||||
partition = torch.nn.functional.pad(partition, (1, 0, 1, 0), "constant", 0)
|
||||
elif i==0:
|
||||
partition = torch.nn.functional.pad(partition, (0, 1, 1, 0), "constant", 0)
|
||||
elif i>0 and j==0:
|
||||
partition = torch.nn.functional.pad(partition, (1, 0, 0, 1), "constant", 0)
|
||||
elif i>0 and j>0:
|
||||
partition = torch.nn.functional.pad(partition, (0, 1, 0, 1), "constant", 0)
|
||||
|
||||
partition = torch.nn.functional.interpolate(partition, scale_factor=self.stride, mode='nearest')
|
||||
partition = self.conv_out(partition)
|
||||
partition = partition[:,:,overlap_size:partitioned_height*2+overlap_size,overlap_size:partitioned_width*2+overlap_size]
|
||||
|
||||
row.append(partition)
|
||||
rows.append(row)
|
||||
|
||||
for row in rows:
|
||||
res.append(torch.cat(row, dim=3))
|
||||
|
||||
h = torch.cat(res, dim=2)
|
||||
# Diffusion4k
|
||||
elif self.stride > 1:
|
||||
h = self.norm_out(h)
|
||||
h = torch.nn.functional.interpolate(h, scale_factor=self.stride, mode='nearest')
|
||||
h = swish(h)
|
||||
h = self.conv_out(h)
|
||||
else:
|
||||
h = self.norm_out(h)
|
||||
h = swish(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
@@ -404,6 +456,9 @@ configs = {
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
# Diffusion4k
|
||||
stride=1,
|
||||
partitioned=False,
|
||||
),
|
||||
),
|
||||
"schnell": ModelSpec(
|
||||
@@ -436,6 +491,9 @@ configs = {
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
# Diffusion4k
|
||||
stride=1,
|
||||
partitioned=False,
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
@@ -680,3 +680,4 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
||||
default=3.0,
|
||||
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
|
||||
)
|
||||
parser.add_argument("--partitioned_vae", action="store_true", help="Partitioned VAE from Diffusion4k paper")
|
||||
|
||||
@@ -346,6 +346,13 @@ def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_wi
|
||||
img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||
return img_ids
|
||||
|
||||
def prepare_paritioned_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int):
|
||||
img_ids = torch.zeros(packed_latent_height // 2, packed_latent_width // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width // 2)[None, :]
|
||||
img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||
return img_ids
|
||||
|
||||
|
||||
def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
|
||||
"""
|
||||
@@ -354,6 +361,13 @@ def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_wid
|
||||
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
|
||||
return x
|
||||
|
||||
def unpack_partitioned_latents(x: torch.FloatTensor, packed_latent_height: int, packed_latent_width: int) -> torch.FloatTensor:
|
||||
"""
|
||||
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
|
||||
"""
|
||||
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height // 2, w=packed_latent_width // 2, ph=2, pw=2)
|
||||
return x
|
||||
|
||||
|
||||
def pack_latents(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user