Support Qwen3-Next on Ascend NPU (#10379)

This commit is contained in:
Even Zhou
2025-09-13 07:31:37 +08:00
committed by GitHub
parent d5e2a37414
commit 16cd550c85
10 changed files with 79 additions and 26 deletions

View File

@@ -158,7 +158,7 @@ def _layer_norm_fwd(
# heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8)
grid = (M, ngroups)
with torch.cuda.device(x.device.index):
with torch.get_device_module(x.device).device(x.device.index):
_layer_norm_fwd_1pass_kernel[grid](
x,
out,

View File

@@ -23,6 +23,22 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer, fused_gdn_gating
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import is_npu
if is_npu():
from sgl_kernel_npu.fla.chunk import chunk_gated_delta_rule_npu
from sgl_kernel_npu.fla.fused_sigmoid_gating_recurrent import (
fused_sigmoid_gating_delta_rule_update_npu,
)
from sgl_kernel_npu.mamba.causal_conv1d import (
causal_conv1d_fn_npu,
causal_conv1d_update_npu,
)
chunk_gated_delta_rule = chunk_gated_delta_rule_npu
fused_sigmoid_gating_delta_rule_update = fused_sigmoid_gating_delta_rule_update_npu
causal_conv1d_fn = causal_conv1d_fn_npu
causal_conv1d_update = causal_conv1d_update_npu
@dataclass
@@ -85,10 +101,12 @@ class MambaAttnBackend(AttentionBackend):
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
for i in range(max_bs):
self.state_indices_list.append(
torch.full((i + 1,), self.pad_slot_id, dtype=torch.int32, device="cuda")
torch.full(
(i + 1,), self.pad_slot_id, dtype=torch.int32, device=self.device
)
)
self.query_start_loc_list.append(
torch.empty((i + 2,), dtype=torch.int32, device="cuda")
torch.empty((i + 2,), dtype=torch.int32, device=self.device)
)
def init_forward_metadata_capture_cuda_graph(
@@ -110,7 +128,7 @@ class MambaAttnBackend(AttentionBackend):
bs * spec_info.draft_token_num + 1,
step=spec_info.draft_token_num,
dtype=torch.int32,
device="cuda",
device=self.device,
)
)
else:
@@ -152,7 +170,7 @@ class MambaAttnBackend(AttentionBackend):
bs * spec_info.draft_token_num + 1,
step=spec_info.draft_token_num,
dtype=torch.int32,
device="cuda",
device=self.device,
)
)
if num_padding > 0:

View File

@@ -649,6 +649,7 @@ class HybridLinearKVPool(KVCache):
self,
size: int,
dtype: torch.dtype,
page_size: int,
head_num: int,
head_dim: int,
full_attention_layer_ids: List[int],
@@ -659,10 +660,14 @@ class HybridLinearKVPool(KVCache):
self.dtype = dtype
self.device = device
self.full_layer_nums = len(full_attention_layer_ids)
self.page_size = 1
self.page_size = page_size
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
assert not enable_kvcache_transpose
self.full_kv_pool = MHATokenToKVPool(
if _is_npu:
TokenToKVPoolClass = AscendTokenToKVPool
else:
TokenToKVPoolClass = MHATokenToKVPool
self.full_kv_pool = TokenToKVPoolClass(
size=size,
page_size=self.page_size,
dtype=dtype,
@@ -904,8 +909,12 @@ class AscendTokenToKVPool(MHATokenToKVPool):
cache_v: torch.Tensor,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
layer_id_override: Optional[int] = None,
):
layer_id = layer.layer_id
if layer_id_override is not None:
layer_id = layer_id_override
else:
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
if k_scale is not None:
cache_k.div_(k_scale)

View File

@@ -1567,6 +1567,7 @@ class ModelRunner:
)
elif self.is_hybrid_gdn:
self.token_to_kv_pool = HybridLinearKVPool(
page_size=self.page_size if _is_npu else 1,
size=self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(
@@ -1601,7 +1602,10 @@ class ModelRunner:
# Initialize token_to_kv_pool_allocator
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
if self.token_to_kv_pool_allocator is None:
if self.server_args.attention_backend == "ascend":
if _is_npu and self.server_args.attention_backend in [
"ascend",
"hybrid_linear_attn",
]:
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
self.max_total_num_tokens,
page_size=self.page_size,
@@ -1819,15 +1823,22 @@ class ModelRunner:
assert (
self.is_hybrid_gdn
), "hybrid_linear_attn backend can only be used with hybrid GDN models."
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
)
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
HybridLinearAttnBackend,
MambaAttnBackend,
)
full_attn_backend = FlashAttentionBackend(self)
if _is_npu:
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
full_attn_backend = AscendAttnBackend(self)
else:
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
)
full_attn_backend = FlashAttentionBackend(self)
linear_attn_backend = MambaAttnBackend(self)
full_attn_layers = self.model_config.hf_config.full_attention_layer_ids
return HybridLinearAttnBackend(

View File

@@ -46,10 +46,11 @@ from sglang.srt.model_loader.weight_utils import (
sharded_weight_loader,
)
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock
from sglang.srt.utils import add_prefix, is_cuda, make_layers, set_weight_attrs
from sglang.srt.utils import add_prefix, is_cuda, is_npu, make_layers, set_weight_attrs
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_npu = is_npu()
import triton
import triton.language as tl
@@ -327,7 +328,7 @@ class Qwen3GatedDeltaNet(nn.Module):
eps=self.layer_norm_epsilon,
group_size=None,
norm_before_gate=True,
device=torch.cuda.current_device(),
device=torch.get_device_module().current_device(),
dtype=config.torch_dtype,
)
@@ -388,7 +389,7 @@ class Qwen3GatedDeltaNet(nn.Module):
return query, key, value, z, b, a
def _forward_input_proj(self, hidden_states: torch.Tensor):
DUAL_STREAM_TOKEN_THRESHOLD = 1024
DUAL_STREAM_TOKEN_THRESHOLD = 1024 if not _is_npu else 0
seq_len, _ = hidden_states.shape
if seq_len < DUAL_STREAM_TOKEN_THRESHOLD:
current_stream = torch.cuda.current_stream()
@@ -454,6 +455,8 @@ class Qwen3GatedDeltaNet(nn.Module):
"dt_bias": self.dt_bias,
"layer_id": self.layer_id,
"seq_len": seq_len,
"num_k_heads": self.num_k_heads,
"num_v_heads": self.num_v_heads,
"z": z,
}

View File

@@ -38,6 +38,7 @@ from sglang.srt.utils import (
is_cuda,
is_flashinfer_available,
is_hip,
is_npu,
is_port_available,
is_remote_url,
is_sm90_supported,
@@ -569,7 +570,7 @@ class ServerArgs:
)
self.disable_cuda_graph = True
if self.attention_backend == "ascend":
if is_npu() and self.attention_backend in ["ascend", "hybrid_linear_attn"]:
logger.warning(
"At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
)