mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 08:36:41 +00:00
fix: review with Copilot
This commit is contained in:
@@ -72,16 +72,6 @@ def train(args):
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
) or not args.unsloth_offload_checkpointing, "blocks_to_swap is not supported with unsloth_offload_checkpointing"
|
||||
|
||||
# # Flash attention: validate availability
|
||||
# if args.flash_attn:
|
||||
# try:
|
||||
# import flash_attn # noqa: F401
|
||||
|
||||
# logger.info("Flash Attention enabled for DiT blocks")
|
||||
# except ImportError:
|
||||
# logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
|
||||
# args.flash_attn = False
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_dreambooth_method = args.in_json is None
|
||||
|
||||
@@ -348,7 +338,7 @@ def train(args):
|
||||
assert args.mixed_precision == "bf16", "full_bf16 requires mixed_precision='bf16'"
|
||||
accelerator.print("enable full bf16 training.")
|
||||
else:
|
||||
dit_weight_dtype = torch.float32 # Default to float32
|
||||
dit_weight_dtype = torch.float32 # If neither full_fp16 nor full_bf16, the model weights should be in float32
|
||||
dit.to(dit_weight_dtype) # convert dit to target weight dtype
|
||||
|
||||
# move text encoder to GPU if not cached
|
||||
@@ -431,6 +421,7 @@ def train(args):
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
# Copy for noise and timestep generation, because noise_scheduler may be changed during training in future
|
||||
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
@@ -540,7 +531,7 @@ def train(args):
|
||||
|
||||
# Get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, accelerator.device, dit_weight_dtype
|
||||
args, noise_scheduler_copy, latents, noise, accelerator.device, dit_weight_dtype
|
||||
)
|
||||
timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32
|
||||
|
||||
|
||||
@@ -352,10 +352,6 @@ class Attention(nn.Module):
|
||||
|
||||
return q, k, v
|
||||
|
||||
# def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||
# result = self.attn_op(q, k, v) # [B, S, H, D]
|
||||
# return self.output_dropout(self.output_proj(result))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from safetensors.torch import load_file, save_file
|
||||
from accelerate.utils import set_module_tensor_to_device # kept for potential future use
|
||||
from accelerate import init_empty_weights
|
||||
@@ -38,11 +37,11 @@ def load_anima_model(
|
||||
loading_device: Union[str, torch.device],
|
||||
dit_weight_dtype: Optional[torch.dtype],
|
||||
fp8_scaled: bool = False,
|
||||
lora_weights_list: Optional[Dict[str, torch.Tensor]] = None,
|
||||
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]] = None,
|
||||
lora_multipliers: Optional[list[float]] = None,
|
||||
) -> anima_models.Anima:
|
||||
"""
|
||||
Load a HunyuanImage model from the specified checkpoint.
|
||||
Load Anima model from the specified checkpoint.
|
||||
|
||||
Args:
|
||||
device (Union[str, torch.device]): Device for optimization or merging
|
||||
@@ -53,7 +52,7 @@ def load_anima_model(
|
||||
dit_weight_dtype (Optional[torch.dtype]): Data type of the DiT weights.
|
||||
If None, it will be loaded as is (same as the state_dict) or scaled for fp8. if not None, model weights will be casted to this dtype.
|
||||
fp8_scaled (bool): Whether to use fp8 scaling for the model weights.
|
||||
lora_weights_list (Optional[Dict[str, torch.Tensor]]): LoRA weights to apply, if any.
|
||||
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): LoRA weights to apply, if any.
|
||||
lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any.
|
||||
"""
|
||||
# dit_weight_dtype is None for fp8_scaled
|
||||
|
||||
@@ -28,8 +28,6 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
import logging
|
||||
|
||||
from library.safetensors_utils import load_safetensors
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
Reference in New Issue
Block a user