mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
lora以外も対応
This commit is contained in:
15
fine_tune.py
15
fine_tune.py
@@ -6,6 +6,7 @@ import gc
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import toml
|
import toml
|
||||||
|
from multiprocessing import Value
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
@@ -21,10 +22,6 @@ from library.config_util import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(examples):
|
|
||||||
return examples[0]
|
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
train_util.verify_training_args(args)
|
train_util.verify_training_args(args)
|
||||||
train_util.prepare_dataset_args(args, True)
|
train_util.prepare_dataset_args(args, True)
|
||||||
@@ -65,6 +62,10 @@ def train(args):
|
|||||||
config_util.blueprint_args_conflict(args,blueprint)
|
config_util.blueprint_args_conflict(args,blueprint)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
|
||||||
|
current_epoch = Value('i',0)
|
||||||
|
current_step = Value('i',0)
|
||||||
|
collater = train_util.collater_class(current_epoch,current_step)
|
||||||
|
|
||||||
if args.debug_dataset:
|
if args.debug_dataset:
|
||||||
train_util.debug_dataset(train_dataset_group)
|
train_util.debug_dataset(train_dataset_group)
|
||||||
return
|
return
|
||||||
@@ -188,7 +189,7 @@ def train(args):
|
|||||||
train_dataset_group,
|
train_dataset_group,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collater,
|
||||||
num_workers=n_workers,
|
num_workers=n_workers,
|
||||||
persistent_workers=args.persistent_data_loader_workers,
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
)
|
)
|
||||||
@@ -259,14 +260,14 @@ def train(args):
|
|||||||
|
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
train_dataset_group.set_current_epoch(epoch + 1)
|
current_epoch.value = epoch+1
|
||||||
train_dataset_group.set_current_step(global_step)
|
|
||||||
|
|
||||||
for m in training_models:
|
for m in training_models:
|
||||||
m.train()
|
m.train()
|
||||||
|
|
||||||
loss_total = 0
|
loss_total = 0
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
|
|||||||
@@ -2987,3 +2987,14 @@ class ImageLoadingDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
# colalte_fn用 epoch,stepはmultiprocessing.Value
|
||||||
|
class collater_class:
|
||||||
|
def __init__(self,epoch,step):
|
||||||
|
self.current_epoch=epoch
|
||||||
|
self.current_step=step
|
||||||
|
def __call__(self, examples):
|
||||||
|
dataset = torch.utils.data.get_worker_info().dataset
|
||||||
|
dataset.set_current_epoch(self.current_epoch.value)
|
||||||
|
dataset.set_current_step(self.current_step.value)
|
||||||
|
return examples[0]
|
||||||
15
train_db.py
15
train_db.py
@@ -8,6 +8,7 @@ import itertools
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import toml
|
import toml
|
||||||
|
from multiprocessing import Value
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
@@ -23,10 +24,6 @@ from library.config_util import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(examples):
|
|
||||||
return examples[0]
|
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
train_util.verify_training_args(args)
|
train_util.verify_training_args(args)
|
||||||
train_util.prepare_dataset_args(args, False)
|
train_util.prepare_dataset_args(args, False)
|
||||||
@@ -60,6 +57,10 @@ def train(args):
|
|||||||
config_util.blueprint_args_conflict(args,blueprint)
|
config_util.blueprint_args_conflict(args,blueprint)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
|
||||||
|
current_epoch = Value('i',0)
|
||||||
|
current_step = Value('i',0)
|
||||||
|
collater = train_util.collater_class(current_epoch,current_step)
|
||||||
|
|
||||||
if args.no_token_padding:
|
if args.no_token_padding:
|
||||||
train_dataset_group.disable_token_padding()
|
train_dataset_group.disable_token_padding()
|
||||||
|
|
||||||
@@ -153,7 +154,7 @@ def train(args):
|
|||||||
train_dataset_group,
|
train_dataset_group,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collater,
|
||||||
num_workers=n_workers,
|
num_workers=n_workers,
|
||||||
persistent_workers=args.persistent_data_loader_workers,
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
)
|
)
|
||||||
@@ -233,8 +234,7 @@ def train(args):
|
|||||||
loss_total = 0.0
|
loss_total = 0.0
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
train_dataset_group.set_current_epoch(epoch + 1)
|
current_epoch.value = epoch+1
|
||||||
train_dataset_group.set_current_step(global_step)
|
|
||||||
|
|
||||||
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
||||||
unet.train()
|
unet.train()
|
||||||
@@ -243,6 +243,7 @@ def train(args):
|
|||||||
text_encoder.train()
|
text_encoder.train()
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
current_step.value = global_step
|
||||||
# 指定したステップ数でText Encoderの学習を止める
|
# 指定したステップ数でText Encoderの学習を止める
|
||||||
if global_step == args.stop_text_encoder_training:
|
if global_step == args.stop_text_encoder_training:
|
||||||
print(f"stop text encoder training at step {global_step}")
|
print(f"stop text encoder training at step {global_step}")
|
||||||
|
|||||||
@@ -25,17 +25,6 @@ from library.config_util import (
|
|||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
|
|
||||||
class collater_class:
|
|
||||||
def __init__(self,epoch,step):
|
|
||||||
self.current_epoch=epoch
|
|
||||||
self.current_step=step
|
|
||||||
def __call__(self, examples):
|
|
||||||
dataset = torch.utils.data.get_worker_info().dataset
|
|
||||||
dataset.set_current_epoch(self.current_epoch.value)
|
|
||||||
dataset.set_current_step(self.current_step.value)
|
|
||||||
return examples[0]
|
|
||||||
|
|
||||||
|
|
||||||
# TODO 他のスクリプトと共通化する
|
# TODO 他のスクリプトと共通化する
|
||||||
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
||||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||||
@@ -110,7 +99,7 @@ def train(args):
|
|||||||
|
|
||||||
current_epoch = Value('i',0)
|
current_epoch = Value('i',0)
|
||||||
current_step = Value('i',0)
|
current_step = Value('i',0)
|
||||||
collater = collater_class(current_epoch,current_step)
|
collater = train_util.collater_class(current_epoch,current_step)
|
||||||
|
|
||||||
if args.debug_dataset:
|
if args.debug_dataset:
|
||||||
train_util.debug_dataset(train_dataset_group)
|
train_util.debug_dataset(train_dataset_group)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import gc
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import toml
|
import toml
|
||||||
|
from multiprocessing import Value
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
@@ -71,10 +72,6 @@ imagenet_style_templates_small = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(examples):
|
|
||||||
return examples[0]
|
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
if args.output_name is None:
|
if args.output_name is None:
|
||||||
args.output_name = args.token_string
|
args.output_name = args.token_string
|
||||||
@@ -186,6 +183,10 @@ def train(args):
|
|||||||
config_util.blueprint_args_conflict(args,blueprint)
|
config_util.blueprint_args_conflict(args,blueprint)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
|
||||||
|
current_epoch = Value('i',0)
|
||||||
|
current_step = Value('i',0)
|
||||||
|
collater = train_util.collater_class(current_epoch,current_step)
|
||||||
|
|
||||||
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
||||||
if use_template:
|
if use_template:
|
||||||
print("use template for training captions. is object: {args.use_object_template}")
|
print("use template for training captions. is object: {args.use_object_template}")
|
||||||
@@ -251,7 +252,7 @@ def train(args):
|
|||||||
train_dataset_group,
|
train_dataset_group,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collater,
|
||||||
num_workers=n_workers,
|
num_workers=n_workers,
|
||||||
persistent_workers=args.persistent_data_loader_workers,
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
)
|
)
|
||||||
@@ -335,13 +336,14 @@ def train(args):
|
|||||||
|
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
train_dataset_group.set_current_epoch(epoch + 1)
|
current_epoch.value = epoch+1
|
||||||
train_dataset_group.set_current_step(global_step)
|
|
||||||
|
|
||||||
text_encoder.train()
|
text_encoder.train()
|
||||||
|
|
||||||
loss_total = 0
|
loss_total = 0
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(text_encoder):
|
with accelerator.accumulate(text_encoder):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user