Fused backward pass

This commit is contained in:
2kpr
2024-04-14 09:56:58 -05:00
parent 71e2c91330
commit 4f203ce40d
3 changed files with 142 additions and 6 deletions

View File

@@ -2920,6 +2920,11 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
default=1,
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
)
parser.add_argument(
"--fused_backward_pass",
action="store_true",
help="Combines backward pass and optimizer step to reduce VRAM usage / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。",
)
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
@@ -3846,6 +3851,14 @@ def get_optimizer(args, trainable_params):
optimizer_type = "AdamW"
optimizer_type = optimizer_type.lower()
if args.fused_backward_pass:
assert (
optimizer_type == "Adafactor".lower()
), "fused_backward_pass currently only works with optimizer_type Adafactor / fused_backward_passは現在optimizer_type Adafactorでのみ機能します"
assert (
args.gradient_accumulation_steps == 1
), "fused_backward_pass does not work with gradient_accumulation_steps > 1 / fused_backward_passはgradient_accumulation_steps>1では機能しません"
# 引数を分解する
optimizer_kwargs = {}
if args.optimizer_args is not None and len(args.optimizer_args) > 0: