lora以外も対応

This commit is contained in:
u-haru
2023-03-26 02:19:55 +09:00
parent 9c80da6ac5
commit 4dc1124f93
5 changed files with 37 additions and 33 deletions

View File

@@ -25,17 +25,6 @@ from library.config_util import (
BlueprintGenerator,
)
class collater_class:
def __init__(self,epoch,step):
self.current_epoch=epoch
self.current_step=step
def __call__(self, examples):
dataset = torch.utils.data.get_worker_info().dataset
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]
# TODO 他のスクリプトと共通化する
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
@@ -110,7 +99,7 @@ def train(args):
current_epoch = Value('i',0)
current_step = Value('i',0)
collater = collater_class(current_epoch,current_step)
collater = train_util.collater_class(current_epoch,current_step)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)