From 46f9e24b24bdbc4c3d839fc3362c2883d637fb2f Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 10 Oct 2025 14:54:37 -0400 Subject: [PATCH] fix: revert strategy_sd.py and remove latents from huber --- library/strategy_sd.py | 196 ++++++++++++----------------------------- train_db.py | 2 +- 2 files changed, 58 insertions(+), 140 deletions(-) diff --git a/library/strategy_sd.py b/library/strategy_sd.py index 45a59d72..d0a3a68b 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -30,171 +30,81 @@ class SdTokenizeStrategy(TokenizeStrategy): ) else: self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + if max_length is None: self.max_length = self.tokenizer.model_max_length else: self.max_length = max_length + 2 - - self.break_separator = "BREAK" - - def _split_on_break(self, text: str) -> List[str]: - """Split text on BREAK separator (case-sensitive), filtering empty segments.""" - segments = text.split(self.break_separator) - # Filter out empty or whitespace-only segments - filtered = [seg.strip() for seg in segments if seg.strip()] - # Return at least one segment to maintain consistency - return filtered if filtered else [""] - - def _tokenize_segments(self, segments: List[str], weighted: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """Tokenize multiple segments and concatenate them.""" - if len(segments) == 1: - # No BREAK present, use existing logic - if weighted: - return self._get_input_ids(self.tokenizer, segments[0], self.max_length, weighted=True) - else: - tokens = self._get_input_ids(self.tokenizer, segments[0], self.max_length) - return tokens, None - - # Multiple segments - tokenize each separately - all_tokens = [] - all_weights = [] if weighted else None - - for segment in segments: - if weighted: - seg_tokens, seg_weights = self._get_input_ids(self.tokenizer, segment, self.max_length, weighted=True) - all_tokens.append(seg_tokens) - all_weights.append(seg_weights) - else: - seg_tokens = self._get_input_ids(self.tokenizer, segment, self.max_length) - all_tokens.append(seg_tokens) - - # Concatenate along the sequence dimension (dim=1 for tokens that are [batch, seq_len] or [n_chunks, seq_len]) - combined_tokens = torch.cat(all_tokens, dim=1) if all_tokens[0].dim() == 2 else torch.cat(all_tokens, dim=0) - combined_weights = None - if weighted: - combined_weights = torch.cat(all_weights, dim=1) if all_weights[0].dim() == 2 else torch.cat(all_weights, dim=0) - - return combined_tokens, combined_weights - + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: text = [text] if isinstance(text, str) else text - - tokens_list = [] - for t in text: - segments = self._split_on_break(t) - tokens, _ = self._tokenize_segments(segments, weighted=False) - tokens_list.append(tokens) - - # Pad tokens to same length for stacking - max_length = max(t.shape[-1] for t in tokens_list) - padded_tokens = [] - for tokens in tokens_list: - if tokens.shape[-1] < max_length: - # Pad with pad_token_id - pad_size = max_length - tokens.shape[-1] - if tokens.dim() == 2: - padding = torch.full((tokens.shape[0], pad_size), self.tokenizer.pad_token_id, dtype=tokens.dtype) - tokens = torch.cat([tokens, padding], dim=1) - else: - padding = torch.full((pad_size,), self.tokenizer.pad_token_id, dtype=tokens.dtype) - tokens = torch.cat([tokens, padding], dim=0) - padded_tokens.append(tokens) - - return [torch.stack(padded_tokens, dim=0)] - - def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)] + + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]: text = [text] if isinstance(text, str) else text - tokens_list = [] weights_list = [] for t in text: - segments = self._split_on_break(t) - tokens, weights = self._tokenize_segments(segments, weighted=True) + tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True) tokens_list.append(tokens) weights_list.append(weights) - return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)] class SdTextEncodingStrategy(TextEncodingStrategy): def __init__(self, clip_skip: Optional[int] = None) -> None: self.clip_skip = clip_skip - - def _encode_with_clip_skip(self, text_encoder: Any, tokens: torch.Tensor) -> torch.Tensor: - """Encode tokens with optional CLIP skip.""" - if self.clip_skip is None: - return text_encoder(tokens)[0] - - enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True) - hidden_states = enc_out["hidden_states"][-self.clip_skip] - return text_encoder.text_model.final_layer_norm(hidden_states) - - def _reconstruct_embeddings(self, encoder_hidden_states: torch.Tensor, tokens: torch.Tensor, - max_token_length: int, model_max_length: int, - tokenizer: Any) -> torch.Tensor: - """Reconstruct embeddings from chunked encoding.""" - v1 = tokenizer.pad_token_id == tokenizer.eos_token_id - states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # - - if not v1: - # v2: ... ... の三連を ... ... へ戻す - for i in range(1, max_token_length, model_max_length): - chunk = encoder_hidden_states[:, i : i + model_max_length - 2] - if i > 0: - for j in range(len(chunk)): - if tokens[j, 1] == tokenizer.eos_token: - chunk[j, 0] = chunk[j, 1] - states_list.append(chunk) - states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) - else: - # v1: ... の三連を ... へ戻す - for i in range(1, max_token_length, model_max_length): - states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) - states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) - - return torch.cat(states_list, dim=1) - - def _apply_weights_single_chunk(self, encoder_hidden_states: torch.Tensor, - weights: torch.Tensor) -> torch.Tensor: - """Apply weights for single chunk case (no max_token_length).""" - return encoder_hidden_states * weights.squeeze(1).unsqueeze(2) - - def _apply_weights_multi_chunk(self, encoder_hidden_states: torch.Tensor, - weights: torch.Tensor) -> torch.Tensor: - """Apply weights for multi-chunk case (with max_token_length).""" - for i in range(weights.shape[1]): - start_idx = i * 75 + 1 - end_idx = i * 75 + 76 - encoder_hidden_states[:, start_idx:end_idx] = ( - encoder_hidden_states[:, start_idx:end_idx] * weights[:, i, 1:-1].unsqueeze(-1) - ) - return encoder_hidden_states - + def encode_tokens( self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] ) -> List[torch.Tensor]: text_encoder = models[0] tokens = tokens[0] sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy - + + # tokens: b,n,77 b_size = tokens.size()[0] max_token_length = tokens.size()[1] * tokens.size()[2] model_max_length = sd_tokenize_strategy.tokenizer.model_max_length - - tokens = tokens.reshape((-1, model_max_length)) + tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77 + tokens = tokens.to(text_encoder.device) - - encoder_hidden_states = self._encode_with_clip_skip(text_encoder, tokens) + + if self.clip_skip is None: + encoder_hidden_states = text_encoder(tokens)[0] + else: + enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip] + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) + + # bs*3, 77, 768 or 1024 encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) - + if max_token_length != model_max_length: - encoder_hidden_states = self._reconstruct_embeddings( - encoder_hidden_states, tokens, max_token_length, - model_max_length, sd_tokenize_strategy.tokenizer - ) - + v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id + if not v1: + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, model_max_length): + chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # の後から 最後の前まで + if i > 0: + for j in range(len(chunk)): + if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token: + # 空、つまり ...のパターン + chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか + encoder_hidden_states = torch.cat(states_list, dim=1) + else: + # v1: ... の三連を ... へ戻す + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, model_max_length): + states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # + encoder_hidden_states = torch.cat(states_list, dim=1) + return [encoder_hidden_states] - + def encode_tokens_with_weights( self, tokenize_strategy: TokenizeStrategy, @@ -203,15 +113,23 @@ class SdTextEncodingStrategy(TextEncodingStrategy): weights_list: List[torch.Tensor], ) -> List[torch.Tensor]: encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0] + weights = weights_list[0].to(encoder_hidden_states.device) - - if weights.shape[1] == 1: - encoder_hidden_states = self._apply_weights_single_chunk(encoder_hidden_states, weights) + + # apply weights + if weights.shape[1] == 1: # no max_token_length + # weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768) + encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2) else: - encoder_hidden_states = self._apply_weights_multi_chunk(encoder_hidden_states, weights) - + # weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768) + for i in range(weights.shape[1]): + encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[ + :, i, 1:-1 + ].unsqueeze(-1) + return [encoder_hidden_states] + class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): # sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix. # and we keep the old npz for the backward compatibility. diff --git a/train_db.py b/train_db.py index 7209f7dc..689d6c97 100644 --- a/train_db.py +++ b/train_db.py @@ -385,7 +385,7 @@ def train(args): else: target = noise - huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, latents, noise_scheduler) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch)