Files
xc_validation_strategy/main.py
2026-06-18 15:22:29 +08:00

246 lines
9.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
xc_validation_strategy — 主入口
启动后执行一次模型验证任务批量提交,之后保持 HTTP 服务存活。
同时暴露 /healthK8s 探活)和 /status运行状态
"""
import json
import os
import signal
import threading
from datetime import datetime
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from typing import List, Tuple
import requests
# ══════════════════════════════════════════════════════════
# 配置(全部从环境变量读取,不硬编码敏感信息)
# ══════════════════════════════════════════════════════════
BASE_URL = os.environ.get("BASE_URL", "https://modelhub.org.cn")
SUBMIT_ENDPOINT = "/adminApi/async/task/create-contest-task"
# 通过 curl -X POST https://modelhub.org.cn/adminApi/user/login 获取后填入
AUTH_TOKEN = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ1c2VyQWNjb3VudCI6Inpob3VzaGFzaGEiLCJpZCI6MTQsInVzZXJSb2xlIjoibGVhZGVyYm9hcmQiLCJleHAiOjE3ODE4NTE0NzcsImlhdCI6MTc4MTI0NjY3N30.p3uvCpG50aLNifNVVXxvzmWJahbLM5K1671FVCtj8E8"
CONTEST_API_TOKEN = "ef1ef82f3c9efee413d602345fbe224d"
CONTRIBUTORS = "zhoushasha"
GPU_TYPE = "Cambricon_mlu-370-x8"
TASK_TYPE = "text-generation"
STRATEGY_ID = os.environ.get("STRATEGY_ID", "") # 平台自动注入,无需修改
HTTP_HOST = "0.0.0.0"
HTTP_PORT = 8080
# ══════════════════════════════════════════════════════════
# 模型列表
# ══════════════════════════════════════════════════════════
ALL_MODEL_IDS = [
"l3utterfly/mistral-7b-v0.1-layla-v4",
"OpenBuddy/openbuddy-mistral-7b-v13.1",
"allenai/truthfulqa-info-judge-llama2-7B",
"l3utterfly/mistral-7b-v0.1-layla-v1",
"l3utterfly/minima-3b-layla-v2",
"l3utterfly/tinyllama-1.1b-layla-v4",
"l3utterfly/mistral-7b-v0.1-layla-v2",
"l3utterfly/tinyllama-1.1b-layla-v1",
"Duxiaoman-DI/XuanYuan-13B-Chat",
"l3utterfly/minima-3b-layla-v1",
"AI-ModelScope/gemma-2-2b",
"baichuan-inc/Baichuan-13B-Base",
"LGAI-EXAONE/EXAONE-Deep-2.4B",
"NousResearch/DeepHermes-3-Llama-3-3B-Preview",
"Fengshenbang/Ziya2-13B-Base",
"prithivMLmods/QwQ-MathOct-7B",
"l3utterfly/phi-2-layla-v1-chatml",
"argilla/notus-7b-v1",
"prithivMLmods/Doopler-Augment-3B-Cox",
"prithivMLmods/Blaze.1-32B-Instruct",
"CohereLabs/aya-expanse-8B",
"Magpie-Align/MagpieLM-4B-SFT-v0.1",
"Magpie-Align/MagpieLM-8B-SFT-v0.1",
"Magpie-Align/Llama-3-8B-Magpie-Align-SFT-v0.2",
"Magpie-Align/MagpieLM-8B-Chat-v0.1",
"Magpie-Align/Llama-3.1-8B-Magpie-Align-SFT-v0.1",
"Magpie-Align/Llama-3-8B-Magpie-Air-SFT-300K-v0.1",
"prithivMLmods/Tulu-MathLingo-8B",
"prithivMLmods/Triangulum-5B",
"prithivMLmods/Viper-Coder-v0.1",
]
# ══════════════════════════════════════════════════════════
# 全局状态(供 /status 展示)
# ══════════════════════════════════════════════════════════
_state = {
"strategy_id": STRATEGY_ID,
"phase": "starting", # starting | submitting | done | error
"total": len(ALL_MODEL_IDS),
"submitted": 0,
"failed": 0,
"started_at": None,
"finished_at": None,
}
_shutdown = threading.Event()
# ══════════════════════════════════════════════════════════
# HTTP 服务
# ══════════════════════════════════════════════════════════
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
if self.path == "/health":
self._json({"status": "ok"})
elif self.path == "/status":
self._json(_state)
else:
self._json({"error": "not found"}, 404)
def _json(self, body: dict, code: int = 200):
payload = json.dumps(body, default=str).encode()
self.send_response(code)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(payload)))
self.end_headers()
self.wfile.write(payload)
def log_message(self, fmt, *args):
print(f"[http] {self.address_string()} {fmt % args}", flush=True)
def _run_http():
server = ThreadingHTTPServer((HTTP_HOST, HTTP_PORT), Handler)
server.timeout = 1
print(f"[http] 监听 {HTTP_HOST}:{HTTP_PORT}", flush=True)
while not _shutdown.is_set():
server.handle_request()
server.server_close()
print("[http] 已关闭", flush=True)
# ══════════════════════════════════════════════════════════
# 业务逻辑
# ══════════════════════════════════════════════════════════
def _submit_task(token: str, model_id: str) -> Tuple[bool, str]:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {token}",
}
config_content = f"""docker_image: harbor.4pd.io/hardcore-tech/cambricon-mlu370-pytorch:v25.01-torch2.5.0-torchmlu1.24.1-ubuntu22.04-py310
nv_docker_image: harbor.4pd.io/dooke/vllm/vllm/vllm-openai:v0.11.0
framework: vllm
storage: gpfs
modelhub_options:
srcRelativePath: leaderboard/modelHubXC/{model_id}
mountPoint: /model
sut_config:
values:
gpu_num: 1
env:
- name: MAX_MODEL_LEN
value: 8192
command: ["vllm", "serve", "/model", "--port", "8000", "--served-model-name", "llm", "--max-model-len", "8192", "--trust-remote-code", "--dtype", "float16"]
ref_config:
values:
cpu_num: 2
gpu_num: 1
env:
- name: MAX_MODEL_LEN
value: 8192
command: ["vllm", "serve", "/model", "--port", "80", "--served-model-name", "llm", "--max-model-len", "8192", "--trust-remote-code", "--dtype", "float16"]
"""
payload = {
"contestApiToken": CONTEST_API_TOKEN,
"contributors": CONTRIBUTORS,
"gpuTypes": [GPU_TYPE],
"taskType": TASK_TYPE,
"modelId": model_id,
"framework": "vllm",
"strategyId": STRATEGY_ID, # 平台要求
"submissionConfig": [{
"config": config_content,
"gpuType": GPU_TYPE,
"taskType": TASK_TYPE,
}],
}
print(f"[payload] {json.dumps(payload, indent=2, ensure_ascii=False)}", flush=True)
try:
resp = requests.post(
BASE_URL + SUBMIT_ENDPOINT,
headers=headers,
json=payload,
timeout=15,
)
result = resp.json()
if result.get("code") == 0:
task_id = result.get("data", {}).get("id", "")
print(f"[worker] OK {model_id} task_id={task_id}", flush=True)
return True, task_id
else:
print(f"[worker] FAIL {model_id}: {result.get('message')}", flush=True)
return False, ""
except Exception as e:
print(f"[worker] ERROR {model_id}: {e}", flush=True)
return False, ""
def _run_worker():
_state["started_at"] = datetime.utcnow().isoformat()
_state["phase"] = "submitting"
successful: List[Tuple[str, str]] = []
token = AUTH_TOKEN
print("[worker] 使用预设 Token跳过登录", flush=True)
for model_id in ALL_MODEL_IDS:
if _shutdown.is_set():
break
ok, task_id = _submit_task(token, model_id)
if ok:
_state["submitted"] += 1
successful.append((task_id, model_id))
else:
_state["failed"] += 1
# 写入结果文件
try:
with open("submitted_validation_tasks.txt", "w", encoding="utf-8") as f:
for tid, mid in successful:
f.write(f"{tid}\t{mid}\n")
except Exception:
pass
_state["finished_at"] = datetime.utcnow().isoformat()
_state["phase"] = "done"
print(
f"[worker] 完成 submitted={_state['submitted']} failed={_state['failed']}",
flush=True,
)
# 提交完成后继续保持进程存活,等待平台停止
# ══════════════════════════════════════════════════════════
# 入口
# ══════════════════════════════════════════════════════════
def _handle_signal(signum, _frame):
print(f"[main] 收到信号 {signum},正在关闭...", flush=True)
_shutdown.set()
def main():
signal.signal(signal.SIGTERM, _handle_signal)
signal.signal(signal.SIGINT, _handle_signal)
# HTTP 服务线程
http_thread = threading.Thread(target=_run_http, daemon=False)
http_thread.start()
# 提交任务线程
worker_thread = threading.Thread(target=_run_worker, daemon=True)
worker_thread.start()
# 主线程等待 shutdown
_shutdown.wait()
print("[main] 等待 HTTP 服务关闭...", flush=True)
http_thread.join(timeout=5)
print("[main] 退出", flush=True)
if __name__ == "__main__":
main()