Merge branch 'dev' into masked-loss

This commit is contained in:
Kohya S
2024-03-26 19:39:30 +09:00
18 changed files with 424 additions and 250 deletions

View File

@@ -69,6 +69,7 @@ from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipel
import library.model_util as model_util
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
import library.deepspeed_utils as deepspeed_utils
from library.utils import setup_logging
setup_logging()
@@ -4120,6 +4121,10 @@ def load_tokenizer(args: argparse.Namespace):
def prepare_accelerator(args: argparse.Namespace):
"""
this function also prepares deepspeed plugin
"""
if args.logging_dir is None:
logging_dir = None
else:
@@ -4165,6 +4170,8 @@ def prepare_accelerator(args: argparse.Namespace):
),
)
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
@@ -4172,6 +4179,7 @@ def prepare_accelerator(args: argparse.Namespace):
project_dir=logging_dir,
kwargs_handlers=kwargs_handlers,
dynamo_backend=dynamo_backend,
deepspeed_plugin=deepspeed_plugin,
)
print("accelerator device:", accelerator.device)
return accelerator
@@ -4242,7 +4250,6 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
# load models for each process
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
@@ -4253,7 +4260,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
accelerator.device if args.lowram else "cpu",
unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2,
)
# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
@@ -4262,7 +4268,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
return text_encoder, vae, unet, load_stable_diffusion_format