diff --git a/flux_train.py b/flux_train.py index ad32ebc6..99ca4641 100644 --- a/flux_train.py +++ b/flux_train.py @@ -385,7 +385,7 @@ def train(args): optimizer_eval_fn = lambda: None # dummy function if (args.optimizer_type not in fused_optimizers_supported) and args.full_bf16: - logger.warning("Use of --blockwise_fused_optimizers with Adafactor optimizer prevents stochastic/Kahan weight updates.") + logger.warning("Use of --blockwise_fused_optimizers is preventing stochastic/Kahan weight updates.") else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)