mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Add sample batch size for Lumina
This commit is contained in:
@@ -880,8 +880,8 @@ class NextDiT(nn.Module):
|
||||
self.n_heads = n_heads
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.cpu_offload_checkpointing = False
|
||||
self.blocks_to_swap = None
|
||||
self.cpu_offload_checkpointing = False # TODO: not yet supported
|
||||
self.blocks_to_swap = None # TODO: not yet supported
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
@@ -982,8 +982,8 @@ class NextDiT(nn.Module):
|
||||
|
||||
l_effective_cap_len = cap_mask.sum(dim=1).tolist()
|
||||
encoder_seq_len = cap_mask.shape[1]
|
||||
|
||||
image_seq_len = (height // self.patch_size) * (width // self.patch_size)
|
||||
|
||||
seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len]
|
||||
max_seq_len = max(seq_lengths)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import math
|
||||
import os
|
||||
import numpy as np
|
||||
import time
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Any, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Any, Union, Generator
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@@ -32,6 +32,59 @@ logger = logging.getLogger(__name__)
|
||||
# region sample images
|
||||
|
||||
|
||||
def batchify(prompt_dicts, batch_size=None) -> Generator[list[dict[str, str]], None, None]:
|
||||
"""
|
||||
Group prompt dictionaries into batches with configurable batch size.
|
||||
|
||||
Args:
|
||||
prompt_dicts (list): List of dictionaries containing prompt parameters.
|
||||
batch_size (int, optional): Number of prompts per batch. Defaults to None.
|
||||
|
||||
Yields:
|
||||
list[dict[str, str]]: Batch of prompts.
|
||||
"""
|
||||
# Validate batch_size
|
||||
if batch_size is not None:
|
||||
if not isinstance(batch_size, int) or batch_size <= 0:
|
||||
raise ValueError("batch_size must be a positive integer or None")
|
||||
|
||||
# Group prompts by their parameters
|
||||
batches = {}
|
||||
for prompt_dict in prompt_dicts:
|
||||
# Extract parameters
|
||||
width = int(prompt_dict.get("width", 1024))
|
||||
height = int(prompt_dict.get("height", 1024))
|
||||
height = max(64, height - height % 8) # round to divisible by 8
|
||||
width = max(64, width - width % 8) # round to divisible by 8
|
||||
guidance_scale = float(prompt_dict.get("scale", 3.5))
|
||||
sample_steps = int(prompt_dict.get("sample_steps", 38))
|
||||
seed = prompt_dict.get("seed", None)
|
||||
seed = int(seed) if seed is not None else None
|
||||
|
||||
# Create a key based on the parameters
|
||||
key = (width, height, guidance_scale, seed, sample_steps)
|
||||
|
||||
# Add the prompt_dict to the corresponding batch
|
||||
if key not in batches:
|
||||
batches[key] = []
|
||||
batches[key].append(prompt_dict)
|
||||
|
||||
# Yield each batch with its parameters
|
||||
for key in batches:
|
||||
prompts = batches[key]
|
||||
if batch_size is None:
|
||||
# Yield the entire group as a single batch
|
||||
yield prompts
|
||||
else:
|
||||
# Split the group into batches of size `batch_size`
|
||||
start = 0
|
||||
while start < len(prompts):
|
||||
end = start + batch_size
|
||||
batch = prompts[start:end]
|
||||
yield batch
|
||||
start = end
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_images(
|
||||
accelerator: Accelerator,
|
||||
@@ -39,9 +92,9 @@ def sample_images(
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
nextdit: lumina_models.NextDiT,
|
||||
vae: torch.nn.Module,
|
||||
vae: AutoEncoder,
|
||||
gemma2_model: Gemma2Model,
|
||||
sample_prompts_gemma2_outputs: List[Tuple[Tensor, Tensor, Tensor]],
|
||||
sample_prompts_gemma2_outputs: dict[str, Tuple[Tensor, Tensor, Tensor]],
|
||||
prompt_replacement: Optional[Tuple[str, str]] = None,
|
||||
controlnet=None,
|
||||
):
|
||||
@@ -54,11 +107,13 @@ def sample_images(
|
||||
epoch (int): Current epoch number.
|
||||
global_step (int): Current global step number.
|
||||
nextdit (lumina_models.NextDiT): The NextDiT model instance.
|
||||
vae (torch.nn.Module): The VAE module.
|
||||
vae (AutoEncoder): The VAE module.
|
||||
gemma2_model (Gemma2Model): The Gemma2 model instance.
|
||||
sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing the encoded prompts, text masks, and timestep for each sample.
|
||||
prompt_replacement (Optional[Tuple[str, str]], optional): Tuple containing the prompt and negative prompt replacements. Defaults to None.
|
||||
controlnet:: ControlNet model
|
||||
sample_prompts_gemma2_outputs (dict[str, Tuple[Tensor, Tensor, Tensor]]):
|
||||
Dictionary ist of tuples containing the encoded prompts, text masks, and timestep for each sample.
|
||||
prompt_replacement (Optional[Tuple[str, str]], optional):
|
||||
Tuple containing the prompt and negative prompt replacements. Defaults to None.
|
||||
controlnet (): ControlNet model, not yet supported
|
||||
|
||||
Returns:
|
||||
None
|
||||
@@ -110,9 +165,12 @@ def sample_images(
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
batch_size = args.sample_batch_size or args.train_batch_size or 1
|
||||
|
||||
if distributed_state.num_processes <= 1:
|
||||
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
||||
for prompt_dict in prompts:
|
||||
# TODO: batch prompts together with buckets of image sizes
|
||||
for prompt_dicts in batchify(prompts, batch_size):
|
||||
sample_image_inference(
|
||||
accelerator,
|
||||
args,
|
||||
@@ -120,7 +178,7 @@ def sample_images(
|
||||
gemma2_model,
|
||||
vae,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
prompt_dicts,
|
||||
epoch,
|
||||
global_step,
|
||||
sample_prompts_gemma2_outputs,
|
||||
@@ -135,7 +193,8 @@ def sample_images(
|
||||
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
|
||||
|
||||
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
|
||||
for prompt_dict in prompt_dict_lists[0]:
|
||||
# TODO: batch prompts together with buckets of image sizes
|
||||
for prompt_dicts in batchify(prompt_dict_lists[0], batch_size):
|
||||
sample_image_inference(
|
||||
accelerator,
|
||||
args,
|
||||
@@ -143,7 +202,7 @@ def sample_images(
|
||||
gemma2_model,
|
||||
vae,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
prompt_dicts,
|
||||
epoch,
|
||||
global_step,
|
||||
sample_prompts_gemma2_outputs,
|
||||
@@ -166,10 +225,10 @@ def sample_image_inference(
|
||||
gemma2_model: Gemma2Model,
|
||||
vae: AutoEncoder,
|
||||
save_dir: str,
|
||||
prompt_dict: Dict[str, str],
|
||||
prompt_dicts: list[Dict[str, str]],
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
sample_prompts_gemma2_outputs: dict[str, List[Tuple[Tensor, Tensor, Tensor]]],
|
||||
sample_prompts_gemma2_outputs: dict[str, Tuple[Tensor, Tensor, Tensor]],
|
||||
prompt_replacement: Optional[Tuple[str, str]] = None,
|
||||
controlnet=None,
|
||||
):
|
||||
@@ -192,43 +251,6 @@ def sample_image_inference(
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
assert isinstance(prompt_dict, dict)
|
||||
# negative_prompt = prompt_dict.get("negative_prompt")
|
||||
sample_steps = int(prompt_dict.get("sample_steps", 38))
|
||||
width = int(prompt_dict.get("width", 1024))
|
||||
height = int(prompt_dict.get("height", 1024))
|
||||
guidance_scale = float(prompt_dict.get("scale", 3.5))
|
||||
seed = prompt_dict.get("seed", None)
|
||||
controlnet_image = prompt_dict.get("controlnet_image")
|
||||
prompt: str = prompt_dict.get("prompt", "")
|
||||
negative_prompt: str = prompt_dict.get("negative_prompt", "")
|
||||
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
||||
|
||||
seed = int(seed) if seed is not None else None
|
||||
assert seed is None or seed > 0, f"Invalid seed {seed}"
|
||||
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
if negative_prompt is not None:
|
||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
|
||||
generator = torch.Generator(device=accelerator.device)
|
||||
if seed is not None:
|
||||
generator.manual_seed(seed)
|
||||
|
||||
# if negative_prompt is None:
|
||||
# negative_prompt = ""
|
||||
height = max(64, height - height % 8) # round to divisible by 8
|
||||
width = max(64, width - width % 8) # round to divisible by 8
|
||||
logger.info(f"prompt: {prompt}")
|
||||
logger.info(f"negative_prompt: {negative_prompt}")
|
||||
logger.info(f"height: {height}")
|
||||
logger.info(f"width: {width}")
|
||||
logger.info(f"sample_steps: {sample_steps}")
|
||||
logger.info(f"scale: {guidance_scale}")
|
||||
# logger.info(f"sample_sampler: {sampler_name}")
|
||||
if seed is not None:
|
||||
logger.info(f"seed: {seed}")
|
||||
|
||||
# encode prompts
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
@@ -237,33 +259,86 @@ def sample_image_inference(
|
||||
assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
|
||||
assert isinstance(encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
|
||||
|
||||
system_prompt = args.system_prompt or ""
|
||||
text_conds = []
|
||||
|
||||
# Apply system prompt to prompts
|
||||
prompt = system_prompt + prompt
|
||||
negative_prompt = system_prompt + negative_prompt
|
||||
# assuming seed, width, height, sample steps, guidance are the same
|
||||
width = int(prompt_dicts[0].get("width", 1024))
|
||||
height = int(prompt_dicts[0].get("height", 1024))
|
||||
height = max(64, height - height % 8) # round to divisible by 8
|
||||
width = max(64, width - width % 8) # round to divisible by 8
|
||||
|
||||
# Get sample prompts from cache
|
||||
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
|
||||
gemma2_conds = sample_prompts_gemma2_outputs[prompt]
|
||||
logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}")
|
||||
guidance_scale = float(prompt_dicts[0].get("scale", 3.5))
|
||||
sample_steps = int(prompt_dicts[0].get("sample_steps", 36))
|
||||
seed = prompt_dicts[0].get("seed", None)
|
||||
seed = int(seed) if seed is not None else None
|
||||
assert seed is None or seed > 0, f"Invalid seed {seed}"
|
||||
generator = torch.Generator(device=accelerator.device)
|
||||
if seed is not None:
|
||||
generator.manual_seed(seed)
|
||||
|
||||
if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs:
|
||||
neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt]
|
||||
logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}")
|
||||
for prompt_dict in prompt_dicts:
|
||||
controlnet_image = prompt_dict.get("controlnet_image")
|
||||
prompt: str = prompt_dict.get("prompt", "")
|
||||
negative_prompt = prompt_dict.get("negative_prompt", "")
|
||||
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
||||
|
||||
# Load sample prompts from Gemma 2
|
||||
if gemma2_model is not None:
|
||||
logger.info(f"Encoding prompt with Gemma2: {prompt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks)
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
if negative_prompt is not None:
|
||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt)
|
||||
neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks)
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ""
|
||||
logger.info(f"prompt: {prompt}")
|
||||
logger.info(f"negative_prompt: {negative_prompt}")
|
||||
logger.info(f"height: {height}")
|
||||
logger.info(f"width: {width}")
|
||||
logger.info(f"sample_steps: {sample_steps}")
|
||||
logger.info(f"scale: {guidance_scale}")
|
||||
# logger.info(f"sample_sampler: {sampler_name}")
|
||||
|
||||
# Unpack Gemma2 outputs
|
||||
gemma2_hidden_states, input_ids, gemma2_attn_mask = gemma2_conds
|
||||
neg_gemma2_hidden_states, neg_input_ids, neg_gemma2_attn_mask = neg_gemma2_conds
|
||||
system_prompt = args.system_prompt or ""
|
||||
|
||||
# Apply system prompt to prompts
|
||||
prompt = system_prompt + prompt
|
||||
negative_prompt = system_prompt + negative_prompt
|
||||
|
||||
# Get sample prompts from cache
|
||||
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
|
||||
gemma2_conds = sample_prompts_gemma2_outputs[prompt]
|
||||
logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}")
|
||||
|
||||
if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs:
|
||||
neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt]
|
||||
logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}")
|
||||
|
||||
# Load sample prompts from Gemma 2
|
||||
if gemma2_model is not None:
|
||||
logger.info(f"Encoding prompt with Gemma2: {prompt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks)
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt)
|
||||
neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks)
|
||||
|
||||
# Unpack Gemma2 outputs
|
||||
gemma2_hidden_states, _, gemma2_attn_mask = gemma2_conds
|
||||
neg_gemma2_hidden_states, _, neg_gemma2_attn_mask = neg_gemma2_conds
|
||||
|
||||
text_conds.append(
|
||||
(
|
||||
gemma2_hidden_states.squeeze(0),
|
||||
gemma2_attn_mask.squeeze(0),
|
||||
neg_gemma2_hidden_states.squeeze(0),
|
||||
neg_gemma2_attn_mask.squeeze(0),
|
||||
)
|
||||
)
|
||||
|
||||
# Stack conditioning
|
||||
cond_hidden_states = torch.stack([text_cond[0] for text_cond in text_conds]).to(accelerator.device)
|
||||
cond_attn_masks = torch.stack([text_cond[1] for text_cond in text_conds]).to(accelerator.device)
|
||||
uncond_hidden_states = torch.stack([text_cond[2] for text_cond in text_conds]).to(accelerator.device)
|
||||
uncond_attn_masks = torch.stack([text_cond[3] for text_cond in text_conds]).to(accelerator.device)
|
||||
|
||||
# sample image
|
||||
weight_dtype = vae.dtype # TOFO give dtype as argument
|
||||
@@ -279,6 +354,7 @@ def sample_image_inference(
|
||||
dtype=weight_dtype,
|
||||
generator=generator,
|
||||
)
|
||||
noise = noise.repeat(cond_hidden_states.shape[0], 1, 1, 1)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=6.0)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=sample_steps)
|
||||
@@ -294,10 +370,10 @@ def sample_image_inference(
|
||||
scheduler,
|
||||
nextdit,
|
||||
noise,
|
||||
gemma2_hidden_states,
|
||||
gemma2_attn_mask.to(accelerator.device),
|
||||
neg_gemma2_hidden_states,
|
||||
neg_gemma2_attn_mask.to(accelerator.device),
|
||||
cond_hidden_states,
|
||||
cond_attn_masks,
|
||||
uncond_hidden_states,
|
||||
uncond_attn_masks,
|
||||
timesteps=timesteps,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
@@ -307,34 +383,44 @@ def sample_image_inference(
|
||||
clean_memory_on_device(accelerator.device)
|
||||
org_vae_device = vae.device # will be on cpu
|
||||
vae.to(accelerator.device) # distributed_state.device is same as accelerator.device
|
||||
with accelerator.autocast():
|
||||
x = vae.decode((x / vae.scale_factor) + vae.shift_factor)
|
||||
for img, prompt_dict in zip(x, prompt_dicts):
|
||||
|
||||
img = (img / vae.scale_factor) + vae.shift_factor
|
||||
|
||||
with accelerator.autocast():
|
||||
# Add a single batch image for the VAE to decode
|
||||
img = vae.decode(img.unsqueeze(0))
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
img = img.permute(0, 2, 3, 1) # B, H, W, C
|
||||
# Scale images back to 0 to 255
|
||||
img = (127.5 * (img + 1.0)).float().cpu().numpy().astype(np.uint8)
|
||||
|
||||
# Get single image
|
||||
image = Image.fromarray(img[0])
|
||||
|
||||
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
|
||||
# but adding 'enum' to the filename should be enough
|
||||
|
||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{global_step:06d}"
|
||||
seed_suffix = "" if seed is None else f"_{seed}"
|
||||
i: int = int(prompt_dict.get("enum", 0))
|
||||
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
||||
image.save(os.path.join(save_dir, img_filename))
|
||||
|
||||
# send images to wandb if enabled
|
||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
|
||||
import wandb
|
||||
|
||||
# not to commit images to avoid inconsistency between training and logging steps
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
||||
|
||||
vae.to(org_vae_device)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
x = x.clamp(-1, 1)
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
||||
|
||||
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
|
||||
# but adding 'enum' to the filename should be enough
|
||||
|
||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{global_step:06d}"
|
||||
seed_suffix = "" if seed is None else f"_{seed}"
|
||||
i: int = int(prompt_dict.get("enum", 0))
|
||||
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
||||
image.save(os.path.join(save_dir, img_filename))
|
||||
|
||||
# send images to wandb if enabled
|
||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
|
||||
import wandb
|
||||
|
||||
# not to commit images to avoid inconsistency between training and logging steps
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
||||
# the following implementation was original for t=0: clean / t=1: noise
|
||||
@@ -879,16 +965,22 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
|
||||
"--discrete_flow_shift",
|
||||
type=float,
|
||||
default=6.0,
|
||||
help="Discrete flow shift for the Euler Discrete Scheduler, default is 6.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは6.0。",
|
||||
help="Discrete flow shift for the Euler Discrete Scheduler, default is 6.0 / Euler Discrete Schedulerの離散フローシフト、デフォルトは6.0",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_flash_attn",
|
||||
action="store_true",
|
||||
help="Use Flash Attention for the model. / モデルにFlash Attentionを使用する。",
|
||||
help="Use Flash Attention for the model / モデルにFlash Attentionを使用する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--system_prompt",
|
||||
type=str,
|
||||
default="",
|
||||
help="System prompt to add to the prompt. / プロンプトに追加するシステムプロンプト。",
|
||||
help="System prompt to add to the prompt / プロンプトに追加するシステムプロンプト",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample_batch_size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Batch size to use for sampling, defaults to --training_batch_size value. Sample batches are bucketed by width, height, guidance scale, and seed / サンプリングに使用するバッチサイズ。デフォルトは --training_batch_size の値です。サンプルバッチは、幅、高さ、ガイダンススケール、シードによってバケット化されます",
|
||||
)
|
||||
|
||||
@@ -1242,6 +1242,7 @@ class NetworkTrainer:
|
||||
# For --sample_at_first
|
||||
optimizer_eval_fn()
|
||||
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
||||
progress_bar.unpause() # Reset progress bar to before sampling images
|
||||
optimizer_train_fn()
|
||||
is_tracking = len(accelerator.trackers) > 0
|
||||
if is_tracking:
|
||||
@@ -1344,6 +1345,7 @@ class NetworkTrainer:
|
||||
self.sample_images(
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
|
||||
)
|
||||
progress_bar.unpause()
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
@@ -1531,6 +1533,7 @@ class NetworkTrainer:
|
||||
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
||||
|
||||
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
||||
progress_bar.unpause()
|
||||
optimizer_train_fn()
|
||||
|
||||
# end of epoch
|
||||
|
||||
Reference in New Issue
Block a user