Merge pull request #2178 from kohya-ss/update-libraries

Feat: Update libraries, remove warnings
This commit is contained in:
Kohya S.
2025-08-28 08:42:53 +09:00
committed by GitHub
13 changed files with 135 additions and 78 deletions

View File

@@ -22,7 +22,7 @@ jobs:
matrix:
os: [ubuntu-latest]
python-version: ["3.10"] # Python versions to test
pytorch-version: ["2.4.0"] # PyTorch versions to test
pytorch-version: ["2.4.0", "2.6.0"] # PyTorch versions to test
steps:
- uses: actions/checkout@v4

View File

@@ -4,18 +4,29 @@ This repository contains training, generation and utility scripts for Stable Dif
This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training.
__Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchvision==0.19.0` with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__
__Please update PyTorch to 2.6.0 or later. We have tested with `torch==2.6.0` and `torchvision==0.21.0` with CUDA 12.4. `requirements.txt` is also updated, so please update the requirements.__
The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
`pip3 install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124`
If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`.
For RTX 50 series GPUs, PyTorch 2.8.0 with CUDA 12.8/9 should be used. `requirements.txt` will work with this version.
If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed` (appropriate version is not confirmed yet).
- [FLUX.1 training](#flux1-training)
- [SD3 training](#sd3-training)
### Recent Updates
Aug 28, 2025:
- In order to support the latest GPUs and features, we have updated the **PyTorch and library versions**. PR [#2178](https://github.com/kohya-ss/sd-scripts/pull/2178) There are many changes, so please let us know if you encounter any issues.
- The PyTorch version used for testing has been updated to 2.6.0. We have confirmed that it works with PyTorch 2.6.0 and later.
- The `requirements.txt` has been updated, so please update your dependencies.
- You can update the dependencies with `pip install -r requirements.txt`.
- The version specification for `bitsandbytes` has been removed. If you encounter errors on RTX 50 series GPUs, please update it with `pip install -U bitsandbytes`.
- We have modified each script to minimize warnings as much as possible.
- The modified scripts will work in the old environment (library versions), but please update them when convenient.
Jul 30, 2025:
- **Breaking Change**: For FLUX.1 and Chroma training, the CFG (Classifier-Free Guidance, using negative prompts) scale option for sample image generation during training has been changed from `--g` to `--l`. The `--g` option is now used for the embedded guidance scale. Please update your prompts accordingly. See [Sample Image Generation During Training](#sample-image-generation-during-training) for details.

View File

@@ -6,6 +6,7 @@ import os
import torch
from library.device_utils import init_ipex
init_ipex()
import diffusers
@@ -14,8 +15,10 @@ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # ,
from safetensors.torch import load_file, save_file
from library.original_unet import UNet2DConditionModel
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# DiffUsers版StableDiffusionのモデルパラメータ
@@ -974,7 +977,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
checkpoint = None
state_dict = load_file(ckpt_path) # , device) # may causes error
else:
checkpoint = torch.load(ckpt_path, map_location=device)
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:

View File

@@ -114,8 +114,10 @@ from torch import nn
from torch.nn import functional as F
from einops import rearrange
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
@@ -530,7 +532,9 @@ class DownBlock2D(nn.Module):
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = resnet(hidden_states, temb)
@@ -626,15 +630,9 @@ class CrossAttention(nn.Module):
hidden_states,
encoder_hidden_states,
attention_mask,
) = translate_attention_names_from_diffusers(
hidden_states=hidden_states, context=context, mask=mask, **kwargs
)
) = translate_attention_names_from_diffusers(hidden_states=hidden_states, context=context, mask=mask, **kwargs)
return self.processor(
attn=self,
hidden_states=hidden_states,
encoder_hidden_states=context,
attention_mask=mask,
**kwargs
attn=self, hidden_states=hidden_states, encoder_hidden_states=context, attention_mask=mask, **kwargs
)
if self.use_memory_efficient_attention_xformers:
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
@@ -748,13 +746,14 @@ class CrossAttention(nn.Module):
out = self.to_out[0](out)
return out
def translate_attention_names_from_diffusers(
hidden_states: torch.FloatTensor,
context: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
# HF naming
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None
attention_mask: Optional[torch.FloatTensor] = None,
):
# translate from hugging face diffusers
context = context if context is not None else encoder_hidden_states
@@ -764,6 +763,7 @@ def translate_attention_names_from_diffusers(
return hidden_states, context, mask
# feedforward
class GEGLU(nn.Module):
r"""
@@ -1015,9 +1015,11 @@ class CrossAttnDownBlock2D(nn.Module):
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False
)[0]
else:
hidden_states = resnet(hidden_states, temb)
@@ -1098,10 +1100,12 @@ class UNetMidBlock2DCrossAttn(nn.Module):
if attn is not None:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
if attn is not None:
hidden_states = attn(hidden_states, encoder_hidden_states).sample
@@ -1201,7 +1205,9 @@ class UpBlock2D(nn.Module):
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = resnet(hidden_states, temb)
@@ -1296,9 +1302,11 @@ class CrossAttnUpBlock2D(nn.Module):
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False
)[0]
else:
hidden_states = resnet(hidden_states, temb)

View File

@@ -683,7 +683,7 @@ class BaseDataset(torch.utils.data.Dataset):
resolution: Optional[Tuple[int, int]],
network_multiplier: float,
debug_dataset: bool,
resize_interpolation: Optional[str] = None
resize_interpolation: Optional[str] = None,
) -> None:
super().__init__()
@@ -719,7 +719,9 @@ class BaseDataset(torch.utils.data.Dataset):
self.image_transforms = IMAGE_TRANSFORMS
if resize_interpolation is not None:
assert validate_interpolation_fn(resize_interpolation), f"Resize interpolation \"{resize_interpolation}\" is not a valid interpolation"
assert validate_interpolation_fn(
resize_interpolation
), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation'
self.resize_interpolation = resize_interpolation
self.image_data: Dict[str, ImageInfo] = {}
@@ -1613,7 +1615,11 @@ class BaseDataset(torch.utils.data.Dataset):
if self.enable_bucket:
img, original_size, crop_ltrb = trim_and_resize_if_required(
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size, resize_interpolation=image_info.resize_interpolation
subset.random_crop,
img,
image_info.bucket_reso,
image_info.resized_size,
resize_interpolation=image_info.resize_interpolation,
)
else:
if face_cx > 0: # 顔位置情報あり
@@ -2101,7 +2107,9 @@ class DreamBoothDataset(BaseDataset):
for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
info.resize_interpolation = (
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
)
if size is not None:
info.image_size = size
if subset.is_reg:
@@ -2385,7 +2393,7 @@ class ControlNetDataset(BaseDataset):
bucket_no_upscale: bool,
debug_dataset: bool,
validation_split: float,
validation_seed: Optional[int],
validation_seed: Optional[int],
resize_interpolation: Optional[str] = None,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
@@ -2448,7 +2456,7 @@ class ControlNetDataset(BaseDataset):
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
self.validation_split = validation_split
self.validation_seed = validation_seed
self.validation_seed = validation_seed
self.resize_interpolation = resize_interpolation
# assert all conditioning data exists
@@ -2538,7 +2546,14 @@ class ControlNetDataset(BaseDataset):
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
cond_img = resize_image(cond_img, original_size_hw[1], original_size_hw[0], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
cond_img = resize_image(
cond_img,
original_size_hw[1],
original_size_hw[0],
target_size_hw[1],
target_size_hw[0],
self.resize_interpolation,
)
# TODO support random crop
# 現在サポートしているcropはrandomではなく中央のみ
@@ -2552,7 +2567,14 @@ class ControlNetDataset(BaseDataset):
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
# resize to target
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
cond_img = resize_image(cond_img, cond_img.shape[0], cond_img.shape[1], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
cond_img = resize_image(
cond_img,
cond_img.shape[0],
cond_img.shape[1],
target_size_hw[1],
target_size_hw[0],
self.resize_interpolation,
)
if flipped:
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
@@ -3000,7 +3022,9 @@ def load_images_and_masks_for_caching(
for info in image_infos:
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)
image, original_size, crop_ltrb = trim_and_resize_if_required(
random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation
)
original_sizes.append(original_size)
crop_ltrbs.append(crop_ltrb)
@@ -3041,7 +3065,9 @@ def cache_batch_latents(
for info in image_infos:
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)
image, original_size, crop_ltrb = trim_and_resize_if_required(
random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation
)
info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
@@ -3482,9 +3508,9 @@ def get_sai_model_spec(
textual_inversion: bool,
is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA
sd3: str = None,
flux: str = None, # "dev", "schnell" or "chroma"
flux: str = None, # "dev", "schnell" or "chroma"
lumina: str = None,
optional_metadata: dict[str, str] | None = None
optional_metadata: dict[str, str] | None = None,
):
timestamp = time.time()
@@ -3513,7 +3539,7 @@ def get_sai_model_spec(
# Extract metadata_* fields from args and merge with optional_metadata
extracted_metadata = {}
# Extract all metadata_* attributes from args
for attr_name in dir(args):
if attr_name.startswith("metadata_") and not attr_name.startswith("metadata___"):
@@ -3523,7 +3549,7 @@ def get_sai_model_spec(
field_name = attr_name[9:] # len("metadata_") = 9
if field_name not in ["title", "author", "description", "license", "tags"]:
extracted_metadata[field_name] = value
# Merge extracted metadata with provided optional_metadata
all_optional_metadata = {**extracted_metadata}
if optional_metadata:
@@ -3546,7 +3572,7 @@ def get_sai_model_spec(
tags=args.metadata_tags,
timesteps=timesteps,
clip_skip=args.clip_skip, # None or int
model_config=model_config,
model_config=model_config,
optional_metadata=all_optional_metadata if all_optional_metadata else None,
)
return metadata
@@ -3562,7 +3588,7 @@ def get_sai_model_spec_dataclass(
sd3: str = None,
flux: str = None,
lumina: str = None,
optional_metadata: dict[str, str] | None = None
optional_metadata: dict[str, str] | None = None,
) -> sai_model_spec.ModelSpecMetadata:
"""
Get ModelSpec metadata as a dataclass - preferred for new code.
@@ -5558,11 +5584,12 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
def patch_accelerator_for_fp16_training(accelerator):
from accelerate import DistributedType
if accelerator.distributed_type == DistributedType.DEEPSPEED:
return
org_unscale_grads = accelerator.scaler._unscale_grads_
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
@@ -6054,7 +6081,6 @@ def get_noise_noisy_latents_and_timesteps(
b_size = latents.shape[0]
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device)
# Add noise to the latents according to the noise magnitude at each timestep
@@ -6279,7 +6305,6 @@ def line_to_prompt_dict(line: str) -> dict:
prompt_dict["renorm_cfg"] = float(m.group(1))
continue
except ValueError as ex:
logger.error(f"Exception in parsing / 解析エラー: {parg}")
logger.error(ex)
@@ -6328,7 +6353,7 @@ def sample_images_common(
vae,
tokenizer,
text_encoder,
unet,
unet_wrapped,
prompt_replacement=None,
controlnet=None,
):
@@ -6363,7 +6388,7 @@ def sample_images_common(
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
# unwrap unet and text_encoder(s)
unet = accelerator.unwrap_model(unet)
unet = accelerator.unwrap_model(unet_wrapped)
if isinstance(text_encoder, (list, tuple)):
text_encoder = [accelerator.unwrap_model(te) for te in text_encoder]
else:
@@ -6509,7 +6534,7 @@ def sample_image_inference(
logger.info(f"sample_sampler: {sampler_name}")
if seed is not None:
logger.info(f"seed: {seed}")
with accelerator.autocast():
with accelerator.autocast(), torch.no_grad():
latents = pipeline(
prompt=prompt,
height=height,
@@ -6647,4 +6672,3 @@ class LossRecorder:
if losses == 0:
return 0
return self.loss_total / losses

View File

View File

View File

@@ -0,0 +1,4 @@
# dummy module for pytorch_lightning
class ModelCheckpoint:
pass

View File

@@ -1,28 +1,29 @@
accelerate==0.33.0
transformers==4.44.0
diffusers[torch]==0.25.0
ftfy==6.1.1
accelerate==1.6.0
transformers==4.54.1
diffusers[torch]==0.32.1
ftfy==6.3.1
# albumentations==1.3.0
opencv-python==4.8.1.78
opencv-python==4.10.0.84
einops==0.7.0
pytorch-lightning==1.9.0
bitsandbytes==0.44.0
lion-pytorch==0.0.6
# pytorch-lightning==1.9.0
bitsandbytes
lion-pytorch==0.2.3
schedulefree==1.4
pytorch-optimizer==3.7.0
prodigy-plus-schedule-free==1.9.0
prodigy-plus-schedule-free==1.9.2
prodigyopt==1.1.2
tensorboard
safetensors==0.4.4
safetensors==0.4.5
# gradio==3.16.2
altair==4.2.2
easygui==0.98.3
# altair==4.2.2
# easygui==0.98.3
toml==0.10.2
voluptuous==0.13.1
huggingface-hub==0.24.5
voluptuous==0.15.2
huggingface-hub==0.34.3
# for Image utils
imagesize==1.4.1
numpy<=2.0
numpy
# <=2.0
# for BLIP captioning
# requests==2.28.2
# timm==0.6.12
@@ -41,8 +42,8 @@ numpy<=2.0
# open clip for SDXL
# open-clip-torch==2.20.0
# For logging
rich==13.7.0
rich==14.1.0
# for T5XXL tokenizer (SD3/FLUX)
sentencepiece==0.2.0
sentencepiece==0.2.1
# for kohya_ss library
-e .

View File

@@ -23,7 +23,12 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
self.is_sdxl = True
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
def assert_extra_args(
self,
args,
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
sdxl_train_util.verify_sdxl_training_args(args)
if args.cache_text_encoder_outputs:

View File

@@ -20,7 +20,6 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
self.is_sdxl = True
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
train_dataset_group.verify_bucket_reso_steps(32)

View File

@@ -12,7 +12,7 @@ import toml
from tqdm import tqdm
import torch
from library import deepspeed_utils
from library import deepspeed_utils, strategy_base, strategy_sd
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
@@ -73,7 +73,14 @@ def train(args):
args.seed = random.randint(0, 2**32)
set_seed(args.seed)
tokenizer = train_util.load_tokenizer(args)
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
tokenizer = tokenize_strategy.tokenizer
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
True, args.cache_latents_to_disk, args.vae_batch_size, False
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
@@ -100,7 +107,7 @@ def train(args):
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
@@ -243,12 +250,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(
vae,
args.vae_batch_size,
args.cache_latents_to_disk,
accelerator.is_main_process,
)
train_dataset_group.new_cache_latents(vae, accelerator)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
@@ -267,6 +269,7 @@ def train(args):
# dataloaderを準備する
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
train_dataset_group.set_current_strategies()
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
@@ -451,7 +454,7 @@ def train(args):
latents = latents * 0.18215
b_size = latents.shape[0]
input_ids = batch["input_ids"].to(accelerator.device)
input_ids = batch["input_ids_list"][0].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
# Sample noise that we'll add to the latents

View File

@@ -414,13 +414,12 @@ class NetworkTrainer:
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch['captions'])
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
@@ -1340,7 +1339,7 @@ class NetworkTrainer:
)
NUM_VALIDATION_TIMESTEPS = 4 # 200, 400, 600, 800 TODO make this configurable
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.num_train_timesteps if args.max_timestep is None else args.max_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
validation_timesteps = np.linspace(min_timestep, max_timestep, (NUM_VALIDATION_TIMESTEPS + 2), dtype=int)[1:-1]
validation_total_steps = validation_steps * len(validation_timesteps)
original_args_min_timestep = args.min_timestep