diff --git a/sdxl_train.py b/sdxl_train.py index b533b274..d71a4fa1 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -96,7 +96,7 @@ def append_block_lr_to_logs(block_lrs, logs, lr_scheduler, optimizer_type): train_util.append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names) -def train(args): +def train(args, progress_interceptor = None): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) sdxl_train_util.verify_sdxl_training_args(args) @@ -761,6 +761,10 @@ def train(args): if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 + progress_interceptor.send_training_progress( + value=global_step, + max_value=args.max_train_steps, + ) sdxl_train_util.sample_images( accelerator, @@ -942,11 +946,60 @@ def setup_parser() -> argparse.ArgumentParser: return parser +class ProgressInterceptor: + def __init__(self, train_id): + self.train_id = train_id + self.api_base_url = "https://13df-2405-201-d02a-a86b-69e0-1f60-c792-bb5/api/generations" + + def send_training_progress(self, value, max_value): + if max_value < 100: + progress = int((value / max_value) * 100) + self._post_progress(progress) + return + + progress_step = max_value // 100 + if value % progress_step == 0: + progress = value // progress_step + self._post_progress(progress) + + def _post_progress(self, progress): + url = f"{self.api_base_url}/{self.train_id}/update_training_progress/" + + headers = { + "Accept": "application/json, text/plain, */*", + "Accept-Language": "en-GB,en-US;q=0.9,en;q=0.8", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "application/json", + "Origin": "http://localhost:3000", + "Pragma": "no-cache", + "Referer": "http://localhost:3000/", + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", + "access-token": "311e7d6f-8d78-4a7a-9831-5aa08b3ef06c", + "sec-ch-ua": '"Google Chrome";v="131", "Chromium";v="131", "Not_A Brand";v="24"', + "sec-ch-ua-mobile": "?0", + "sec-ch-ua-platform": '"macOS"', + "uid": "1" + } + + payload = {"progress": progress} + + try: + response = requests.post(url, headers=headers, json=payload) + response.raise_for_status() + print(f"[ProgressInterceptor] Progress updated to {progress}%") + except requests.RequestException as e: + print(f"[ProgressInterceptor] Failed to update progress: {e}") + + + if __name__ == "__main__": parser = setup_parser() args = parser.parse_args() train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) - - train(args) + interceptor = ProgressInterceptor( + train_id=args.train_id, + ) + train(args, interceptor) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 4f6f65dd..90f9d4d7 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -11,51 +11,6 @@ setup_logging() import logging logger = logging.getLogger(__name__) -class ProgressInterceptor: - def __init__(self, train_id): - self.train_id = train_id - self.api_base_url = "https://13df-2405-201-d02a-a86b-69e0-1f60-c792-bb5/api/generations" - - def send_training_progress(self, value, max_value): - if max_value < 100: - progress = int((value / max_value) * 100) - self._post_progress(progress) - return - - progress_step = max_value // 100 - if value % progress_step == 0: - progress = value // progress_step - self._post_progress(progress) - - def _post_progress(self, progress): - url = f"{self.api_base_url}/{self.train_id}/update_training_progress/" - - headers = { - "Accept": "application/json, text/plain, */*", - "Accept-Language": "en-GB,en-US;q=0.9,en;q=0.8", - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Content-Type": "application/json", - "Origin": "http://localhost:3000", - "Pragma": "no-cache", - "Referer": "http://localhost:3000/", - "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", - "access-token": "311e7d6f-8d78-4a7a-9831-5aa08b3ef06c", - "sec-ch-ua": '"Google Chrome";v="131", "Chromium";v="131", "Not_A Brand";v="24"', - "sec-ch-ua-mobile": "?0", - "sec-ch-ua-platform": '"macOS"', - "uid": "1" - } - - payload = {"progress": progress} - - try: - response = requests.post(url, headers=headers, json=payload) - response.raise_for_status() - print(f"[ProgressInterceptor] Progress updated to {progress}%") - except requests.RequestException as e: - print(f"[ProgressInterceptor] Failed to update progress: {e}") - class SdxlNetworkTrainer(train_network.NetworkTrainer): def __init__(self):