diff --git a/library/sd3_models.py b/library/sd3_models.py index 7041420c..e4c0790d 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -1124,7 +1124,12 @@ class Upsample(torch.nn.Module): self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) def forward(self, x): + org_dtype = x.dtype + if x.dtype == torch.bfloat16: + x = x.to(torch.float32) x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if x.dtype != org_dtype: + x = x.to(org_dtype) x = self.conv(x) return x @@ -1263,11 +1268,11 @@ class SDVAE(torch.nn.Module): def dtype(self): return next(self.parameters()).dtype - @torch.autocast("cuda", dtype=torch.float16) + # @torch.autocast("cuda", dtype=torch.float16) def decode(self, latent): return self.decoder(latent) - @torch.autocast("cuda", dtype=torch.float16) + # @torch.autocast("cuda", dtype=torch.float16) def encode(self, image): hidden = self.encoder(image) mean, logvar = torch.chunk(hidden, 2, dim=1) @@ -1630,10 +1635,25 @@ class T5LayerNorm(torch.nn.Module): self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device)) self.variance_epsilon = eps - def forward(self, x): - variance = x.pow(2).mean(-1, keepdim=True) - x = x * torch.rsqrt(variance + self.variance_epsilon) - return self.weight.to(device=x.device, dtype=x.dtype) * x + # def forward(self, x): + # variance = x.pow(2).mean(-1, keepdim=True) + # x = x * torch.rsqrt(variance + self.variance_epsilon) + # return self.weight.to(device=x.device, dtype=x.dtype) * x + + # copy from transformers' T5LayerNorm + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states class T5DenseGatedActDense(torch.nn.Module): @@ -1775,7 +1795,27 @@ class T5Block(torch.nn.Module): def forward(self, x, past_bias=None): x, past_bias = self.layer[0](x, past_bias) + + # copy from transformers' T5Block + # clamp inf values to enable fp16 training + if x.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(x).any(), + torch.finfo(x.dtype).max - 1000, + torch.finfo(x.dtype).max, + ) + x = torch.clamp(x, min=-clamp_value, max=clamp_value) + x = self.layer[-1](x) + # clamp inf values to enable fp16 training + if x.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(x).any(), + torch.finfo(x.dtype).max - 1000, + torch.finfo(x.dtype).max, + ) + x = torch.clamp(x, min=-clamp_value, max=clamp_value) + return x, past_bias @@ -1896,4 +1936,96 @@ def create_t5xxl(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[st return t5 +""" + # snippet for using the T5 model from transformers + + from transformers import T5EncoderModel, T5Config + import accelerate + import json + + T5_CONFIG_JSON = "" +{ + "architectures": [ + "T5EncoderModel" + ], + "classifier_dropout": 0.0, + "d_ff": 10240, + "d_kv": 64, + "d_model": 4096, + "decoder_start_token_id": 0, + "dense_act_fn": "gelu_new", + "dropout_rate": 0.1, + "eos_token_id": 1, + "feed_forward_proj": "gated-gelu", + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": true, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 24, + "num_heads": 64, + "num_layers": 24, + "output_past": true, + "pad_token_id": 0, + "relative_attention_max_distance": 128, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.41.2", + "use_cache": true, + "vocab_size": 32128 +} +"" + config = json.loads(T5_CONFIG_JSON) + config = T5Config(**config) + + # model = T5EncoderModel.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="text_encoder_3") + # print(model.config) + # # model(**load_model.config) + + # with accelerate.init_empty_weights(): + model = T5EncoderModel._from_config(config) # , torch_dtype=dtype) + for key in list(state_dict.keys()): + if key.startswith("transformer."): + new_key = key[len("transformer.") :] + state_dict[new_key] = state_dict.pop(key) + + info = model.load_state_dict(state_dict) + print(info) + model.set_attn_mode = lambda x: None + # model.to("cpu") + + _self = model + + def enc(list_of_token_weight_pairs): + has_batch = isinstance(list_of_token_weight_pairs[0][0], list) + + if has_batch: + list_of_tokens = [] + for pairs in list_of_token_weight_pairs: + tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0] + list_of_tokens.append(tokens) + else: + list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] + + list_of_tokens = np.array(list_of_tokens) + list_of_tokens = torch.from_numpy(list_of_tokens).to("cuda", dtype=torch.long) + out = _self(list_of_tokens) + pooled = None + if has_batch: + return out, pooled + else: + if pooled is not None: + first_pooled = pooled[0:1] + else: + first_pooled = pooled + return out[0], first_pooled + # output = [out[0:1]] + # return torch.cat(output, dim=-2), first_pooled + + model.encode_token_weights = enc + + return model +""" + # endregion