From 808d2d1f48e2f4e544d47464edb2727c03da2f53 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 9 Aug 2024 23:02:51 +0900 Subject: [PATCH] fix typos --- flux_train_network.py | 2 +- library/flux_models.py | 4 ++-- library/flux_utils.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 7c762c86..e4be97ad 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -250,7 +250,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # ) with accelerator.autocast(): - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = unet( img=packed_noisy_model_input, img_ids=img_ids, diff --git a/library/flux_models.py b/library/flux_models.py index d0955e37..92c79bcc 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -685,11 +685,11 @@ class DoubleStreamBlock(nn.Module): attn = attention(q, k, v, pe=pe) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] - # calculate the img bloks + # calculate the img blocks img = img + img_mod1.gate * self.img_attn.proj(img_attn) img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) - # calculate the txt bloks + # calculate the txt blocks txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) return img, txt diff --git a/library/flux_utils.py b/library/flux_utils.py index ba828d50..166cd833 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -20,7 +20,7 @@ MODEL_VERSION_FLUX_V1 = "flux1" def load_flow_model(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.Flux: - logger.info(f"Bulding Flux model {name}") + logger.info(f"Building Flux model {name}") with torch.device("meta"): model = flux_models.Flux(flux_models.configs[name].params).to(dtype)