mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Compare commits
4 Commits
006d98ff27
...
32e1591621
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
32e1591621 | ||
|
|
51435f1718 | ||
|
|
fa53f71ec0 | ||
|
|
ceec25dc97 |
@@ -50,6 +50,9 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
|
|||||||
|
|
||||||
### 更新履歴
|
### 更新履歴
|
||||||
|
|
||||||
|
- **Version 0.10.3 (2026-04-02):**
|
||||||
|
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。
|
||||||
|
|
||||||
- **Version 0.10.2 (2026-03-30):**
|
- **Version 0.10.2 (2026-03-30):**
|
||||||
- SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。
|
- SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。
|
||||||
- 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。
|
- 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。
|
||||||
|
|||||||
@@ -47,6 +47,9 @@ If you find this project helpful, please consider supporting its development via
|
|||||||
|
|
||||||
### Change History
|
### Change History
|
||||||
|
|
||||||
|
- **Version 0.10.3 (2026-04-02):**
|
||||||
|
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.
|
||||||
|
|
||||||
- **Version 0.10.2 (2026-03-30):**
|
- **Version 0.10.2 (2026-03-30):**
|
||||||
- LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294).
|
- LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294).
|
||||||
- Please refer to the [documentation](./docs/train_leco.md) for details.
|
- Please refer to the [documentation](./docs/train_leco.md) for details.
|
||||||
|
|||||||
@@ -738,9 +738,9 @@ class FinalLayer(nn.Module):
|
|||||||
x_B_T_H_W_D: torch.Tensor,
|
x_B_T_H_W_D: torch.Tensor,
|
||||||
emb_B_T_D: torch.Tensor,
|
emb_B_T_D: torch.Tensor,
|
||||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||||
|
use_fp32: bool = False,
|
||||||
):
|
):
|
||||||
# Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers)
|
# Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers)
|
||||||
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
|
|
||||||
with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32):
|
with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32):
|
||||||
if self.use_adaln_lora:
|
if self.use_adaln_lora:
|
||||||
assert adaln_lora_B_T_3D is not None
|
assert adaln_lora_B_T_3D is not None
|
||||||
@@ -863,11 +863,11 @@ class Block(nn.Module):
|
|||||||
emb_B_T_D: torch.Tensor,
|
emb_B_T_D: torch.Tensor,
|
||||||
crossattn_emb: torch.Tensor,
|
crossattn_emb: torch.Tensor,
|
||||||
attn_params: attention.AttentionParams,
|
attn_params: attention.AttentionParams,
|
||||||
|
use_fp32: bool = False,
|
||||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
|
|
||||||
if use_fp32:
|
if use_fp32:
|
||||||
# Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context.
|
# Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context.
|
||||||
x_B_T_H_W_D = x_B_T_H_W_D.float()
|
x_B_T_H_W_D = x_B_T_H_W_D.float()
|
||||||
@@ -959,6 +959,7 @@ class Block(nn.Module):
|
|||||||
emb_B_T_D: torch.Tensor,
|
emb_B_T_D: torch.Tensor,
|
||||||
crossattn_emb: torch.Tensor,
|
crossattn_emb: torch.Tensor,
|
||||||
attn_params: attention.AttentionParams,
|
attn_params: attention.AttentionParams,
|
||||||
|
use_fp32: bool = False,
|
||||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||||
@@ -972,6 +973,7 @@ class Block(nn.Module):
|
|||||||
emb_B_T_D,
|
emb_B_T_D,
|
||||||
crossattn_emb,
|
crossattn_emb,
|
||||||
attn_params,
|
attn_params,
|
||||||
|
use_fp32,
|
||||||
rope_emb_L_1_1_D,
|
rope_emb_L_1_1_D,
|
||||||
adaln_lora_B_T_3D,
|
adaln_lora_B_T_3D,
|
||||||
extra_per_block_pos_emb,
|
extra_per_block_pos_emb,
|
||||||
@@ -994,6 +996,7 @@ class Block(nn.Module):
|
|||||||
emb_B_T_D,
|
emb_B_T_D,
|
||||||
crossattn_emb,
|
crossattn_emb,
|
||||||
attn_params,
|
attn_params,
|
||||||
|
use_fp32,
|
||||||
rope_emb_L_1_1_D,
|
rope_emb_L_1_1_D,
|
||||||
adaln_lora_B_T_3D,
|
adaln_lora_B_T_3D,
|
||||||
extra_per_block_pos_emb,
|
extra_per_block_pos_emb,
|
||||||
@@ -1007,6 +1010,7 @@ class Block(nn.Module):
|
|||||||
emb_B_T_D,
|
emb_B_T_D,
|
||||||
crossattn_emb,
|
crossattn_emb,
|
||||||
attn_params,
|
attn_params,
|
||||||
|
use_fp32,
|
||||||
rope_emb_L_1_1_D,
|
rope_emb_L_1_1_D,
|
||||||
adaln_lora_B_T_3D,
|
adaln_lora_B_T_3D,
|
||||||
extra_per_block_pos_emb,
|
extra_per_block_pos_emb,
|
||||||
@@ -1018,6 +1022,7 @@ class Block(nn.Module):
|
|||||||
emb_B_T_D,
|
emb_B_T_D,
|
||||||
crossattn_emb,
|
crossattn_emb,
|
||||||
attn_params,
|
attn_params,
|
||||||
|
use_fp32,
|
||||||
rope_emb_L_1_1_D,
|
rope_emb_L_1_1_D,
|
||||||
adaln_lora_B_T_3D,
|
adaln_lora_B_T_3D,
|
||||||
extra_per_block_pos_emb,
|
extra_per_block_pos_emb,
|
||||||
@@ -1338,16 +1343,19 @@ class Anima(nn.Module):
|
|||||||
|
|
||||||
attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
|
attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
|
||||||
|
|
||||||
|
# Determine whether to use float32 for block computations based on input dtype (use float32 for better stability when input is float16)
|
||||||
|
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
|
||||||
|
|
||||||
for block_idx, block in enumerate(self.blocks):
|
for block_idx, block in enumerate(self.blocks):
|
||||||
if self.blocks_to_swap:
|
if self.blocks_to_swap:
|
||||||
self.offloader.wait_for_block(block_idx)
|
self.offloader.wait_for_block(block_idx)
|
||||||
|
|
||||||
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, **block_kwargs)
|
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, use_fp32, **block_kwargs)
|
||||||
|
|
||||||
if self.blocks_to_swap:
|
if self.blocks_to_swap:
|
||||||
self.offloader.submit_move_blocks(self.blocks, block_idx)
|
self.offloader.submit_move_blocks(self.blocks, block_idx)
|
||||||
|
|
||||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D, use_fp32=use_fp32)
|
||||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
|
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
|
||||||
return x_B_C_Tt_Hp_Wp
|
return x_B_C_Tt_Hp_Wp
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user