Files
xc_validation_strategy/main.py

246 lines
9.8 KiB
Python
Raw Permalink Normal View History

2026-06-12 14:24:27 +08:00
"""
xc_validation_strategy 主入口
2026-06-10 21:42:41 +08:00
2026-06-12 14:24:27 +08:00
启动后执行一次模型验证任务批量提交之后保持 HTTP 服务存活
同时暴露 /healthK8s 探活 /status运行状态
"""
2026-06-10 21:42:41 +08:00
2026-06-12 14:24:27 +08:00
import json
import os
import signal
import threading
from datetime import datetime
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from typing import List, Tuple
2026-06-10 21:42:41 +08:00
2026-06-12 14:24:27 +08:00
import requests
2026-06-10 21:42:41 +08:00
2026-06-12 14:24:27 +08:00
# ══════════════════════════════════════════════════════════
# 配置(全部从环境变量读取,不硬编码敏感信息)
# ══════════════════════════════════════════════════════════
BASE_URL = os.environ.get("BASE_URL", "https://modelhub.org.cn")
SUBMIT_ENDPOINT = "/adminApi/async/task/create-contest-task"
2026-06-10 21:42:41 +08:00
2026-06-14 23:54:02 +08:00
# 通过 curl -X POST https://modelhub.org.cn/adminApi/user/login 获取后填入
AUTH_TOKEN = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ1c2VyQWNjb3VudCI6Inpob3VzaGFzaGEiLCJpZCI6MTQsInVzZXJSb2xlIjoibGVhZGVyYm9hcmQiLCJleHAiOjE3ODE4NTE0NzcsImlhdCI6MTc4MTI0NjY3N30.p3uvCpG50aLNifNVVXxvzmWJahbLM5K1671FVCtj8E8"
CONTEST_API_TOKEN = "ef1ef82f3c9efee413d602345fbe224d"
CONTRIBUTORS = "zhoushasha"
GPU_TYPE = "Cambricon_mlu-370-x8"
TASK_TYPE = "text-generation"
STRATEGY_ID = os.environ.get("STRATEGY_ID", "") # 平台自动注入,无需修改
2026-06-10 21:42:41 +08:00
2026-06-12 14:24:27 +08:00
HTTP_HOST = "0.0.0.0"
HTTP_PORT = 8080
2026-06-10 21:42:41 +08:00
2026-06-12 14:24:27 +08:00
# ══════════════════════════════════════════════════════════
# 模型列表
# ══════════════════════════════════════════════════════════
2026-06-10 21:42:41 +08:00
ALL_MODEL_IDS = [
2026-06-18 15:22:29 +08:00
"l3utterfly/mistral-7b-v0.1-layla-v4",
"OpenBuddy/openbuddy-mistral-7b-v13.1",
"allenai/truthfulqa-info-judge-llama2-7B",
"l3utterfly/mistral-7b-v0.1-layla-v1",
"l3utterfly/minima-3b-layla-v2",
"l3utterfly/tinyllama-1.1b-layla-v4",
"l3utterfly/mistral-7b-v0.1-layla-v2",
"l3utterfly/tinyllama-1.1b-layla-v1",
"Duxiaoman-DI/XuanYuan-13B-Chat",
"l3utterfly/minima-3b-layla-v1",
"AI-ModelScope/gemma-2-2b",
"baichuan-inc/Baichuan-13B-Base",
"LGAI-EXAONE/EXAONE-Deep-2.4B",
"NousResearch/DeepHermes-3-Llama-3-3B-Preview",
"Fengshenbang/Ziya2-13B-Base",
"prithivMLmods/QwQ-MathOct-7B",
"l3utterfly/phi-2-layla-v1-chatml",
"argilla/notus-7b-v1",
"prithivMLmods/Doopler-Augment-3B-Cox",
"prithivMLmods/Blaze.1-32B-Instruct",
"CohereLabs/aya-expanse-8B",
"Magpie-Align/MagpieLM-4B-SFT-v0.1",
"Magpie-Align/MagpieLM-8B-SFT-v0.1",
"Magpie-Align/Llama-3-8B-Magpie-Align-SFT-v0.2",
"Magpie-Align/MagpieLM-8B-Chat-v0.1",
"Magpie-Align/Llama-3.1-8B-Magpie-Align-SFT-v0.1",
"Magpie-Align/Llama-3-8B-Magpie-Air-SFT-300K-v0.1",
"prithivMLmods/Tulu-MathLingo-8B",
"prithivMLmods/Triangulum-5B",
"prithivMLmods/Viper-Coder-v0.1",
2026-06-10 21:42:41 +08:00
]
2026-06-12 14:24:27 +08:00
# ══════════════════════════════════════════════════════════
# 全局状态(供 /status 展示)
# ══════════════════════════════════════════════════════════
_state = {
"strategy_id": STRATEGY_ID,
"phase": "starting", # starting | submitting | done | error
"total": len(ALL_MODEL_IDS),
"submitted": 0,
"failed": 0,
"started_at": None,
"finished_at": None,
}
_shutdown = threading.Event()
# ══════════════════════════════════════════════════════════
# HTTP 服务
# ══════════════════════════════════════════════════════════
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
if self.path == "/health":
self._json({"status": "ok"})
elif self.path == "/status":
self._json(_state)
else:
self._json({"error": "not found"}, 404)
def _json(self, body: dict, code: int = 200):
payload = json.dumps(body, default=str).encode()
self.send_response(code)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(payload)))
self.end_headers()
self.wfile.write(payload)
def log_message(self, fmt, *args):
print(f"[http] {self.address_string()} {fmt % args}", flush=True)
def _run_http():
server = ThreadingHTTPServer((HTTP_HOST, HTTP_PORT), Handler)
server.timeout = 1
print(f"[http] 监听 {HTTP_HOST}:{HTTP_PORT}", flush=True)
while not _shutdown.is_set():
server.handle_request()
server.server_close()
print("[http] 已关闭", flush=True)
# ══════════════════════════════════════════════════════════
# 业务逻辑
# ══════════════════════════════════════════════════════════
def _submit_task(token: str, model_id: str) -> Tuple[bool, str]:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {token}",
}
2026-06-10 21:42:41 +08:00
config_content = f"""docker_image: harbor.4pd.io/hardcore-tech/cambricon-mlu370-pytorch:v25.01-torch2.5.0-torchmlu1.24.1-ubuntu22.04-py310
nv_docker_image: harbor.4pd.io/dooke/vllm/vllm/vllm-openai:v0.11.0
framework: vllm
storage: gpfs
modelhub_options:
srcRelativePath: leaderboard/modelHubXC/{model_id}
mountPoint: /model
sut_config:
values:
gpu_num: 1
env:
- name: MAX_MODEL_LEN
value: 8192
command: ["vllm", "serve", "/model", "--port", "8000", "--served-model-name", "llm", "--max-model-len", "8192", "--trust-remote-code", "--dtype", "float16"]
ref_config:
values:
cpu_num: 2
gpu_num: 1
env:
- name: MAX_MODEL_LEN
value: 8192
command: ["vllm", "serve", "/model", "--port", "80", "--served-model-name", "llm", "--max-model-len", "8192", "--trust-remote-code", "--dtype", "float16"]
"""
2026-06-12 14:24:27 +08:00
payload = {
2026-06-10 21:42:41 +08:00
"contestApiToken": CONTEST_API_TOKEN,
2026-06-12 14:24:27 +08:00
"contributors": CONTRIBUTORS,
"gpuTypes": [GPU_TYPE],
"taskType": TASK_TYPE,
"modelId": model_id,
"framework": "vllm",
"strategyId": STRATEGY_ID, # 平台要求
2026-06-10 21:42:41 +08:00
"submissionConfig": [{
2026-06-12 14:24:27 +08:00
"config": config_content,
"gpuType": GPU_TYPE,
"taskType": TASK_TYPE,
}],
2026-06-10 21:42:41 +08:00
}
2026-06-18 15:22:29 +08:00
print(f"[payload] {json.dumps(payload, indent=2, ensure_ascii=False)}", flush=True)
2026-06-10 21:42:41 +08:00
try:
2026-06-12 14:24:27 +08:00
resp = requests.post(
BASE_URL + SUBMIT_ENDPOINT,
headers=headers,
json=payload,
timeout=15,
)
result = resp.json()
if result.get("code") == 0:
task_id = result.get("data", {}).get("id", "")
print(f"[worker] OK {model_id} task_id={task_id}", flush=True)
return True, task_id
2026-06-10 21:42:41 +08:00
else:
2026-06-12 14:24:27 +08:00
print(f"[worker] FAIL {model_id}: {result.get('message')}", flush=True)
return False, ""
2026-06-10 21:42:41 +08:00
except Exception as e:
2026-06-12 14:24:27 +08:00
print(f"[worker] ERROR {model_id}: {e}", flush=True)
return False, ""
2026-06-10 21:42:41 +08:00
2026-06-12 14:24:27 +08:00
def _run_worker():
_state["started_at"] = datetime.utcnow().isoformat()
_state["phase"] = "submitting"
2026-06-10 21:42:41 +08:00
2026-06-12 14:24:27 +08:00
successful: List[Tuple[str, str]] = []
2026-06-14 23:54:02 +08:00
token = AUTH_TOKEN
print("[worker] 使用预设 Token跳过登录", flush=True)
2026-06-10 21:42:41 +08:00
for model_id in ALL_MODEL_IDS:
2026-06-12 14:24:27 +08:00
if _shutdown.is_set():
break
ok, task_id = _submit_task(token, model_id)
if ok:
_state["submitted"] += 1
successful.append((task_id, model_id))
else:
_state["failed"] += 1
# 写入结果文件
try:
with open("submitted_validation_tasks.txt", "w", encoding="utf-8") as f:
for tid, mid in successful:
f.write(f"{tid}\t{mid}\n")
except Exception:
pass
_state["finished_at"] = datetime.utcnow().isoformat()
_state["phase"] = "done"
print(
f"[worker] 完成 submitted={_state['submitted']} failed={_state['failed']}",
flush=True,
)
# 提交完成后继续保持进程存活,等待平台停止
# ══════════════════════════════════════════════════════════
# 入口
# ══════════════════════════════════════════════════════════
def _handle_signal(signum, _frame):
print(f"[main] 收到信号 {signum},正在关闭...", flush=True)
_shutdown.set()
def main():
signal.signal(signal.SIGTERM, _handle_signal)
signal.signal(signal.SIGINT, _handle_signal)
# HTTP 服务线程
http_thread = threading.Thread(target=_run_http, daemon=False)
http_thread.start()
# 提交任务线程
worker_thread = threading.Thread(target=_run_worker, daemon=True)
worker_thread.start()
# 主线程等待 shutdown
_shutdown.wait()
print("[main] 等待 HTTP 服务关闭...", flush=True)
http_thread.join(timeout=5)
print("[main] 退出", flush=True)
2026-06-10 21:42:41 +08:00
if __name__ == "__main__":
main()