Files
sglang/test/srt/test_cutedsl_flashinfer_8gpu.py

78 lines
2.1 KiB
Python

import os
import unittest
from types import SimpleNamespace
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
try_cached_model,
)
class TestDeepseekR1Nvfp4CuteDSLDeepEP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = try_cached_model(DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST)
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = [
"--trust-remote-code",
"--disable-radix-cache",
"--max-running-requests",
"256",
"--chunked-prefill-size",
"2048",
"--tp",
"8",
"--dp",
"8",
"--enable-dp-attention",
"--enable-ep-moe",
"--quantization",
"modelopt_fp4",
"--enable-flashinfer-cutedsl-moe",
"--enable-deepep-moe",
"--deepep-mode",
"low_latency",
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
env={
**os.environ,
"SGLANG_DEEPEP_BF16_DISPATCH": "1",
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "256",
},
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=512,
parallel=512,
max_new_tokens=512,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Eval accuracy of GSM8K: {metrics=}")
self.assertGreater(metrics["accuracy"], 0.92)
if __name__ == "__main__":
unittest.main()