mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 09:18:00 +00:00
progress interceptor
This commit is contained in:
@@ -3141,6 +3141,8 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
|
||||
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
||||
parser.add_argument("--train_id", type=int,
|
||||
default=None, help="parameter model id for lora training")
|
||||
parser.add_argument(
|
||||
"--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ"
|
||||
)
|
||||
|
||||
@@ -11,6 +11,52 @@ setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ProgressInterceptor:
|
||||
def __init__(self, train_id):
|
||||
self.train_id = train_id
|
||||
self.api_base_url = "https://13df-2405-201-d02a-a86b-69e0-1f60-c792-bb5/api/generations"
|
||||
|
||||
def send_training_progress(self, value, max_value):
|
||||
if max_value < 100:
|
||||
progress = int((value / max_value) * 100)
|
||||
self._post_progress(progress)
|
||||
return
|
||||
|
||||
progress_step = max_value // 100
|
||||
if value % progress_step == 0:
|
||||
progress = value // progress_step
|
||||
self._post_progress(progress)
|
||||
|
||||
def _post_progress(self, progress):
|
||||
url = f"{self.api_base_url}/{self.train_id}/update_training_progress/"
|
||||
|
||||
headers = {
|
||||
"Accept": "application/json, text/plain, */*",
|
||||
"Accept-Language": "en-GB,en-US;q=0.9,en;q=0.8",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Content-Type": "application/json",
|
||||
"Origin": "http://localhost:3000",
|
||||
"Pragma": "no-cache",
|
||||
"Referer": "http://localhost:3000/",
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
|
||||
"access-token": "311e7d6f-8d78-4a7a-9831-5aa08b3ef06c",
|
||||
"sec-ch-ua": '"Google Chrome";v="131", "Chromium";v="131", "Not_A Brand";v="24"',
|
||||
"sec-ch-ua-mobile": "?0",
|
||||
"sec-ch-ua-platform": '"macOS"',
|
||||
"uid": "1"
|
||||
}
|
||||
|
||||
payload = {"progress": progress}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
print(f"[ProgressInterceptor] Progress updated to {progress}%")
|
||||
except requests.RequestException as e:
|
||||
print(f"[ProgressInterceptor] Failed to update progress: {e}")
|
||||
|
||||
|
||||
class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -181,4 +227,20 @@ if __name__ == "__main__":
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
trainer = SdxlNetworkTrainer()
|
||||
trainer.train(args)
|
||||
interceptor = ProgressInterceptor(
|
||||
train_id=args.train_id,
|
||||
)
|
||||
try:
|
||||
trainer.train(args, progress_interceptor=interceptor)
|
||||
intercepter.send_message(
|
||||
body={
|
||||
"status": "FINISHED",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
intercepter.send_message(
|
||||
body={
|
||||
"status": "FAILED",
|
||||
}
|
||||
)
|
||||
raise e
|
||||
|
||||
@@ -134,7 +134,7 @@ class NetworkTrainer:
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
||||
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
def train(self, args):
|
||||
def train(self, args, progress_interceptor=None):
|
||||
session_id = random.randint(0, 2**32)
|
||||
training_started_at = time.time()
|
||||
train_util.verify_training_args(args)
|
||||
@@ -1025,7 +1025,10 @@ class NetworkTrainer:
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
progress_interceptor.send_training_progress(
|
||||
value=global_step,
|
||||
max_value=args.max_train_steps,
|
||||
)
|
||||
self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
|
||||
Reference in New Issue
Block a user