make tracker init_kwargs configurable

This commit is contained in:
ddPn08
2023-07-09 16:31:38 +09:00
parent f54b784d88
commit b841dd78fe
8 changed files with 40 additions and 7 deletions

View File

@@ -6,6 +6,7 @@ import gc
import math import math
import os import os
from multiprocessing import Value from multiprocessing import Value
import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -275,7 +276,10 @@ def train(args):
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name) init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")

View File

@@ -2445,6 +2445,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None, default=None,
help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名", help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名",
) )
parser.add_argument(
"--log_tracker_config",
type=str,
default=None,
help="path to tracker config file to use for logging / ログ出力に使用するtrackerの設定ファイルのパス",
)
parser.add_argument( parser.add_argument(
"--wandb_api_key", "--wandb_api_key",
type=str, type=str,

View File

@@ -5,6 +5,7 @@ import gc
import math import math
import os import os
from multiprocessing import Value from multiprocessing import Value
import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -350,7 +351,10 @@ def train(args):
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name) init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")

View File

@@ -7,6 +7,7 @@ import random
import time import time
from multiprocessing import Value from multiprocessing import Value
from types import SimpleNamespace from types import SimpleNamespace
import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -324,7 +325,10 @@ def train(args):
clip_sample=False, clip_sample=False,
) )
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name) init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
loss_list = [] loss_list = []
loss_total = 0.0 loss_total = 0.0

View File

@@ -7,6 +7,7 @@ import itertools
import math import math
import os import os
from multiprocessing import Value from multiprocessing import Value
import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -248,7 +249,10 @@ def train(args):
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name) init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
loss_list = [] loss_list = []
loss_total = 0.0 loss_total = 0.0

View File

@@ -8,6 +8,7 @@ import random
import time import time
import json import json
from multiprocessing import Value from multiprocessing import Value
import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -672,7 +673,10 @@ class NetworkTrainer:
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("network_train" if args.log_tracker_name is None else args.log_tracker_name) init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
loss_list = [] loss_list = []
loss_total = 0.0 loss_total = 0.0

View File

@@ -3,6 +3,7 @@ import gc
import math import math
import os import os
from multiprocessing import Value from multiprocessing import Value
import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -493,7 +494,10 @@ class TextualInversionTrainer:
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name) init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
# function for saving/removing # function for saving/removing
def save_model(ckpt_name, embs_list, steps, epoch_no, force_sync_upload=False): def save_model(ckpt_name, embs_list, steps, epoch_no, force_sync_upload=False):

View File

@@ -386,7 +386,10 @@ def train(args):
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name) init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
# function for saving/removing # function for saving/removing
def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False): def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False):