remove workaround for accelerator=0.15, fix XTI

This commit is contained in:
ykume
2023-06-11 18:32:14 +09:00
parent 33a6234b52
commit 0315611b11
7 changed files with 153 additions and 159 deletions

View File

@@ -2904,23 +2904,9 @@ def prepare_accelerator(args: argparse.Namespace):
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=log_with,
logging_dir=logging_dir,
project_dir=logging_dir,
)
# accelerateの互換性問題を解決する
accelerator_0_15 = True
try:
accelerator.unwrap_model("dummy", True)
print("Using accelerator 0.15.0 or above.")
except TypeError:
accelerator_0_15 = False
def unwrap_model(model):
if accelerator_0_15:
return accelerator.unwrap_model(model, True)
return accelerator.unwrap_model(model)
return accelerator, unwrap_model
return accelerator
def prepare_dtype(args: argparse.Namespace):