mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
correct files
This commit is contained in:
@@ -96,7 +96,7 @@ def append_block_lr_to_logs(block_lrs, logs, lr_scheduler, optimizer_type):
|
||||
train_util.append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names)
|
||||
|
||||
|
||||
def train(args):
|
||||
def train(args, progress_interceptor = None):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
sdxl_train_util.verify_sdxl_training_args(args)
|
||||
@@ -761,6 +761,10 @@ def train(args):
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
progress_interceptor.send_training_progress(
|
||||
value=global_step,
|
||||
max_value=args.max_train_steps,
|
||||
)
|
||||
|
||||
sdxl_train_util.sample_images(
|
||||
accelerator,
|
||||
@@ -942,11 +946,60 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
return parser
|
||||
|
||||
|
||||
class ProgressInterceptor:
|
||||
def __init__(self, train_id):
|
||||
self.train_id = train_id
|
||||
self.api_base_url = "https://13df-2405-201-d02a-a86b-69e0-1f60-c792-bb5/api/generations"
|
||||
|
||||
def send_training_progress(self, value, max_value):
|
||||
if max_value < 100:
|
||||
progress = int((value / max_value) * 100)
|
||||
self._post_progress(progress)
|
||||
return
|
||||
|
||||
progress_step = max_value // 100
|
||||
if value % progress_step == 0:
|
||||
progress = value // progress_step
|
||||
self._post_progress(progress)
|
||||
|
||||
def _post_progress(self, progress):
|
||||
url = f"{self.api_base_url}/{self.train_id}/update_training_progress/"
|
||||
|
||||
headers = {
|
||||
"Accept": "application/json, text/plain, */*",
|
||||
"Accept-Language": "en-GB,en-US;q=0.9,en;q=0.8",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Content-Type": "application/json",
|
||||
"Origin": "http://localhost:3000",
|
||||
"Pragma": "no-cache",
|
||||
"Referer": "http://localhost:3000/",
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
|
||||
"access-token": "311e7d6f-8d78-4a7a-9831-5aa08b3ef06c",
|
||||
"sec-ch-ua": '"Google Chrome";v="131", "Chromium";v="131", "Not_A Brand";v="24"',
|
||||
"sec-ch-ua-mobile": "?0",
|
||||
"sec-ch-ua-platform": '"macOS"',
|
||||
"uid": "1"
|
||||
}
|
||||
|
||||
payload = {"progress": progress}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
print(f"[ProgressInterceptor] Progress updated to {progress}%")
|
||||
except requests.RequestException as e:
|
||||
print(f"[ProgressInterceptor] Failed to update progress: {e}")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
train_util.verify_command_line_training_args(args)
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
interceptor = ProgressInterceptor(
|
||||
train_id=args.train_id,
|
||||
)
|
||||
train(args, interceptor)
|
||||
|
||||
@@ -11,51 +11,6 @@ setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ProgressInterceptor:
|
||||
def __init__(self, train_id):
|
||||
self.train_id = train_id
|
||||
self.api_base_url = "https://13df-2405-201-d02a-a86b-69e0-1f60-c792-bb5/api/generations"
|
||||
|
||||
def send_training_progress(self, value, max_value):
|
||||
if max_value < 100:
|
||||
progress = int((value / max_value) * 100)
|
||||
self._post_progress(progress)
|
||||
return
|
||||
|
||||
progress_step = max_value // 100
|
||||
if value % progress_step == 0:
|
||||
progress = value // progress_step
|
||||
self._post_progress(progress)
|
||||
|
||||
def _post_progress(self, progress):
|
||||
url = f"{self.api_base_url}/{self.train_id}/update_training_progress/"
|
||||
|
||||
headers = {
|
||||
"Accept": "application/json, text/plain, */*",
|
||||
"Accept-Language": "en-GB,en-US;q=0.9,en;q=0.8",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Content-Type": "application/json",
|
||||
"Origin": "http://localhost:3000",
|
||||
"Pragma": "no-cache",
|
||||
"Referer": "http://localhost:3000/",
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
|
||||
"access-token": "311e7d6f-8d78-4a7a-9831-5aa08b3ef06c",
|
||||
"sec-ch-ua": '"Google Chrome";v="131", "Chromium";v="131", "Not_A Brand";v="24"',
|
||||
"sec-ch-ua-mobile": "?0",
|
||||
"sec-ch-ua-platform": '"macOS"',
|
||||
"uid": "1"
|
||||
}
|
||||
|
||||
payload = {"progress": progress}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
print(f"[ProgressInterceptor] Progress updated to {progress}%")
|
||||
except requests.RequestException as e:
|
||||
print(f"[ProgressInterceptor] Failed to update progress: {e}")
|
||||
|
||||
|
||||
class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
def __init__(self):
|
||||
|
||||
Reference in New Issue
Block a user