mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
@@ -275,7 +275,7 @@ def train(args):
|
|||||||
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype)
|
||||||
else:
|
else:
|
||||||
# latentに変換
|
# latentに変換
|
||||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||||
@@ -313,6 +313,7 @@ def train(args):
|
|||||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
|
with accelerator.autocast():
|
||||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
|
|
||||||
if args.v_parameterization:
|
if args.v_parameterization:
|
||||||
|
|||||||
@@ -185,8 +185,8 @@ def train(args):
|
|||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
|
||||||
current_epoch = Value('i',0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value('i',0)
|
current_step = Value("i", 0)
|
||||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||||
|
|
||||||
@@ -264,7 +264,9 @@ def train(args):
|
|||||||
|
|
||||||
# 学習ステップ数を計算する
|
# 学習ステップ数を計算する
|
||||||
if args.max_train_epochs is not None:
|
if args.max_train_epochs is not None:
|
||||||
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||||
|
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||||
|
)
|
||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
# データセット側にも学習ステップを送信
|
# データセット側にも学習ステップを送信
|
||||||
@@ -359,7 +361,7 @@ def train(args):
|
|||||||
|
|
||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
input_ids = batch["input_ids"].to(accelerator.device)
|
input_ids = batch["input_ids"].to(accelerator.device)
|
||||||
# weight_dtype) use float instead of fp16/bf16 because text encoder is float
|
# use float instead of fp16/bf16 because text encoder is float
|
||||||
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float)
|
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float)
|
||||||
|
|
||||||
# Sample noise that we'll add to the latents
|
# Sample noise that we'll add to the latents
|
||||||
@@ -377,6 +379,7 @@ def train(args):
|
|||||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
|
with accelerator.autocast():
|
||||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
|
|
||||||
if args.v_parameterization:
|
if args.v_parameterization:
|
||||||
|
|||||||
@@ -418,6 +418,7 @@ def train(args):
|
|||||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
|
with accelerator.autocast():
|
||||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
|
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
|
||||||
|
|
||||||
if args.v_parameterization:
|
if args.v_parameterization:
|
||||||
|
|||||||
Reference in New Issue
Block a user