mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Added sample_images() for --sample_at_first
This commit is contained in:
@@ -303,6 +303,9 @@ def train(args):
|
|||||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||||
current_epoch.value = epoch + 1
|
current_epoch.value = epoch + 1
|
||||||
|
|
||||||
|
# For --sample_at_first
|
||||||
|
train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||||
|
|
||||||
for m in training_models:
|
for m in training_models:
|
||||||
m.train()
|
m.train()
|
||||||
|
|
||||||
|
|||||||
@@ -373,6 +373,20 @@ def train(args):
|
|||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
|
# For --sample_at_first
|
||||||
|
train_util.sample_images(
|
||||||
|
accelerator,
|
||||||
|
args,
|
||||||
|
epoch,
|
||||||
|
global_step,
|
||||||
|
accelerator.device,
|
||||||
|
vae,
|
||||||
|
tokenizer,
|
||||||
|
text_encoder,
|
||||||
|
unet,
|
||||||
|
controlnet=controlnet,
|
||||||
|
)
|
||||||
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||||
current_epoch.value = epoch + 1
|
current_epoch.value = epoch + 1
|
||||||
|
|||||||
@@ -279,6 +279,8 @@ def train(args):
|
|||||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||||
current_epoch.value = epoch + 1
|
current_epoch.value = epoch + 1
|
||||||
|
|
||||||
|
train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||||
|
|
||||||
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
||||||
unet.train()
|
unet.train()
|
||||||
# train==True is required to enable gradient_checkpointing
|
# train==True is required to enable gradient_checkpointing
|
||||||
|
|||||||
@@ -750,6 +750,8 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
metadata["ss_epoch"] = str(epoch + 1)
|
metadata["ss_epoch"] = str(epoch + 1)
|
||||||
|
|
||||||
|
# For --sample_at_first
|
||||||
|
self.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||||
network.on_epoch_start(text_encoder, unet)
|
network.on_epoch_start(text_encoder, unet)
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
|||||||
@@ -534,6 +534,20 @@ class TextualInversionTrainer:
|
|||||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||||
current_epoch.value = epoch + 1
|
current_epoch.value = epoch + 1
|
||||||
|
|
||||||
|
# For --sample_at_first
|
||||||
|
self.sample_images(
|
||||||
|
accelerator,
|
||||||
|
args,
|
||||||
|
epoch,
|
||||||
|
global_step,
|
||||||
|
accelerator.device,
|
||||||
|
vae,
|
||||||
|
tokenizer_or_list,
|
||||||
|
text_encoder_or_list,
|
||||||
|
unet,
|
||||||
|
prompt_replacement,
|
||||||
|
)
|
||||||
|
|
||||||
for text_encoder in text_encoders:
|
for text_encoder in text_encoders:
|
||||||
text_encoder.train()
|
text_encoder.train()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user