diff --git a/flux_train_network.py b/flux_train_network.py index be8c62ca..75daeeab 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -36,6 +36,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): self.is_schnell: Optional[bool] = None self.is_swapping_blocks: bool = False self.model_type: Optional[str] = None + self.args = None def assert_extra_args( self, diff --git a/train_network.py b/train_network.py index 6180d2f8..3e733982 100644 --- a/train_network.py +++ b/train_network.py @@ -480,6 +480,7 @@ class NetworkTrainer: return loss.mean() def train(self, args): + self.args = args # store args for later use session_id = random.randint(0, 2**32) training_started_at = time.time() train_util.verify_training_args(args)