[PD] Add PD support for hybrid model (Qwen3-Next, DeepSeek V3.2 Exp) (#10912)
Signed-off-by: Shangming Cai <csmthu@gmail.com> Co-authored-by: hzh0425 <hzh0425@apache.org> Co-authored-by: ZeldaHuang <hzm414167@alibaba-inc.com>
This commit is contained in:
@@ -163,6 +163,7 @@ suites = {
|
||||
TestFile("test_deepseek_v3_basic.py", 275),
|
||||
TestFile("test_deepseek_v3_mtp.py", 275),
|
||||
TestFile("test_disaggregation_different_tp.py", 600),
|
||||
TestFile("test_disaggregation_hybrid_attention.py", 200),
|
||||
TestFile("test_disaggregation_pp.py", 140),
|
||||
],
|
||||
"per-commit-4-gpu-b200": [
|
||||
|
||||
83
test/srt/test_disaggregation_hybrid_attention.py
Normal file
83
test/srt/test_disaggregation_hybrid_attention.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import os
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.environ import envs
|
||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
popen_launch_pd_server,
|
||||
)
|
||||
|
||||
|
||||
class TestDisaggregationHybridAttentionMamba(TestDisaggregationBase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.model = "Qwen/Qwen3-Next-80B-A3B-Instruct"
|
||||
|
||||
# Non blocking start servers
|
||||
cls.start_prefill()
|
||||
cls.start_decode()
|
||||
|
||||
# Block until both
|
||||
cls.wait_server_ready(cls.prefill_url + "/health")
|
||||
cls.wait_server_ready(cls.decode_url + "/health")
|
||||
|
||||
cls.launch_lb()
|
||||
|
||||
@classmethod
|
||||
def start_prefill(cls):
|
||||
prefill_args = [
|
||||
"--trust-remote-code",
|
||||
"--disaggregation-mode",
|
||||
"prefill",
|
||||
"--tp",
|
||||
"4",
|
||||
]
|
||||
prefill_args += cls.transfer_backend + cls.rdma_devices
|
||||
cls.process_prefill = popen_launch_pd_server(
|
||||
cls.model,
|
||||
cls.prefill_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=prefill_args,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def start_decode(cls):
|
||||
decode_args = [
|
||||
"--trust-remote-code",
|
||||
"--disaggregation-mode",
|
||||
"decode",
|
||||
"--tp",
|
||||
"4",
|
||||
"--base-gpu-id",
|
||||
"4",
|
||||
]
|
||||
decode_args += cls.transfer_backend + cls.rdma_devices
|
||||
cls.process_decode = popen_launch_pd_server(
|
||||
cls.model,
|
||||
cls.decode_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=decode_args,
|
||||
)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host=f"http://{self.base_host}",
|
||||
port=int(self.lb_port),
|
||||
)
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
print(f"Evaluation metrics: {metrics}")
|
||||
|
||||
self.assertGreater(metrics["accuracy"], 0.93)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -42,6 +42,7 @@ class TestMamba(unittest.TestCase):
|
||||
full_attention_layer_ids=full_attention_layer_ids,
|
||||
enable_kvcache_transpose=False,
|
||||
device=device,
|
||||
mamba_pool=None,
|
||||
)
|
||||
assert pool._transfer_full_attention_id(global_interval - 1) == 0
|
||||
assert pool._transfer_full_attention_id(2 * global_interval - 1) == 1
|
||||
@@ -173,6 +174,7 @@ class TestMamba(unittest.TestCase):
|
||||
full_attention_layer_ids=full_attention_layer_ids,
|
||||
enable_kvcache_transpose=False,
|
||||
device=device,
|
||||
mamba_pool=req_to_token_pool.mamba_pool,
|
||||
)
|
||||
|
||||
# setup token to kv pool allocator
|
||||
|
||||
Reference in New Issue
Block a user