mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix full_fp16 compatible and train_step
This commit is contained in:
@@ -3166,6 +3166,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
action="store_true",
|
||||
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_master_weights_and_gradients",
|
||||
action="store_true",
|
||||
help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32."
|
||||
)
|
||||
|
||||
def verify_training_args(args: argparse.Namespace):
|
||||
if args.v_parameterization and not args.v2:
|
||||
@@ -3966,6 +3971,8 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size
|
||||
deepspeed_plugin.deepspeed_config['train_batch_size'] = \
|
||||
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ['WORLD_SIZE'])
|
||||
if args.full_fp16 or args.fp16_master_weights_and_gradients:
|
||||
deepspeed_plugin.deepspeed_config['fp16_master_weights_and_gradients'] = True
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
|
||||
Reference in New Issue
Block a user