correct files

This commit is contained in:
Imon Banerjee
2025-04-07 14:53:59 +05:30
parent 14672c6f2a
commit 398e7df03c
2 changed files with 56 additions and 48 deletions

View File

@@ -96,7 +96,7 @@ def append_block_lr_to_logs(block_lrs, logs, lr_scheduler, optimizer_type):
train_util.append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names)
def train(args):
def train(args, progress_interceptor = None):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
sdxl_train_util.verify_sdxl_training_args(args)
@@ -761,6 +761,10 @@ def train(args):
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
progress_interceptor.send_training_progress(
value=global_step,
max_value=args.max_train_steps,
)
sdxl_train_util.sample_images(
accelerator,
@@ -942,11 +946,60 @@ def setup_parser() -> argparse.ArgumentParser:
return parser
class ProgressInterceptor:
def __init__(self, train_id):
self.train_id = train_id
self.api_base_url = "https://13df-2405-201-d02a-a86b-69e0-1f60-c792-bb5/api/generations"
def send_training_progress(self, value, max_value):
if max_value < 100:
progress = int((value / max_value) * 100)
self._post_progress(progress)
return
progress_step = max_value // 100
if value % progress_step == 0:
progress = value // progress_step
self._post_progress(progress)
def _post_progress(self, progress):
url = f"{self.api_base_url}/{self.train_id}/update_training_progress/"
headers = {
"Accept": "application/json, text/plain, */*",
"Accept-Language": "en-GB,en-US;q=0.9,en;q=0.8",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "application/json",
"Origin": "http://localhost:3000",
"Pragma": "no-cache",
"Referer": "http://localhost:3000/",
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
"access-token": "311e7d6f-8d78-4a7a-9831-5aa08b3ef06c",
"sec-ch-ua": '"Google Chrome";v="131", "Chromium";v="131", "Not_A Brand";v="24"',
"sec-ch-ua-mobile": "?0",
"sec-ch-ua-platform": '"macOS"',
"uid": "1"
}
payload = {"progress": progress}
try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
print(f"[ProgressInterceptor] Progress updated to {progress}%")
except requests.RequestException as e:
print(f"[ProgressInterceptor] Failed to update progress: {e}")
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)
interceptor = ProgressInterceptor(
train_id=args.train_id,
)
train(args, interceptor)

View File

@@ -11,51 +11,6 @@ setup_logging()
import logging
logger = logging.getLogger(__name__)
class ProgressInterceptor:
def __init__(self, train_id):
self.train_id = train_id
self.api_base_url = "https://13df-2405-201-d02a-a86b-69e0-1f60-c792-bb5/api/generations"
def send_training_progress(self, value, max_value):
if max_value < 100:
progress = int((value / max_value) * 100)
self._post_progress(progress)
return
progress_step = max_value // 100
if value % progress_step == 0:
progress = value // progress_step
self._post_progress(progress)
def _post_progress(self, progress):
url = f"{self.api_base_url}/{self.train_id}/update_training_progress/"
headers = {
"Accept": "application/json, text/plain, */*",
"Accept-Language": "en-GB,en-US;q=0.9,en;q=0.8",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "application/json",
"Origin": "http://localhost:3000",
"Pragma": "no-cache",
"Referer": "http://localhost:3000/",
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
"access-token": "311e7d6f-8d78-4a7a-9831-5aa08b3ef06c",
"sec-ch-ua": '"Google Chrome";v="131", "Chromium";v="131", "Not_A Brand";v="24"',
"sec-ch-ua-mobile": "?0",
"sec-ch-ua-platform": '"macOS"',
"uid": "1"
}
payload = {"progress": progress}
try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
print(f"[ProgressInterceptor] Progress updated to {progress}%")
except requests.RequestException as e:
print(f"[ProgressInterceptor] Failed to update progress: {e}")
class SdxlNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):