mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
lora以外も対応
This commit is contained in:
@@ -25,17 +25,6 @@ from library.config_util import (
|
||||
BlueprintGenerator,
|
||||
)
|
||||
|
||||
class collater_class:
|
||||
def __init__(self,epoch,step):
|
||||
self.current_epoch=epoch
|
||||
self.current_step=step
|
||||
def __call__(self, examples):
|
||||
dataset = torch.utils.data.get_worker_info().dataset
|
||||
dataset.set_current_epoch(self.current_epoch.value)
|
||||
dataset.set_current_step(self.current_step.value)
|
||||
return examples[0]
|
||||
|
||||
|
||||
# TODO 他のスクリプトと共通化する
|
||||
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||
@@ -110,7 +99,7 @@ def train(args):
|
||||
|
||||
current_epoch = Value('i',0)
|
||||
current_step = Value('i',0)
|
||||
collater = collater_class(current_epoch,current_step)
|
||||
collater = train_util.collater_class(current_epoch,current_step)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
|
||||
Reference in New Issue
Block a user