mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
fix: revert strategy_sd.py and remove latents from huber
This commit is contained in:
@@ -30,171 +30,81 @@ class SdTokenizeStrategy(TokenizeStrategy):
|
||||
)
|
||||
else:
|
||||
self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
|
||||
if max_length is None:
|
||||
self.max_length = self.tokenizer.model_max_length
|
||||
else:
|
||||
self.max_length = max_length + 2
|
||||
|
||||
self.break_separator = "BREAK"
|
||||
|
||||
def _split_on_break(self, text: str) -> List[str]:
|
||||
"""Split text on BREAK separator (case-sensitive), filtering empty segments."""
|
||||
segments = text.split(self.break_separator)
|
||||
# Filter out empty or whitespace-only segments
|
||||
filtered = [seg.strip() for seg in segments if seg.strip()]
|
||||
# Return at least one segment to maintain consistency
|
||||
return filtered if filtered else [""]
|
||||
|
||||
def _tokenize_segments(self, segments: List[str], weighted: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Tokenize multiple segments and concatenate them."""
|
||||
if len(segments) == 1:
|
||||
# No BREAK present, use existing logic
|
||||
if weighted:
|
||||
return self._get_input_ids(self.tokenizer, segments[0], self.max_length, weighted=True)
|
||||
else:
|
||||
tokens = self._get_input_ids(self.tokenizer, segments[0], self.max_length)
|
||||
return tokens, None
|
||||
|
||||
# Multiple segments - tokenize each separately
|
||||
all_tokens = []
|
||||
all_weights = [] if weighted else None
|
||||
|
||||
for segment in segments:
|
||||
if weighted:
|
||||
seg_tokens, seg_weights = self._get_input_ids(self.tokenizer, segment, self.max_length, weighted=True)
|
||||
all_tokens.append(seg_tokens)
|
||||
all_weights.append(seg_weights)
|
||||
else:
|
||||
seg_tokens = self._get_input_ids(self.tokenizer, segment, self.max_length)
|
||||
all_tokens.append(seg_tokens)
|
||||
|
||||
# Concatenate along the sequence dimension (dim=1 for tokens that are [batch, seq_len] or [n_chunks, seq_len])
|
||||
combined_tokens = torch.cat(all_tokens, dim=1) if all_tokens[0].dim() == 2 else torch.cat(all_tokens, dim=0)
|
||||
combined_weights = None
|
||||
if weighted:
|
||||
combined_weights = torch.cat(all_weights, dim=1) if all_weights[0].dim() == 2 else torch.cat(all_weights, dim=0)
|
||||
|
||||
return combined_tokens, combined_weights
|
||||
|
||||
|
||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
|
||||
tokens_list = []
|
||||
for t in text:
|
||||
segments = self._split_on_break(t)
|
||||
tokens, _ = self._tokenize_segments(segments, weighted=False)
|
||||
tokens_list.append(tokens)
|
||||
|
||||
# Pad tokens to same length for stacking
|
||||
max_length = max(t.shape[-1] for t in tokens_list)
|
||||
padded_tokens = []
|
||||
for tokens in tokens_list:
|
||||
if tokens.shape[-1] < max_length:
|
||||
# Pad with pad_token_id
|
||||
pad_size = max_length - tokens.shape[-1]
|
||||
if tokens.dim() == 2:
|
||||
padding = torch.full((tokens.shape[0], pad_size), self.tokenizer.pad_token_id, dtype=tokens.dtype)
|
||||
tokens = torch.cat([tokens, padding], dim=1)
|
||||
else:
|
||||
padding = torch.full((pad_size,), self.tokenizer.pad_token_id, dtype=tokens.dtype)
|
||||
tokens = torch.cat([tokens, padding], dim=0)
|
||||
padded_tokens.append(tokens)
|
||||
|
||||
return [torch.stack(padded_tokens, dim=0)]
|
||||
|
||||
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
|
||||
|
||||
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
|
||||
tokens_list = []
|
||||
weights_list = []
|
||||
for t in text:
|
||||
segments = self._split_on_break(t)
|
||||
tokens, weights = self._tokenize_segments(segments, weighted=True)
|
||||
tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True)
|
||||
tokens_list.append(tokens)
|
||||
weights_list.append(weights)
|
||||
|
||||
return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)]
|
||||
|
||||
|
||||
class SdTextEncodingStrategy(TextEncodingStrategy):
|
||||
def __init__(self, clip_skip: Optional[int] = None) -> None:
|
||||
self.clip_skip = clip_skip
|
||||
|
||||
def _encode_with_clip_skip(self, text_encoder: Any, tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""Encode tokens with optional CLIP skip."""
|
||||
if self.clip_skip is None:
|
||||
return text_encoder(tokens)[0]
|
||||
|
||||
enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True)
|
||||
hidden_states = enc_out["hidden_states"][-self.clip_skip]
|
||||
return text_encoder.text_model.final_layer_norm(hidden_states)
|
||||
|
||||
def _reconstruct_embeddings(self, encoder_hidden_states: torch.Tensor, tokens: torch.Tensor,
|
||||
max_token_length: int, model_max_length: int,
|
||||
tokenizer: Any) -> torch.Tensor:
|
||||
"""Reconstruct embeddings from chunked encoding."""
|
||||
v1 = tokenizer.pad_token_id == tokenizer.eos_token_id
|
||||
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
||||
|
||||
if not v1:
|
||||
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す
|
||||
for i in range(1, max_token_length, model_max_length):
|
||||
chunk = encoder_hidden_states[:, i : i + model_max_length - 2]
|
||||
if i > 0:
|
||||
for j in range(len(chunk)):
|
||||
if tokens[j, 1] == tokenizer.eos_token:
|
||||
chunk[j, 0] = chunk[j, 1]
|
||||
states_list.append(chunk)
|
||||
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1))
|
||||
else:
|
||||
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
||||
for i in range(1, max_token_length, model_max_length):
|
||||
states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2])
|
||||
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1))
|
||||
|
||||
return torch.cat(states_list, dim=1)
|
||||
|
||||
def _apply_weights_single_chunk(self, encoder_hidden_states: torch.Tensor,
|
||||
weights: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply weights for single chunk case (no max_token_length)."""
|
||||
return encoder_hidden_states * weights.squeeze(1).unsqueeze(2)
|
||||
|
||||
def _apply_weights_multi_chunk(self, encoder_hidden_states: torch.Tensor,
|
||||
weights: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply weights for multi-chunk case (with max_token_length)."""
|
||||
for i in range(weights.shape[1]):
|
||||
start_idx = i * 75 + 1
|
||||
end_idx = i * 75 + 76
|
||||
encoder_hidden_states[:, start_idx:end_idx] = (
|
||||
encoder_hidden_states[:, start_idx:end_idx] * weights[:, i, 1:-1].unsqueeze(-1)
|
||||
)
|
||||
return encoder_hidden_states
|
||||
|
||||
|
||||
def encode_tokens(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
||||
) -> List[torch.Tensor]:
|
||||
text_encoder = models[0]
|
||||
tokens = tokens[0]
|
||||
sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy
|
||||
|
||||
|
||||
# tokens: b,n,77
|
||||
b_size = tokens.size()[0]
|
||||
max_token_length = tokens.size()[1] * tokens.size()[2]
|
||||
model_max_length = sd_tokenize_strategy.tokenizer.model_max_length
|
||||
|
||||
tokens = tokens.reshape((-1, model_max_length))
|
||||
tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77
|
||||
|
||||
tokens = tokens.to(text_encoder.device)
|
||||
|
||||
encoder_hidden_states = self._encode_with_clip_skip(text_encoder, tokens)
|
||||
|
||||
if self.clip_skip is None:
|
||||
encoder_hidden_states = text_encoder(tokens)[0]
|
||||
else:
|
||||
enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True)
|
||||
encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip]
|
||||
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
||||
|
||||
# bs*3, 77, 768 or 1024
|
||||
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
|
||||
|
||||
|
||||
if max_token_length != model_max_length:
|
||||
encoder_hidden_states = self._reconstruct_embeddings(
|
||||
encoder_hidden_states, tokens, max_token_length,
|
||||
model_max_length, sd_tokenize_strategy.tokenizer
|
||||
)
|
||||
|
||||
v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id
|
||||
if not v1:
|
||||
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
||||
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
||||
for i in range(1, max_token_length, model_max_length):
|
||||
chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # <BOS> の後から 最後の前まで
|
||||
if i > 0:
|
||||
for j in range(len(chunk)):
|
||||
if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token:
|
||||
# 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
||||
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
||||
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
||||
encoder_hidden_states = torch.cat(states_list, dim=1)
|
||||
else:
|
||||
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
||||
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
||||
for i in range(1, max_token_length, model_max_length):
|
||||
states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
|
||||
encoder_hidden_states = torch.cat(states_list, dim=1)
|
||||
|
||||
return [encoder_hidden_states]
|
||||
|
||||
|
||||
def encode_tokens_with_weights(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
@@ -203,15 +113,23 @@ class SdTextEncodingStrategy(TextEncodingStrategy):
|
||||
weights_list: List[torch.Tensor],
|
||||
) -> List[torch.Tensor]:
|
||||
encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0]
|
||||
|
||||
weights = weights_list[0].to(encoder_hidden_states.device)
|
||||
|
||||
if weights.shape[1] == 1:
|
||||
encoder_hidden_states = self._apply_weights_single_chunk(encoder_hidden_states, weights)
|
||||
|
||||
# apply weights
|
||||
if weights.shape[1] == 1: # no max_token_length
|
||||
# weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
|
||||
encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2)
|
||||
else:
|
||||
encoder_hidden_states = self._apply_weights_multi_chunk(encoder_hidden_states, weights)
|
||||
|
||||
# weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
|
||||
for i in range(weights.shape[1]):
|
||||
encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[
|
||||
:, i, 1:-1
|
||||
].unsqueeze(-1)
|
||||
|
||||
return [encoder_hidden_states]
|
||||
|
||||
|
||||
class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
|
||||
# and we keep the old npz for the backward compatibility.
|
||||
|
||||
@@ -385,7 +385,7 @@ def train(args):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, latents, noise_scheduler)
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
|
||||
Reference in New Issue
Block a user