piecewise cuda graph support qwen3-moe (#11845)

This commit is contained in:
Xiaoyu Zhang
2025-10-21 10:55:49 +08:00
committed by GitHub
parent 74de76c685
commit 8374a96e49
4 changed files with 71 additions and 6 deletions

View File

@@ -212,6 +212,10 @@ class LayerCommunicator:
) )
) )
self._speculative_algo = SpeculativeAlgorithm.from_string(
get_global_server_args().speculative_algorithm
)
def prepare_attn( def prepare_attn(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -315,13 +319,10 @@ class LayerCommunicator:
def should_fuse_mlp_allreduce_with_next_layer( def should_fuse_mlp_allreduce_with_next_layer(
self, forward_batch: ForwardBatch self, forward_batch: ForwardBatch
) -> bool: ) -> bool:
speculative_algo = SpeculativeAlgorithm.from_string(
get_global_server_args().speculative_algorithm
)
if ( if (
is_dp_attention_enabled() is_dp_attention_enabled()
and speculative_algo is not None and self._speculative_algo is not None
and speculative_algo.is_eagle() and self._speculative_algo.is_eagle()
): ):
return False return False

View File

@@ -1831,3 +1831,21 @@ def triton_scaled_mm(
) )
return result.to(out_dtype) return result.to(out_dtype)
if _is_cuda:
if enable_sgl_per_token_group_quant_8bit:
@torch.library.register_fake("sgl_kernel::sgl_per_token_group_quant_8bit")
def _(
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
):
return
else:
@torch.library.register_fake("sgl_kernel::sgl_per_token_group_quant_fp8")
def _(
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
):
return

View File

@@ -17,6 +17,7 @@
"""Inference-only Qwen2MoE model compatible with HuggingFace weights.""" """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
import logging import logging
from contextlib import nullcontext
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
@@ -590,7 +591,12 @@ class Qwen2MoeModel(nn.Module):
if residual is not None if residual is not None
else hidden_states else hidden_states
) )
with get_global_expert_distribution_recorder().with_current_layer(i): ctx = (
nullcontext()
if get_global_server_args().enable_piecewise_cuda_graph
else get_global_expert_distribution_recorder().with_current_layer(i)
)
with ctx:
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual positions, hidden_states, forward_batch, residual

View File

@@ -55,5 +55,45 @@ class TestPiecewiseCudaGraphBenchmark(CustomTestCase):
self.assertLess(prefill_latency, 0.015) self.assertLess(prefill_latency, 0.015)
class TestPiecewiseCudaGraphQwen3MoE(CustomTestCase):
"""Test piecewise CUDA graph with Qwen3-Coder-30B-A3B-Instruct MoE model"""
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen3-Coder-30B-A3B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--enable-piecewise-cuda-graph",
"--piecewise-cuda-graph-compiler",
"eager",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k_accuracy(self):
"""Test GSM8K accuracy with 8-shot setting"""
num_examples = 2000
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=num_examples,
num_threads=min(num_examples, 1024),
)
metrics = run_eval(args)
print(f"GSM8K Accuracy: {metrics['score']:.3f}")
self.assertGreaterEqual(metrics["score"], 0.90)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()