progress interceptor

This commit is contained in:
Imon Banerjee
2025-04-07 14:13:29 +05:30
parent a0f11730f7
commit 14672c6f2a
3 changed files with 70 additions and 3 deletions

View File

@@ -3141,6 +3141,8 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
parser.add_argument("--train_id", type=int,
default=None, help="parameter model id for lora training")
parser.add_argument(
"--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ"
)

View File

@@ -11,6 +11,52 @@ setup_logging()
import logging
logger = logging.getLogger(__name__)
class ProgressInterceptor:
def __init__(self, train_id):
self.train_id = train_id
self.api_base_url = "https://13df-2405-201-d02a-a86b-69e0-1f60-c792-bb5/api/generations"
def send_training_progress(self, value, max_value):
if max_value < 100:
progress = int((value / max_value) * 100)
self._post_progress(progress)
return
progress_step = max_value // 100
if value % progress_step == 0:
progress = value // progress_step
self._post_progress(progress)
def _post_progress(self, progress):
url = f"{self.api_base_url}/{self.train_id}/update_training_progress/"
headers = {
"Accept": "application/json, text/plain, */*",
"Accept-Language": "en-GB,en-US;q=0.9,en;q=0.8",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "application/json",
"Origin": "http://localhost:3000",
"Pragma": "no-cache",
"Referer": "http://localhost:3000/",
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
"access-token": "311e7d6f-8d78-4a7a-9831-5aa08b3ef06c",
"sec-ch-ua": '"Google Chrome";v="131", "Chromium";v="131", "Not_A Brand";v="24"',
"sec-ch-ua-mobile": "?0",
"sec-ch-ua-platform": '"macOS"',
"uid": "1"
}
payload = {"progress": progress}
try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
print(f"[ProgressInterceptor] Progress updated to {progress}%")
except requests.RequestException as e:
print(f"[ProgressInterceptor] Failed to update progress: {e}")
class SdxlNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
@@ -181,4 +227,20 @@ if __name__ == "__main__":
args = train_util.read_config_from_file(args, parser)
trainer = SdxlNetworkTrainer()
trainer.train(args)
interceptor = ProgressInterceptor(
train_id=args.train_id,
)
try:
trainer.train(args, progress_interceptor=interceptor)
intercepter.send_message(
body={
"status": "FINISHED",
}
)
except Exception as e:
intercepter.send_message(
body={
"status": "FAILED",
}
)
raise e

View File

@@ -134,7 +134,7 @@ class NetworkTrainer:
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
def train(self, args):
def train(self, args, progress_interceptor=None):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
train_util.verify_training_args(args)
@@ -1025,7 +1025,10 @@ class NetworkTrainer:
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
progress_interceptor.send_training_progress(
value=global_step,
max_value=args.max_train_steps,
)
self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# 指定ステップごとにモデルを保存