piecewise cuda graph support qwen3-moe (#11845)
This commit is contained in:
@@ -212,6 +212,10 @@ class LayerCommunicator:
|
||||
)
|
||||
)
|
||||
|
||||
self._speculative_algo = SpeculativeAlgorithm.from_string(
|
||||
get_global_server_args().speculative_algorithm
|
||||
)
|
||||
|
||||
def prepare_attn(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -315,13 +319,10 @@ class LayerCommunicator:
|
||||
def should_fuse_mlp_allreduce_with_next_layer(
|
||||
self, forward_batch: ForwardBatch
|
||||
) -> bool:
|
||||
speculative_algo = SpeculativeAlgorithm.from_string(
|
||||
get_global_server_args().speculative_algorithm
|
||||
)
|
||||
if (
|
||||
is_dp_attention_enabled()
|
||||
and speculative_algo is not None
|
||||
and speculative_algo.is_eagle()
|
||||
and self._speculative_algo is not None
|
||||
and self._speculative_algo.is_eagle()
|
||||
):
|
||||
return False
|
||||
|
||||
|
||||
@@ -1831,3 +1831,21 @@ def triton_scaled_mm(
|
||||
)
|
||||
|
||||
return result.to(out_dtype)
|
||||
|
||||
|
||||
if _is_cuda:
|
||||
if enable_sgl_per_token_group_quant_8bit:
|
||||
|
||||
@torch.library.register_fake("sgl_kernel::sgl_per_token_group_quant_8bit")
|
||||
def _(
|
||||
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
||||
):
|
||||
return
|
||||
|
||||
else:
|
||||
|
||||
@torch.library.register_fake("sgl_kernel::sgl_per_token_group_quant_fp8")
|
||||
def _(
|
||||
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
||||
):
|
||||
return
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
|
||||
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -590,7 +591,12 @@ class Qwen2MoeModel(nn.Module):
|
||||
if residual is not None
|
||||
else hidden_states
|
||||
)
|
||||
with get_global_expert_distribution_recorder().with_current_layer(i):
|
||||
ctx = (
|
||||
nullcontext()
|
||||
if get_global_server_args().enable_piecewise_cuda_graph
|
||||
else get_global_expert_distribution_recorder().with_current_layer(i)
|
||||
)
|
||||
with ctx:
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
|
||||
@@ -55,5 +55,45 @@ class TestPiecewiseCudaGraphBenchmark(CustomTestCase):
|
||||
self.assertLess(prefill_latency, 0.015)
|
||||
|
||||
|
||||
class TestPiecewiseCudaGraphQwen3MoE(CustomTestCase):
|
||||
"""Test piecewise CUDA graph with Qwen3-Coder-30B-A3B-Instruct MoE model"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "Qwen/Qwen3-Coder-30B-A3B-Instruct"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--enable-piecewise-cuda-graph",
|
||||
"--piecewise-cuda-graph-compiler",
|
||||
"eager",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k_accuracy(self):
|
||||
"""Test GSM8K accuracy with 8-shot setting"""
|
||||
num_examples = 2000
|
||||
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mgsm_en",
|
||||
num_examples=num_examples,
|
||||
num_threads=min(num_examples, 1024),
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
print(f"GSM8K Accuracy: {metrics['score']:.3f}")
|
||||
|
||||
self.assertGreaterEqual(metrics["score"], 0.90)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user