diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index af7a47651..2c1ca2f8b 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -812,9 +812,9 @@ class Qwen3MoeForCausalLM(nn.Module): logger.warning(f"Parameter {name} not found in params_dict") self.routed_experts_weights_of_layer = { - layer_id: layer.mlp.get_moe_weights() - for layer_id, layer in enumerate(self.model.layers) - if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock) + layer_id: self.model.layers[layer_id].mlp.get_moe_weights() + for layer_id in range(self.start_layer, self.end_layer) + if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock) } @classmethod diff --git a/test/srt/test_pp_single_node.py b/test/srt/test_pp_single_node.py index efd894fab..dbac4c771 100644 --- a/test/srt/test_pp_single_node.py +++ b/test/srt/test_pp_single_node.py @@ -121,7 +121,7 @@ class TestQwenPPAccuracy(unittest.TestCase): class TestQwenPPTieWeightsAccuracy(unittest.TestCase): @classmethod def setUpClass(cls): - cls.base_url = "http://127.0.0.1:23334" # different ports to avoid conflicts + cls.base_url = "http://127.0.0.1:23335" # different ports to avoid conflicts cls.model_name = ( "Qwen/Qwen3-0.6B" # qwen3 < 8B all have tie_word_embeddings = True ) @@ -176,6 +176,62 @@ class TestQwenPPTieWeightsAccuracy(unittest.TestCase): ) +class TestQwenMoePPAccuracy(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.base_url = "http://127.0.0.1:23336" # different ports to avoid conflicts + cls.model_name = "Qwen/Qwen3-30B-A3B" # replace with your Qwen Model if needed + + def run_gsm8k_test(self, pp_size): + process = popen_launch_server( + self.model_name, + self.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--pp-size", + pp_size, + "--chunked-prefill-size", + 256, + ], + ) + + try: + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + time.sleep(5) + return metrics + finally: + kill_process_tree(process.pid) + + def test_baseline_accuracy(self): + metrics = self.run_gsm8k_test(pp_size=1) + print(f"[Qwen Baseline] {metrics=}") + self.assertGreater(metrics["accuracy"], 0.74) + + def test_pp_consistency(self): + baseline = self.run_gsm8k_test(pp_size=1) + pp_metrics = self.run_gsm8k_test(pp_size=2) + + print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}") + + self.assertGreaterEqual( + pp_metrics["accuracy"], + baseline["accuracy"] - 0.01, + msg=( + f"PP accuracy dropped more than 1% compared to baseline. " + f"Baseline: {baseline['accuracy']:.2%}, PP: {pp_metrics['accuracy']:.2%}" + ), + ) + + class TestFixedBugs(unittest.TestCase): def test_chunked_prefill_with_small_bs(self): model = DEFAULT_MODEL_NAME_FOR_TEST