FLUX.1 LoRA supports CLIP-L

This commit is contained in:
Kohya S
2024-08-27 19:59:40 +09:00
parent 72287d39c7
commit 0087a46e14
6 changed files with 101 additions and 43 deletions

View File

@@ -58,7 +58,7 @@ def sample_images(
logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts):
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
@@ -66,7 +66,8 @@ def sample_images(
# unwrap unet and text_encoder(s)
flux = accelerator.unwrap_model(flux)
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
if text_encoders is not None:
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
prompts = load_prompts(args.sample_prompts)
@@ -134,7 +135,7 @@ def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
flux: flux_models.Flux,
text_encoders: List[CLIPTextModel],
text_encoders: Optional[List[CLIPTextModel]],
ae: flux_models.AutoEncoder,
save_dir,
prompt_dict,
@@ -387,6 +388,7 @@ def get_noisy_model_input_and_timesteps(
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)

View File

@@ -60,7 +60,7 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
if apply_t5_attn_mask is None:
apply_t5_attn_mask = self.apply_t5_attn_mask
clip_l, t5xxl = models
clip_l, t5xxl = models if len(models) == 2 else (models[0], None)
l_tokens, t5_tokens = tokens[:2]
t5_attn_mask = tokens[2] if len(tokens) > 2 else None
@@ -81,6 +81,7 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
else:
t5_out = None
txt_ids = None
t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one
return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer