sample generation in SDXL ControlNet training

This commit is contained in:
Kohya S
2024-09-30 23:39:32 +09:00
parent d78f6a775c
commit 793999d116
5 changed files with 322 additions and 165 deletions

View File

@@ -13,12 +13,20 @@ from tqdm import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers import SchedulerMixin, StableDiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.models import AutoencoderKL
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.utils import logging
from PIL import Image
from library import sdxl_model_util, sdxl_train_util, train_util
from library import (
sdxl_model_util,
sdxl_train_util,
strategy_base,
strategy_sdxl,
train_util,
sdxl_original_unet,
sdxl_original_control_net,
)
try:
@@ -537,7 +545,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
vae: AutoencoderKL,
text_encoder: List[CLIPTextModel],
tokenizer: List[CLIPTokenizer],
unet: UNet2DConditionModel,
unet: Union[sdxl_original_unet.SdxlUNet2DConditionModel, sdxl_original_control_net.SdxlControlledUNet],
scheduler: SchedulerMixin,
# clip_skip: int,
safety_checker: StableDiffusionSafetyChecker,
@@ -594,74 +602,6 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
max_embeddings_multiples,
is_sdxl_text_encoder2,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
if negative_prompt is None:
negative_prompt = [""] * batch_size
elif isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
text_embeddings, text_pool, uncond_embeddings, uncond_pool = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
is_sdxl_text_encoder2=is_sdxl_text_encoder2,
)
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) # ??
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
if text_pool is not None:
text_pool = text_pool.repeat(1, num_images_per_prompt)
text_pool = text_pool.view(bs_embed * num_images_per_prompt, -1)
if do_classifier_free_guidance:
bs_embed, seq_len, _ = uncond_embeddings.shape
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
if uncond_pool is not None:
uncond_pool = uncond_pool.repeat(1, num_images_per_prompt)
uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1)
return text_embeddings, text_pool, uncond_embeddings, uncond_pool
return text_embeddings, text_pool, None, None
def check_inputs(self, prompt, height, width, strength, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@@ -792,7 +732,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
max_embeddings_multiples: Optional[int] = 3,
output_type: Optional[str] = "pil",
return_dict: bool = True,
controlnet=None,
controlnet: sdxl_original_control_net.SdxlControlNet = None,
controlnet_image=None,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
@@ -896,32 +836,24 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
# 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す
# To simplify the implementation, switch the tokenzer/text encoder and call it twice
text_embeddings_list = []
text_pool = None
uncond_embeddings_list = []
uncond_pool = None
for i in range(len(self.tokenizers)):
self.tokenizer = self.tokenizers[i]
self.text_encoder = self.text_encoders[i]
tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy: strategy_sdxl.SdxlTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
text_embeddings, tp1, uncond_embeddings, up1 = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
max_embeddings_multiples,
is_sdxl_text_encoder2=i == 1,
text_input_ids, text_weights = tokenize_strategy.tokenize_with_weights(prompt)
hidden_states_1, hidden_states_2, text_pool = encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, self.text_encoders, text_input_ids, text_weights
)
text_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1)
if do_classifier_free_guidance:
input_ids, weights = tokenize_strategy.tokenize_with_weights(negative_prompt or "")
hidden_states_1, hidden_states_2, uncond_pool = encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, self.text_encoders, input_ids, weights
)
text_embeddings_list.append(text_embeddings)
uncond_embeddings_list.append(uncond_embeddings)
if tp1 is not None:
text_pool = tp1
if up1 is not None:
uncond_pool = up1
uncond_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1)
else:
uncond_embeddings = None
uncond_pool = None
unet_dtype = self.unet.dtype
dtype = unet_dtype
@@ -970,23 +902,23 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# create size embs and concat embeddings for SDXL
orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(dtype)
orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(device, dtype)
crop_size = torch.zeros_like(orig_size)
target_size = orig_size
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype)
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(device, dtype)
# make conditionings
text_pool = text_pool.to(device, dtype)
if do_classifier_free_guidance:
text_embeddings = torch.cat(text_embeddings_list, dim=2)
uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2)
text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(dtype)
text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(device, dtype)
cond_vector = torch.cat([text_pool, embs], dim=1)
uncond_vector = torch.cat([uncond_pool, embs], dim=1)
vector_embedding = torch.cat([uncond_vector, cond_vector]).to(dtype)
uncond_pool = uncond_pool.to(device, dtype)
cond_vector = torch.cat([text_pool, embs], dim=1).to(dtype)
uncond_vector = torch.cat([uncond_pool, embs], dim=1).to(dtype)
vector_embedding = torch.cat([uncond_vector, cond_vector])
else:
text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype)
vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype)
text_embedding = text_embeddings.to(device, dtype)
vector_embedding = torch.cat([text_pool, embs], dim=1)
# 8. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
@@ -994,22 +926,14 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
unet_additional_args = {}
if controlnet is not None:
down_block_res_samples, mid_block_res_sample = controlnet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
controlnet_cond=controlnet_image,
conditioning_scale=1.0,
guess_mode=False,
return_dict=False,
)
unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
# FIXME SD1 ControlNet is not working
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
if controlnet is not None:
input_resi_add, mid_add = controlnet(latent_model_input, t, text_embedding, vector_embedding, controlnet_image)
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding, input_resi_add, mid_add)
else:
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training
# perform guidance

View File

@@ -1,6 +1,7 @@
# base class for platform strategies. this file defines the interface for strategies
import os
import re
from typing import Any, List, Optional, Tuple, Union
import numpy as np
@@ -22,6 +23,24 @@ logger = logging.getLogger(__name__)
class TokenizeStrategy:
_strategy = None # strategy instance: actual strategy class
_re_attention = re.compile(
r"""\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
@@ -54,7 +73,151 @@ class TokenizeStrategy:
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
raise NotImplementedError
def _get_input_ids(self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None) -> torch.Tensor:
def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
raise NotImplementedError
def _get_weighted_input_ids(
self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
max_length includes starting and ending tokens.
"""
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in TokenizeStrategy._re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(text: str, max_length: int):
r"""
Tokenize a list of prompts and return its tokens with weights of each token. max_length does not include starting and ending token.
No padding, starting or ending token is included.
"""
truncated = False
texts_and_weights = parse_prompt_attention(text)
tokens = []
weights = []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = tokenizer(word).input_ids[1:-1]
tokens += token
# copy the weight by length of token
weights += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(tokens) > max_length:
truncated = True
break
# truncate
if len(tokens) > max_length:
truncated = True
tokens = tokens[:max_length]
weights = weights[:max_length]
if truncated:
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
tokens = [bos] + tokens + [eos] + [pad] * (max_length - 2 - len(tokens))
weights = [1.0] + weights + [1.0] * (max_length - 1 - len(weights))
return tokens, weights
if max_length is None:
max_length = tokenizer.model_max_length
tokens, weights = get_prompts_with_weights(text, max_length - 2)
tokens, weights = pad_tokens_and_weights(
tokens, weights, max_length, tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id
)
return torch.tensor(tokens).unsqueeze(0), torch.tensor(weights).unsqueeze(0)
def _get_input_ids(
self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None, weighted: bool = False
) -> torch.Tensor:
"""
for SD1.5/2.0/SDXL
TODO support batch input
@@ -62,7 +225,10 @@ class TokenizeStrategy:
if max_length is None:
max_length = tokenizer.model_max_length - 2
input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids
if weighted:
input_ids, weights = self._get_weighted_input_ids(tokenizer, text, max_length)
else:
input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids
if max_length > tokenizer.model_max_length:
input_ids = input_ids.squeeze(0)
@@ -101,6 +267,17 @@ class TokenizeStrategy:
iids_list.append(ids_chunk)
input_ids = torch.stack(iids_list) # 3,77
if weighted:
weights = weights.squeeze(0)
new_weights = torch.ones(input_ids.shape)
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
b = i // (tokenizer.model_max_length - 2)
new_weights[b, 1 : 1 + tokenizer.model_max_length - 2] = weights[i : i + tokenizer.model_max_length - 2]
weights = new_weights
if weighted:
return input_ids, weights
return input_ids
@@ -126,6 +303,17 @@ class TextEncodingStrategy:
:return: list of output embeddings for each architecture
"""
raise NotImplementedError
def encode_tokens_with_weights(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor]
) -> List[torch.Tensor]:
"""
Encode tokens into embeddings and outputs.
:param tokens: list of token tensors for each TextModel
:param weights: list of weight tensors for each TextModel
:return: list of output embeddings for each architecture
"""
raise NotImplementedError
class TextEncoderOutputsCachingStrategy:

View File

@@ -37,6 +37,22 @@ class SdxlTokenizeStrategy(TokenizeStrategy):
torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0),
)
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
text = [text] if isinstance(text, str) else text
tokens1_list, tokens2_list = [], []
weights1_list, weights2_list = [], []
for t in text:
tokens1, weights1 = self._get_weighted_input_ids(self.tokenizer1, t, self.max_length)
tokens2, weights2 = self._get_weighted_input_ids(self.tokenizer2, t, self.max_length)
tokens1_list.append(tokens1)
tokens2_list.append(tokens2)
weights1_list.append(weights1)
weights2_list.append(weights2)
return (torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)), (
torch.stack(weights1_list, dim=0),
torch.stack(weights2_list, dim=0),
)
class SdxlTextEncodingStrategy(TextEncodingStrategy):
def __init__(self) -> None:
@@ -98,7 +114,10 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy):
):
# input_ids: b,n,77 -> b*n, 77
b_size = input_ids1.size()[0]
max_token_length = input_ids1.size()[1] * input_ids1.size()[2]
if input_ids1.size()[1] == 1:
max_token_length = None
else:
max_token_length = input_ids1.size()[1] * input_ids1.size()[2]
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
input_ids1 = input_ids1.to(text_encoder1.device)
@@ -172,6 +191,24 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy):
)
return [hidden_states1, hidden_states2, pool2]
def encode_tokens_with_weights(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor]
) -> List[torch.Tensor]:
hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens)
# apply weights
if weights[0].shape[1] == 1: # no max_token_length
# weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
hidden_states1 = hidden_states1 * weights[0].squeeze(1).unsqueeze(2)
hidden_states2 = hidden_states2 * weights[1].squeeze(1).unsqueeze(2)
else:
# weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
for weight, hidden_states in zip(weights, [hidden_states1, hidden_states2]):
for i in range(weight.shape[1]):
hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[:, i, 1:-1]
return [hidden_states1, hidden_states2, pool2]
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"

View File

@@ -74,6 +74,7 @@ import imagesize
import cv2
import safetensors.torch
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
import library.model_util as model_util
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
@@ -3581,7 +3582,20 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
# available backends:
# https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5
# https://pytorch.org/docs/stable/torch.compiler.html
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt", "tensort", "ipex", "tvm"],
choices=[
"eager",
"aot_eager",
"inductor",
"aot_ts_nvfuser",
"nvprims_nvfuser",
"cudagraphs",
"ofi",
"fx2trt",
"onnxrt",
"tensort",
"ipex",
"tvm",
],
help="dynamo backend type (default is inductor) / dynamoのbackendの種類デフォルトは inductor",
)
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
@@ -5850,8 +5864,8 @@ def sample_images_common(
pipe_class,
accelerator: Accelerator,
args: argparse.Namespace,
epoch,
steps,
epoch: int,
steps: int,
device,
vae,
tokenizer,
@@ -5910,11 +5924,7 @@ def sample_images_common(
with open(args.sample_prompts, "r", encoding="utf-8") as f:
prompts = json.load(f)
# schedulers: dict = {} cannot find where this is used
default_scheduler = get_my_scheduler(
sample_sampler=args.sample_sampler,
v_parameterization=args.v_parameterization,
)
default_scheduler = get_my_scheduler(sample_sampler=args.sample_sampler, v_parameterization=args.v_parameterization)
pipeline = pipe_class(
text_encoder=text_encoder,
@@ -5975,21 +5985,18 @@ def sample_images_common(
# clear pipeline and cache to reduce vram usage
del pipeline
# I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here.
# with torch.cuda.device(torch.cuda.current_device()):
# torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
torch.set_rng_state(rng_state)
if torch.cuda.is_available() and cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
pipeline,
pipeline: Union[StableDiffusionLongPromptWeightingPipeline, SdxlStableDiffusionLongPromptWeightingPipeline],
save_dir,
prompt_dict,
epoch,

View File

@@ -83,6 +83,7 @@ def train(args):
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
tokenizer1, tokenizer2 = tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2 # this is used for sampling images
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
@@ -436,19 +437,19 @@ def train(args):
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# # For --sample_at_first
# sdxl_train_util.sample_images(
# accelerator,
# args,
# 0,
# global_step,
# accelerator.device,
# vae,
# [tokenizer1, tokenizer2],
# [text_encoder1, text_encoder2],
# unet,
# controlnet=control_net,
# )
# For --sample_at_first
sdxl_train_util.sample_images(
accelerator,
args,
0,
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2, unwrap_model(text_encoder2)],
unet,
controlnet=control_net,
)
# training loop
for epoch in range(num_train_epochs):
@@ -484,7 +485,7 @@ def train(args):
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens(
tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2]
tokenize_strategy, [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], [input_ids1, input_ids2]
)
if args.full_fp16:
encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype)
@@ -558,18 +559,18 @@ def train(args):
progress_bar.update(1)
global_step += 1
# sdxl_train_util.sample_images(
# accelerator,
# args,
# None,
# global_step,
# accelerator.device,
# vae,
# [tokenizer1, tokenizer2],
# [text_encoder1, text_encoder2],
# unet,
# controlnet=control_net,
# )
sdxl_train_util.sample_images(
accelerator,
args,
None,
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2, unwrap_model(text_encoder2)],
unet,
controlnet=control_net,
)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
@@ -628,7 +629,7 @@ def train(args):
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
[text_encoder1, text_encoder2, unwrap_model(text_encoder2)],
unet,
controlnet=control_net,
)