mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
T5XXL LoRA training, fp8 T5XXL support
This commit is contained in:
@@ -85,7 +85,7 @@ def sample_images(
|
||||
|
||||
if distributed_state.num_processes <= 1:
|
||||
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
||||
with torch.no_grad():
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
for prompt_dict in prompts:
|
||||
sample_image_inference(
|
||||
accelerator,
|
||||
@@ -187,14 +187,27 @@ def sample_image_inference(
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
text_encoder_conds = []
|
||||
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
|
||||
te_outputs = sample_prompts_te_outputs[prompt]
|
||||
else:
|
||||
text_encoder_conds = sample_prompts_te_outputs[prompt]
|
||||
print(f"Using cached text encoder outputs for prompt: {prompt}")
|
||||
if text_encoders is not None:
|
||||
print(f"Encoding prompt: {prompt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
# strategy has apply_t5_attn_mask option
|
||||
te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||
print([x.shape if x is not None else None for x in encoded_text_encoder_conds])
|
||||
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = te_outputs
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
text_encoder_conds = encoded_text_encoder_conds
|
||||
else:
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
|
||||
# sample image
|
||||
weight_dtype = ae.dtype # TOFO give dtype as argument
|
||||
|
||||
Reference in New Issue
Block a user