diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 60b4e9e5f..c386d6111 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -212,6 +212,10 @@ class LayerCommunicator: ) ) + self._speculative_algo = SpeculativeAlgorithm.from_string( + get_global_server_args().speculative_algorithm + ) + def prepare_attn( self, hidden_states: torch.Tensor, @@ -315,13 +319,10 @@ class LayerCommunicator: def should_fuse_mlp_allreduce_with_next_layer( self, forward_batch: ForwardBatch ) -> bool: - speculative_algo = SpeculativeAlgorithm.from_string( - get_global_server_args().speculative_algorithm - ) if ( is_dp_attention_enabled() - and speculative_algo is not None - and speculative_algo.is_eagle() + and self._speculative_algo is not None + and self._speculative_algo.is_eagle() ): return False diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 9ac766a23..e98b7f6ff 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -1831,3 +1831,21 @@ def triton_scaled_mm( ) return result.to(out_dtype) + + +if _is_cuda: + if enable_sgl_per_token_group_quant_8bit: + + @torch.library.register_fake("sgl_kernel::sgl_per_token_group_quant_8bit") + def _( + input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0 + ): + return + + else: + + @torch.library.register_fake("sgl_kernel::sgl_per_token_group_quant_fp8") + def _( + input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0 + ): + return diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 1b5738adb..575f32b60 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -17,6 +17,7 @@ """Inference-only Qwen2MoE model compatible with HuggingFace weights.""" import logging +from contextlib import nullcontext from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch @@ -590,7 +591,12 @@ class Qwen2MoeModel(nn.Module): if residual is not None else hidden_states ) - with get_global_expert_distribution_recorder().with_current_layer(i): + ctx = ( + nullcontext() + if get_global_server_args().enable_piecewise_cuda_graph + else get_global_expert_distribution_recorder().with_current_layer(i) + ) + with ctx: layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, forward_batch, residual diff --git a/test/srt/test_piecewise_cuda_graph.py b/test/srt/test_piecewise_cuda_graph.py index ed41e1e04..4325d9acb 100644 --- a/test/srt/test_piecewise_cuda_graph.py +++ b/test/srt/test_piecewise_cuda_graph.py @@ -55,5 +55,45 @@ class TestPiecewiseCudaGraphBenchmark(CustomTestCase): self.assertLess(prefill_latency, 0.015) +class TestPiecewiseCudaGraphQwen3MoE(CustomTestCase): + """Test piecewise CUDA graph with Qwen3-Coder-30B-A3B-Instruct MoE model""" + + @classmethod + def setUpClass(cls): + cls.model = "Qwen/Qwen3-Coder-30B-A3B-Instruct" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--enable-piecewise-cuda-graph", + "--piecewise-cuda-graph-compiler", + "eager", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k_accuracy(self): + """Test GSM8K accuracy with 8-shot setting""" + num_examples = 2000 + + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=num_examples, + num_threads=min(num_examples, 1024), + ) + + metrics = run_eval(args) + print(f"GSM8K Accuracy: {metrics['score']:.3f}") + + self.assertGreaterEqual(metrics["score"], 0.90) + + if __name__ == "__main__": unittest.main()