diff --git a/main.py b/main.py index 48f5f14..f49a33b 100644 --- a/main.py +++ b/main.py @@ -1,31 +1,43 @@ -import requests +""" +xc_validation_strategy — 主入口 + +启动后执行一次模型验证任务批量提交,之后保持 HTTP 服务存活。 +同时暴露 /health(K8s 探活)和 /status(运行状态)。 +""" + import json +import os +import signal +import threading +import traceback +from datetime import datetime +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from typing import List, Tuple -# ========== 全局配置 ========== -BASE_URL = "https://modelhub.org.cn" -LOGIN_ENDPOINT = "/adminApi/user/login" -SUBMIT_TEST_TASK_ENDPOINT = "/adminApi/async/task/create-contest-task" +import requests -USER_ACCOUNT = "zhoushasha@4paradigm.com" -USER_PASSWORD = "4pdpassword" +# ══════════════════════════════════════════════════════════ +# 配置(全部从环境变量读取,不硬编码敏感信息) +# ══════════════════════════════════════════════════════════ +BASE_URL = os.environ.get("BASE_URL", "https://modelhub.org.cn") +LOGIN_ENDPOINT = "/adminApi/user/login" +SUBMIT_ENDPOINT = "/adminApi/async/task/create-contest-task" +USER_ACCOUNT = os.environ["USER_ACCOUNT"] # 必填 +USER_PASSWORD = os.environ["USER_PASSWORD"] # 必填 +CONTEST_API_TOKEN = os.environ["CONTEST_API_TOKEN"] # 必填 +STRATEGY_ID = os.environ.get("STRATEGY_ID", "") # 平台注入 +CONTRIBUTORS = os.environ.get("CONTRIBUTORS", USER_ACCOUNT) +GPU_TYPE = os.environ.get("GPU_TYPE", "Cambricon_mlu-370-x8") +TASK_TYPE = os.environ.get("TASK_TYPE", "text-generation") -CONTEST_API_TOKEN = "ef1ef82f3c9efee413d602345fbe224d" +HTTP_HOST = "0.0.0.0" +HTTP_PORT = 8080 - -CONTRIBUTORS = "zhoushasha" - - -GPU_TYPE = "Cambricon_mlu-370-x8" - -TASK_TYPE = "text-generation" - -HEADERS = {"Content-Type": "application/json"} - -# ======== 模型列表(保持不变)======== +# ══════════════════════════════════════════════════════════ +# 模型列表 +# ══════════════════════════════════════════════════════════ ALL_MODEL_IDS = [ - "AI-ModelScope/gemma-2b", "AI-ModelScope/falcon-mamba-7b", "katanemo/deepseek-2", @@ -46,27 +58,78 @@ ALL_MODEL_IDS = [ "Qwen/CodeQwen1.5-7B-Chat", "OpenBMB/cpm-bee-10b", "OpenBMB/MiniCPM3-4B", - - ] -# === 登录获取 token === -def login(): - payload = {"userAccount": USER_ACCOUNT, "userPassword": USER_PASSWORD} - print("🔑 正在登录...") - resp = requests.post(BASE_URL + LOGIN_ENDPOINT, headers=HEADERS, json=payload) - if resp.status_code != 200: - raise Exception(f"HTTP 登录失败: {resp.text}") +# ══════════════════════════════════════════════════════════ +# 全局状态(供 /status 展示) +# ══════════════════════════════════════════════════════════ +_state = { + "strategy_id": STRATEGY_ID, + "phase": "starting", # starting | submitting | done | error + "total": len(ALL_MODEL_IDS), + "submitted": 0, + "failed": 0, + "started_at": None, + "finished_at": None, +} +_shutdown = threading.Event() + +# ══════════════════════════════════════════════════════════ +# HTTP 服务 +# ══════════════════════════════════════════════════════════ +class Handler(BaseHTTPRequestHandler): + def do_GET(self): + if self.path == "/health": + self._json({"status": "ok"}) + elif self.path == "/status": + self._json(_state) + else: + self._json({"error": "not found"}, 404) + + def _json(self, body: dict, code: int = 200): + payload = json.dumps(body, default=str).encode() + self.send_response(code) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(payload))) + self.end_headers() + self.wfile.write(payload) + + def log_message(self, fmt, *args): + print(f"[http] {self.address_string()} {fmt % args}", flush=True) + + +def _run_http(): + server = ThreadingHTTPServer((HTTP_HOST, HTTP_PORT), Handler) + server.timeout = 1 + print(f"[http] 监听 {HTTP_HOST}:{HTTP_PORT}", flush=True) + while not _shutdown.is_set(): + server.handle_request() + server.server_close() + print("[http] 已关闭", flush=True) + +# ══════════════════════════════════════════════════════════ +# 业务逻辑 +# ══════════════════════════════════════════════════════════ +def _login() -> str: + headers = {"Content-Type": "application/json"} + resp = requests.post( + BASE_URL + LOGIN_ENDPOINT, + headers=headers, + json={"userAccount": USER_ACCOUNT, "userPassword": USER_PASSWORD}, + timeout=30, + ) data = resp.json() if data.get("code") != 0: - raise Exception(f"业务登录失败: {data.get('message')}") - token = data["data"]["token"] - print("✅ 登录成功!") - return token + raise RuntimeError(f"登录失败: {data.get('message')}") + print("[worker] 登录成功", flush=True) + return data["data"]["token"] -# === 提交单个模型的测试任务(vLLM + kunlunxin_p-800)=== -def submit_test_task(token: str, model_id: str) -> Tuple[str, str]: - auth_headers = {**HEADERS, "Authorization": f"Bearer {token}"} + +def _submit_task(token: str, model_id: str) -> Tuple[bool, str]: + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {token}", + } config_content = f"""docker_image: harbor.4pd.io/hardcore-tech/cambricon-mlu370-pytorch:v25.01-torch2.5.0-torchmlu1.24.1-ubuntu22.04-py310 nv_docker_image: harbor.4pd.io/dooke/vllm/vllm/vllm-openai:v0.11.0 framework: vllm @@ -90,66 +153,104 @@ ref_config: value: 8192 command: ["vllm", "serve", "/model", "--port", "80", "--served-model-name", "llm", "--max-model-len", "8192", "--trust-remote-code", "--dtype", "float16"] """ - task_data = { + payload = { "contestApiToken": CONTEST_API_TOKEN, - "contributors": CONTRIBUTORS, - "gpuTypes": [GPU_TYPE], - "taskType": TASK_TYPE, - "modelId": model_id, - "framework": "vllm", + "contributors": CONTRIBUTORS, + "gpuTypes": [GPU_TYPE], + "taskType": TASK_TYPE, + "modelId": model_id, + "framework": "vllm", + "strategyId": STRATEGY_ID, # 平台要求 "submissionConfig": [{ - "config": config_content, - "gpuType": GPU_TYPE, - "taskType": TASK_TYPE - }] + "config": config_content, + "gpuType": GPU_TYPE, + "taskType": TASK_TYPE, + }], } - print(f"📤 提交验证任务: {model_id} (GPU: {GPU_TYPE})") try: - resp = requests.post(BASE_URL + SUBMIT_TEST_TASK_ENDPOINT, json=task_data, headers=auth_headers, timeout=15) - if resp.status_code == 200: - result = resp.json() - if result.get("code") == 0: - task_id = result.get("data", {}).get("id") - print(f"✅ 验证任务提交成功! Task ID: {task_id}") - return task_id, model_id - else: - print(f"❌ 验证任务业务错误 ({model_id}): {result.get('message')}") - return None, model_id + resp = requests.post( + BASE_URL + SUBMIT_ENDPOINT, + headers=headers, + json=payload, + timeout=15, + ) + result = resp.json() + if result.get("code") == 0: + task_id = result.get("data", {}).get("id", "") + print(f"[worker] OK {model_id} task_id={task_id}", flush=True) + return True, task_id else: - print(f"❌ 验证任务 HTTP 错误 ({model_id}): {resp.status_code} - {resp.text}") - return None, model_id + print(f"[worker] FAIL {model_id}: {result.get('message')}", flush=True) + return False, "" except Exception as e: - print(f"💥 提交验证任务异常 ({model_id}): {e}") - return None, model_id + print(f"[worker] ERROR {model_id}: {e}", flush=True) + return False, "" -# === 主函数:仅提交验证任务 === -def main(): - if not ALL_MODEL_IDS: - print("❌ 模型列表为空,请在 ALL_MODEL_IDS 中填入模型ID") + +def _run_worker(): + _state["started_at"] = datetime.utcnow().isoformat() + _state["phase"] = "submitting" + + successful: List[Tuple[str, str]] = [] + try: + token = _login() + except Exception: + traceback.print_exc() + _state["phase"] = "error" return - token = login() - total_count = len(ALL_MODEL_IDS) - print(f"📊 共 {total_count} 个模型待提交验证任务") - - successful_tasks: List[Tuple[str, str]] = [] # (task_id, model_id) - for model_id in ALL_MODEL_IDS: - task_id, mid = submit_test_task(token, model_id) - if task_id: - successful_tasks.append((task_id, mid)) + if _shutdown.is_set(): + break + ok, task_id = _submit_task(token, model_id) + if ok: + _state["submitted"] += 1 + successful.append((task_id, model_id)) + else: + _state["failed"] += 1 - # 写入成功提交的 task_id 和 model_id 到文件 - with open("submitted_validation_tasks.txt", "w", encoding="utf-8") as f: - for tid, mid in successful_tasks: - f.write(f"{tid}\t{mid}\n") + # 写入结果文件 + try: + with open("submitted_validation_tasks.txt", "w", encoding="utf-8") as f: + for tid, mid in successful: + f.write(f"{tid}\t{mid}\n") + except Exception: + pass + + _state["finished_at"] = datetime.utcnow().isoformat() + _state["phase"] = "done" + print( + f"[worker] 完成 submitted={_state['submitted']} failed={_state['failed']}", + flush=True, + ) + # 提交完成后继续保持进程存活,等待平台停止 + +# ══════════════════════════════════════════════════════════ +# 入口 +# ══════════════════════════════════════════════════════════ +def _handle_signal(signum, _frame): + print(f"[main] 收到信号 {signum},正在关闭...", flush=True) + _shutdown.set() + + +def main(): + signal.signal(signal.SIGTERM, _handle_signal) + signal.signal(signal.SIGINT, _handle_signal) + + # HTTP 服务线程 + http_thread = threading.Thread(target=_run_http, daemon=False) + http_thread.start() + + # 提交任务线程 + worker_thread = threading.Thread(target=_run_worker, daemon=True) + worker_thread.start() + + # 主线程等待 shutdown + _shutdown.wait() + print("[main] 等待 HTTP 服务关闭...", flush=True) + http_thread.join(timeout=5) + print("[main] 退出", flush=True) - # 最终统计 - print("\n" + "=" * 60) - print(f"🎉 全部完成!") - print(f"✅ 成功提交验证任务: {len(successful_tasks)}") - print(f"📄 详情已写入: submitted_validation_tasks.txt") - print(f"📊 总计尝试: {total_count}") if __name__ == "__main__": main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f229360 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +requests