diff --git a/flux_train.py b/flux_train.py index fad58d2b..563e845e 100644 --- a/flux_train.py +++ b/flux_train.py @@ -393,11 +393,13 @@ def train(args): # Self check parameter compatibility if args.optimizer_type != "adafactor": logger.warning("Kahan summation has been requested, but currently this is only supported by the supplied Adafactor optimizer.") - if not args.full_bf16: + elif not args.full_bf16: logger.warning("Kahan summation requires --full_bf16") - if args.blockwise_fused_optimizers: + elif args.blockwise_fused_optimizers: logger.warning("Kahan summation has been requested, but it is incompatible with --blockwise_fused_optimizer. "\ "Perhaps try --fused_backward_pass instead.") + else: + logger.info("Using Kahan summation") optimizer.use_kahan_summation = args.kahan_summation # prepare dataloader