mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
fix typos
This commit is contained in:
@@ -250,7 +250,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
# )
|
# )
|
||||||
|
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||||
model_pred = unet(
|
model_pred = unet(
|
||||||
img=packed_noisy_model_input,
|
img=packed_noisy_model_input,
|
||||||
img_ids=img_ids,
|
img_ids=img_ids,
|
||||||
|
|||||||
@@ -685,11 +685,11 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
attn = attention(q, k, v, pe=pe)
|
attn = attention(q, k, v, pe=pe)
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img blocks
|
||||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||||
|
|
||||||
# calculate the txt bloks
|
# calculate the txt blocks
|
||||||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||||
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||||
return img, txt
|
return img, txt
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ MODEL_VERSION_FLUX_V1 = "flux1"
|
|||||||
|
|
||||||
|
|
||||||
def load_flow_model(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.Flux:
|
def load_flow_model(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.Flux:
|
||||||
logger.info(f"Bulding Flux model {name}")
|
logger.info(f"Building Flux model {name}")
|
||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
model = flux_models.Flux(flux_models.configs[name].params).to(dtype)
|
model = flux_models.Flux(flux_models.configs[name].params).to(dtype)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user