fix: review with Copilot

This commit is contained in:
Kohya S
2026-02-12 08:05:41 +09:00
parent 57aa70ea9e
commit cc1f57bc70
4 changed files with 6 additions and 22 deletions

View File

@@ -72,16 +72,6 @@ def train(args):
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.unsloth_offload_checkpointing, "blocks_to_swap is not supported with unsloth_offload_checkpointing"
# # Flash attention: validate availability
# if args.flash_attn:
# try:
# import flash_attn # noqa: F401
# logger.info("Flash Attention enabled for DiT blocks")
# except ImportError:
# logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
# args.flash_attn = False
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
@@ -348,7 +338,7 @@ def train(args):
assert args.mixed_precision == "bf16", "full_bf16 requires mixed_precision='bf16'"
accelerator.print("enable full bf16 training.")
else:
dit_weight_dtype = torch.float32 # Default to float32
dit_weight_dtype = torch.float32 # If neither full_fp16 nor full_bf16, the model weights should be in float32
dit.to(dit_weight_dtype) # convert dit to target weight dtype
# move text encoder to GPU if not cached
@@ -431,6 +421,7 @@ def train(args):
global_step = 0
noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
# Copy for noise and timestep generation, because noise_scheduler may be changed during training in future
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
if accelerator.is_main_process:
@@ -540,7 +531,7 @@ def train(args):
# Get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, accelerator.device, dit_weight_dtype
args, noise_scheduler_copy, latents, noise, accelerator.device, dit_weight_dtype
)
timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32

View File

@@ -352,10 +352,6 @@ class Attention(nn.Module):
return q, k, v
# def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
# result = self.attn_op(q, k, v) # [B, S, H, D]
# return self.output_dropout(self.output_proj(result))
def forward(
self,
x: torch.Tensor,

View File

@@ -3,7 +3,6 @@
import os
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
from safetensors.torch import load_file, save_file
from accelerate.utils import set_module_tensor_to_device # kept for potential future use
from accelerate import init_empty_weights
@@ -38,11 +37,11 @@ def load_anima_model(
loading_device: Union[str, torch.device],
dit_weight_dtype: Optional[torch.dtype],
fp8_scaled: bool = False,
lora_weights_list: Optional[Dict[str, torch.Tensor]] = None,
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]] = None,
lora_multipliers: Optional[list[float]] = None,
) -> anima_models.Anima:
"""
Load a HunyuanImage model from the specified checkpoint.
Load Anima model from the specified checkpoint.
Args:
device (Union[str, torch.device]): Device for optimization or merging
@@ -53,7 +52,7 @@ def load_anima_model(
dit_weight_dtype (Optional[torch.dtype]): Data type of the DiT weights.
If None, it will be loaded as is (same as the state_dict) or scaled for fp8. if not None, model weights will be casted to this dtype.
fp8_scaled (bool): Whether to use fp8 scaling for the model weights.
lora_weights_list (Optional[Dict[str, torch.Tensor]]): LoRA weights to apply, if any.
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): LoRA weights to apply, if any.
lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any.
"""
# dit_weight_dtype is None for fp8_scaled

View File

@@ -28,8 +28,6 @@ import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import logging
from library.safetensors_utils import load_safetensors
from library.utils import setup_logging