diff --git a/flux_train.py b/flux_train.py index 9bb58c0f..5228ec13 100644 --- a/flux_train.py +++ b/flux_train.py @@ -384,7 +384,7 @@ def train(args): optimizer_train_fn = lambda: None # dummy function optimizer_eval_fn = lambda: None # dummy function - if (args.optimizer_type not in fused_optimizers_supported) and args.full_bf16: + if (args.optimizer_type in fused_optimizers_supported) and args.full_bf16: logger.warning("Use of --blockwise_fused_optimizers is preventing stochastic/Kahan weight updates.") else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)