diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 32348b590..d4cf85f72 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -32,6 +32,7 @@ from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS from sglang.srt.distributed import ( + get_pp_group, get_tp_group, get_world_group, init_distributed_environment, @@ -639,6 +640,7 @@ class ModelRunner: cpu_group=get_world_group().cpu_group, ) self.tp_group = get_tp_group() + self.pp_group = get_pp_group() self.attention_tp_group = get_attention_tp_group() # Check memory for tensor parallelism @@ -1825,7 +1827,10 @@ class ModelRunner: else: raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}") - if forward_batch.global_num_tokens_cpu is not None: + if ( + forward_batch.global_num_tokens_cpu is not None + and self.pp_group.is_last_rank + ): forward_batch.post_forward_mlp_sync_batch(ret) return ret, can_run_cuda_graph diff --git a/test/srt/test_pp_single_node.py b/test/srt/test_pp_single_node.py index f1fb3e212..e333c2d4c 100644 --- a/test/srt/test_pp_single_node.py +++ b/test/srt/test_pp_single_node.py @@ -14,11 +14,14 @@ import requests from sglang.bench_one_batch_server import BenchArgs as OneBatchBenchArgs from sglang.srt.server_args import ServerArgs from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, + CustomTestCase, is_in_ci, popen_launch_server, run_bench_one_batch_server, @@ -57,7 +60,7 @@ class TestPPAccuracy(unittest.TestCase): host="http://127.0.0.1", port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval(args) + metrics = run_eval_few_shot_gsm8k(args) print(f"{metrics=}") self.assertGreater(metrics["accuracy"], 0.74) @@ -88,6 +91,45 @@ class TestPPAccuracy(unittest.TestCase): assert len(output_top_logprobs) == 16 +class TestDPAttentionDP2PP2(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "2", + "--pp-size", + "2", + "--enable-dp-attention", + "--dp", + "2", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["score"], 0.8) + + class TestQwenPPAccuracy(unittest.TestCase): @classmethod def setUpClass(cls): @@ -117,7 +159,7 @@ class TestQwenPPAccuracy(unittest.TestCase): host="http://127.0.0.1", port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval(args) + metrics = run_eval_few_shot_gsm8k(args) time.sleep(5) return metrics finally: @@ -172,7 +214,7 @@ class TestQwenPPTieWeightsAccuracy(unittest.TestCase): host="http://127.0.0.1", port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval(args) + metrics = run_eval_few_shot_gsm8k(args) time.sleep(5) return metrics finally: @@ -224,7 +266,7 @@ class TestQwenMoePPAccuracy(unittest.TestCase): host="http://127.0.0.1", port=int(self.base_url.split(":")[-1]), ) - metrics = run_eval(args) + metrics = run_eval_few_shot_gsm8k(args) time.sleep(5) return metrics finally: