Support Qwen3-Next on Ascend NPU (#10379)
This commit is contained in:
@@ -158,7 +158,7 @@ def _layer_norm_fwd(
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||
grid = (M, ngroups)
|
||||
with torch.cuda.device(x.device.index):
|
||||
with torch.get_device_module(x.device).device(x.device.index):
|
||||
_layer_norm_fwd_1pass_kernel[grid](
|
||||
x,
|
||||
out,
|
||||
|
||||
@@ -23,6 +23,22 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer, fused_gdn_gating
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.utils import is_npu
|
||||
|
||||
if is_npu():
|
||||
from sgl_kernel_npu.fla.chunk import chunk_gated_delta_rule_npu
|
||||
from sgl_kernel_npu.fla.fused_sigmoid_gating_recurrent import (
|
||||
fused_sigmoid_gating_delta_rule_update_npu,
|
||||
)
|
||||
from sgl_kernel_npu.mamba.causal_conv1d import (
|
||||
causal_conv1d_fn_npu,
|
||||
causal_conv1d_update_npu,
|
||||
)
|
||||
|
||||
chunk_gated_delta_rule = chunk_gated_delta_rule_npu
|
||||
fused_sigmoid_gating_delta_rule_update = fused_sigmoid_gating_delta_rule_update_npu
|
||||
causal_conv1d_fn = causal_conv1d_fn_npu
|
||||
causal_conv1d_update = causal_conv1d_update_npu
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -85,10 +101,12 @@ class MambaAttnBackend(AttentionBackend):
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
for i in range(max_bs):
|
||||
self.state_indices_list.append(
|
||||
torch.full((i + 1,), self.pad_slot_id, dtype=torch.int32, device="cuda")
|
||||
torch.full(
|
||||
(i + 1,), self.pad_slot_id, dtype=torch.int32, device=self.device
|
||||
)
|
||||
)
|
||||
self.query_start_loc_list.append(
|
||||
torch.empty((i + 2,), dtype=torch.int32, device="cuda")
|
||||
torch.empty((i + 2,), dtype=torch.int32, device=self.device)
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
@@ -110,7 +128,7 @@ class MambaAttnBackend(AttentionBackend):
|
||||
bs * spec_info.draft_token_num + 1,
|
||||
step=spec_info.draft_token_num,
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
device=self.device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -152,7 +170,7 @@ class MambaAttnBackend(AttentionBackend):
|
||||
bs * spec_info.draft_token_num + 1,
|
||||
step=spec_info.draft_token_num,
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
device=self.device,
|
||||
)
|
||||
)
|
||||
if num_padding > 0:
|
||||
|
||||
@@ -649,6 +649,7 @@ class HybridLinearKVPool(KVCache):
|
||||
self,
|
||||
size: int,
|
||||
dtype: torch.dtype,
|
||||
page_size: int,
|
||||
head_num: int,
|
||||
head_dim: int,
|
||||
full_attention_layer_ids: List[int],
|
||||
@@ -659,10 +660,14 @@ class HybridLinearKVPool(KVCache):
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.full_layer_nums = len(full_attention_layer_ids)
|
||||
self.page_size = 1
|
||||
self.page_size = page_size
|
||||
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
||||
assert not enable_kvcache_transpose
|
||||
self.full_kv_pool = MHATokenToKVPool(
|
||||
if _is_npu:
|
||||
TokenToKVPoolClass = AscendTokenToKVPool
|
||||
else:
|
||||
TokenToKVPoolClass = MHATokenToKVPool
|
||||
self.full_kv_pool = TokenToKVPoolClass(
|
||||
size=size,
|
||||
page_size=self.page_size,
|
||||
dtype=dtype,
|
||||
@@ -904,8 +909,12 @@ class AscendTokenToKVPool(MHATokenToKVPool):
|
||||
cache_v: torch.Tensor,
|
||||
k_scale: Optional[float] = None,
|
||||
v_scale: Optional[float] = None,
|
||||
layer_id_override: Optional[int] = None,
|
||||
):
|
||||
layer_id = layer.layer_id
|
||||
if layer_id_override is not None:
|
||||
layer_id = layer_id_override
|
||||
else:
|
||||
layer_id = layer.layer_id
|
||||
if cache_k.dtype != self.dtype:
|
||||
if k_scale is not None:
|
||||
cache_k.div_(k_scale)
|
||||
|
||||
@@ -1567,6 +1567,7 @@ class ModelRunner:
|
||||
)
|
||||
elif self.is_hybrid_gdn:
|
||||
self.token_to_kv_pool = HybridLinearKVPool(
|
||||
page_size=self.page_size if _is_npu else 1,
|
||||
size=self.max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
head_num=self.model_config.get_num_kv_heads(
|
||||
@@ -1601,7 +1602,10 @@ class ModelRunner:
|
||||
# Initialize token_to_kv_pool_allocator
|
||||
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
|
||||
if self.token_to_kv_pool_allocator is None:
|
||||
if self.server_args.attention_backend == "ascend":
|
||||
if _is_npu and self.server_args.attention_backend in [
|
||||
"ascend",
|
||||
"hybrid_linear_attn",
|
||||
]:
|
||||
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
@@ -1819,15 +1823,22 @@ class ModelRunner:
|
||||
assert (
|
||||
self.is_hybrid_gdn
|
||||
), "hybrid_linear_attn backend can only be used with hybrid GDN models."
|
||||
from sglang.srt.layers.attention.flashattention_backend import (
|
||||
FlashAttentionBackend,
|
||||
)
|
||||
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
||||
HybridLinearAttnBackend,
|
||||
MambaAttnBackend,
|
||||
)
|
||||
|
||||
full_attn_backend = FlashAttentionBackend(self)
|
||||
if _is_npu:
|
||||
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
||||
|
||||
full_attn_backend = AscendAttnBackend(self)
|
||||
else:
|
||||
from sglang.srt.layers.attention.flashattention_backend import (
|
||||
FlashAttentionBackend,
|
||||
)
|
||||
|
||||
full_attn_backend = FlashAttentionBackend(self)
|
||||
|
||||
linear_attn_backend = MambaAttnBackend(self)
|
||||
full_attn_layers = self.model_config.hf_config.full_attention_layer_ids
|
||||
return HybridLinearAttnBackend(
|
||||
|
||||
@@ -46,10 +46,11 @@ from sglang.srt.model_loader.weight_utils import (
|
||||
sharded_weight_loader,
|
||||
)
|
||||
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock
|
||||
from sglang.srt.utils import add_prefix, is_cuda, make_layers, set_weight_attrs
|
||||
from sglang.srt.utils import add_prefix, is_cuda, is_npu, make_layers, set_weight_attrs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_is_cuda = is_cuda()
|
||||
_is_npu = is_npu()
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@@ -327,7 +328,7 @@ class Qwen3GatedDeltaNet(nn.Module):
|
||||
eps=self.layer_norm_epsilon,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
device=torch.cuda.current_device(),
|
||||
device=torch.get_device_module().current_device(),
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
|
||||
@@ -388,7 +389,7 @@ class Qwen3GatedDeltaNet(nn.Module):
|
||||
return query, key, value, z, b, a
|
||||
|
||||
def _forward_input_proj(self, hidden_states: torch.Tensor):
|
||||
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
||||
DUAL_STREAM_TOKEN_THRESHOLD = 1024 if not _is_npu else 0
|
||||
seq_len, _ = hidden_states.shape
|
||||
if seq_len < DUAL_STREAM_TOKEN_THRESHOLD:
|
||||
current_stream = torch.cuda.current_stream()
|
||||
@@ -454,6 +455,8 @@ class Qwen3GatedDeltaNet(nn.Module):
|
||||
"dt_bias": self.dt_bias,
|
||||
"layer_id": self.layer_id,
|
||||
"seq_len": seq_len,
|
||||
"num_k_heads": self.num_k_heads,
|
||||
"num_v_heads": self.num_v_heads,
|
||||
"z": z,
|
||||
}
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ from sglang.srt.utils import (
|
||||
is_cuda,
|
||||
is_flashinfer_available,
|
||||
is_hip,
|
||||
is_npu,
|
||||
is_port_available,
|
||||
is_remote_url,
|
||||
is_sm90_supported,
|
||||
@@ -569,7 +570,7 @@ class ServerArgs:
|
||||
)
|
||||
self.disable_cuda_graph = True
|
||||
|
||||
if self.attention_backend == "ascend":
|
||||
if is_npu() and self.attention_backend in ["ascend", "hybrid_linear_attn"]:
|
||||
logger.warning(
|
||||
"At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user