support SD3 LoRA

This commit is contained in:
kohya-ss
2024-10-25 21:58:31 +09:00
parent f52fb66e8f
commit d2c549d7b2
7 changed files with 1334 additions and 67 deletions

View File

@@ -198,6 +198,23 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
help="[DOES NOT WORK] not supported yet. T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtypemixed precisionからを使用",
)
parser.add_argument(
"--t5xxl_max_token_length",
type=int,
default=256,
help="maximum token length for T5-XXL. 256 is the default value / T5-XXLの最大トークン長。デフォルトは256",
)
parser.add_argument(
"--apply_lg_attn_mask",
action="store_true",
help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスクゼロ埋めを適用する",
)
parser.add_argument(
"--apply_t5_attn_mask",
action="store_true",
help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスクゼロ埋めを適用する",
)
# copy from Diffusers
parser.add_argument(
"--weighting_scheme",
@@ -317,36 +334,36 @@ def do_sample(
x = noise_scaled.to(device).to(dtype)
# print(x.shape)
with torch.no_grad():
for i in tqdm(range(len(sigmas) - 1)):
sigma_hat = sigmas[i]
# with torch.no_grad():
for i in tqdm(range(len(sigmas) - 1)):
sigma_hat = sigmas[i]
timestep = model_sampling.timestep(sigma_hat).float()
timestep = torch.FloatTensor([timestep, timestep]).to(device)
timestep = model_sampling.timestep(sigma_hat).float()
timestep = torch.FloatTensor([timestep, timestep]).to(device)
x_c_nc = torch.cat([x, x], dim=0)
# print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)
x_c_nc = torch.cat([x, x], dim=0)
# print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
model_output = model_output.float()
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
model_output = model_output.float()
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
pos_out, neg_out = batched.chunk(2)
denoised = neg_out + (pos_out - neg_out) * guidance_scale
# print(denoised.shape)
pos_out, neg_out = batched.chunk(2)
denoised = neg_out + (pos_out - neg_out) * guidance_scale
# print(denoised.shape)
# d = to_d(x, sigma_hat, denoised)
dims_to_append = x.ndim - sigma_hat.ndim
sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append]
# print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape)
"""Converts a denoiser output to a Karras ODE derivative."""
d = (x - denoised) / sigma_hat_dims
# d = to_d(x, sigma_hat, denoised)
dims_to_append = x.ndim - sigma_hat.ndim
sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append]
# print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape)
"""Converts a denoiser output to a Karras ODE derivative."""
d = (x - denoised) / sigma_hat_dims
dt = sigmas[i + 1] - sigma_hat
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
x = x.to(dtype)
# Euler method
x = x + d * dt
x = x.to(dtype)
return x
@@ -378,7 +395,7 @@ def sample_images(
logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts):
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
@@ -386,7 +403,7 @@ def sample_images(
# unwrap unet and text_encoder(s)
mmdit = accelerator.unwrap_model(mmdit)
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
text_encoders = None if text_encoders is None else [accelerator.unwrap_model(te) for te in text_encoders]
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
prompts = train_util.load_prompts(args.sample_prompts)
@@ -404,7 +421,7 @@ def sample_images(
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad():
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
sample_image_inference(
accelerator,
@@ -506,29 +523,39 @@ def sample_image_inference(
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
te_outputs = sample_prompts_te_outputs[prompt]
else:
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(prompt)
te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens])
def encode_prompt(prpt):
text_encoder_conds = []
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
text_encoder_conds = sample_prompts_te_outputs[prpt]
print(f"Using cached text encoder outputs for prompt: {prpt}")
if text_encoders is not None:
print(f"Encoding prompt: {prpt}")
tokens_and_masks = tokenize_strategy.tokenize(prpt)
# strategy has apply_t5_attn_mask option
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = te_outputs
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
return text_encoder_conds
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(prompt)
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
# encode negative prompts
if sample_prompts_te_outputs and negative_prompt in sample_prompts_te_outputs:
neg_te_outputs = sample_prompts_te_outputs[negative_prompt]
else:
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(negative_prompt)
neg_te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens])
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = neg_te_outputs
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(negative_prompt)
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
# sample image
clean_memory_on_device(accelerator.device)
with accelerator.autocast():
latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device)
with accelerator.autocast(), torch.no_grad():
# mmdit may be fp8, so we need weight_dtype here. vae is always in that dtype.
latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, vae.dtype, accelerator.device)
# latent to image
clean_memory_on_device(accelerator.device)
@@ -538,7 +565,7 @@ def sample_image_inference(
image = vae.decode(latents)
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
image = image.float()
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)