mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 00:32:25 +00:00
feat: enhance block swap functionality for inference and training in Anima model
This commit is contained in:
@@ -401,6 +401,12 @@ class Attention(nn.Module):
|
||||
rope_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
|
||||
if q.dtype != v.dtype:
|
||||
if (not attn_params.supports_fp32 or attn_params.requires_same_dtype) and torch.is_autocast_enabled():
|
||||
# FlashAttention requires fp16/bf16, xformers require same dtype; only cast when autocast is active.
|
||||
target_dtype = v.dtype # v has fp16/bf16 dtype
|
||||
q = q.to(target_dtype)
|
||||
k = k.to(target_dtype)
|
||||
# return self.compute_attention(q, k, v)
|
||||
qkv = [q, k, v]
|
||||
del q, k, v
|
||||
@@ -1304,6 +1310,20 @@ class Anima(nn.Module):
|
||||
if self.blocks_to_swap:
|
||||
self.blocks = save_blocks
|
||||
|
||||
def switch_block_swap_for_inference(self):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
self.offloader.set_forward_only(True)
|
||||
self.prepare_block_swap_before_forward()
|
||||
print(f"Anima: Block swap set to forward only.")
|
||||
|
||||
def switch_block_swap_for_training(self):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
self.offloader.set_forward_only(False)
|
||||
self.prepare_block_swap_before_forward()
|
||||
print(f"Anima: Block swap set to forward and backward.")
|
||||
|
||||
def prepare_block_swap_before_forward(self):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
|
||||
@@ -444,7 +444,7 @@ def sample_images(
|
||||
args: argparse.Namespace,
|
||||
epoch,
|
||||
steps,
|
||||
dit,
|
||||
dit: anima_models.Anima,
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenize_strategy,
|
||||
@@ -479,6 +479,8 @@ def sample_images(
|
||||
if text_encoder is not None:
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
dit.switch_block_swap_for_inference()
|
||||
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
save_dir = os.path.join(args.output_dir, "sample")
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
@@ -514,6 +516,7 @@ def sample_images(
|
||||
if cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
|
||||
dit.switch_block_swap_for_training()
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
|
||||
|
||||
@@ -154,7 +154,9 @@ def load_anima_model(
|
||||
lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any.
|
||||
"""
|
||||
# dit_weight_dtype is None for fp8_scaled
|
||||
assert (not fp8_scaled and dit_weight_dtype is not None) or (fp8_scaled and dit_weight_dtype is None)
|
||||
assert (
|
||||
not fp8_scaled and dit_weight_dtype is not None
|
||||
) or dit_weight_dtype is None, "dit_weight_dtype should be None when fp8_scaled is True"
|
||||
|
||||
device = torch.device(device)
|
||||
loading_device = torch.device(loading_device)
|
||||
|
||||
Reference in New Issue
Block a user