piecewise cuda graph support qwen3-moe (#11845)

This commit is contained in:
Xiaoyu Zhang
2025-10-21 10:55:49 +08:00
committed by GitHub
parent 74de76c685
commit 8374a96e49
4 changed files with 71 additions and 6 deletions

View File

@@ -212,6 +212,10 @@ class LayerCommunicator:
)
)
self._speculative_algo = SpeculativeAlgorithm.from_string(
get_global_server_args().speculative_algorithm
)
def prepare_attn(
self,
hidden_states: torch.Tensor,
@@ -315,13 +319,10 @@ class LayerCommunicator:
def should_fuse_mlp_allreduce_with_next_layer(
self, forward_batch: ForwardBatch
) -> bool:
speculative_algo = SpeculativeAlgorithm.from_string(
get_global_server_args().speculative_algorithm
)
if (
is_dp_attention_enabled()
and speculative_algo is not None
and speculative_algo.is_eagle()
and self._speculative_algo is not None
and self._speculative_algo.is_eagle()
):
return False

View File

@@ -1831,3 +1831,21 @@ def triton_scaled_mm(
)
return result.to(out_dtype)
if _is_cuda:
if enable_sgl_per_token_group_quant_8bit:
@torch.library.register_fake("sgl_kernel::sgl_per_token_group_quant_8bit")
def _(
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
):
return
else:
@torch.library.register_fake("sgl_kernel::sgl_per_token_group_quant_fp8")
def _(
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
):
return

View File

@@ -17,6 +17,7 @@
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
import logging
from contextlib import nullcontext
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
@@ -590,7 +591,12 @@ class Qwen2MoeModel(nn.Module):
if residual is not None
else hidden_states
)
with get_global_expert_distribution_recorder().with_current_layer(i):
ctx = (
nullcontext()
if get_global_server_args().enable_piecewise_cuda_graph
else get_global_expert_distribution_recorder().with_current_layer(i)
)
with ctx:
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual