From 16cd550c8554796d2b20b39162dbab7db8355476 Mon Sep 17 00:00:00 2001 From: Even Zhou Date: Sat, 13 Sep 2025 07:31:37 +0800 Subject: [PATCH] Support Qwen3-Next on Ascend NPU (#10379) --- .../workflows/release-docker-npu-nightly.yml | 2 +- .github/workflows/release-docker-npu.yml | 2 +- docker/Dockerfile.npu | 11 +++++--- .../layers/attention/fla/layernorm_gated.py | 2 +- .../attention/hybrid_linear_attn_backend.py | 26 ++++++++++++++++--- python/sglang/srt/mem_cache/memory_pool.py | 15 ++++++++--- .../sglang/srt/model_executor/model_runner.py | 21 +++++++++++---- python/sglang/srt/models/qwen3_next.py | 9 ++++--- python/sglang/srt/server_args.py | 3 ++- scripts/ci/npu_ci_install_dependency.sh | 14 +++++++--- 10 files changed, 79 insertions(+), 26 deletions(-) diff --git a/.github/workflows/release-docker-npu-nightly.yml b/.github/workflows/release-docker-npu-nightly.yml index 527a0cdc2..dff45f2ac 100644 --- a/.github/workflows/release-docker-npu-nightly.yml +++ b/.github/workflows/release-docker-npu-nightly.yml @@ -73,6 +73,6 @@ jobs: push: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }} provenance: false build-args: | - SGLANG_KERNEL_NPU_TAG=20250901 + SGLANG_KERNEL_NPU_TAG=20250913 CANN_VERSION=${{ matrix.cann_version }} DEVICE_TYPE=${{ matrix.device_type }} diff --git a/.github/workflows/release-docker-npu.yml b/.github/workflows/release-docker-npu.yml index f9d52eb4b..8fa6a983e 100644 --- a/.github/workflows/release-docker-npu.yml +++ b/.github/workflows/release-docker-npu.yml @@ -69,6 +69,6 @@ jobs: push: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }} provenance: false build-args: | - SGLANG_KERNEL_NPU_TAG=20250901 + SGLANG_KERNEL_NPU_TAG=20250913 CANN_VERSION=${{ matrix.cann_version }} DEVICE_TYPE=${{ matrix.device_type }} diff --git a/docker/Dockerfile.npu b/docker/Dockerfile.npu index 3f9b0ae42..01f9cf7e7 100644 --- a/docker/Dockerfile.npu +++ b/docker/Dockerfile.npu @@ -13,7 +13,8 @@ ARG PYTORCH_VERSION=2.6.0 ARG TORCHVISION_VERSION=0.21.0 ARG PTA_URL="https://gitee.com/ascend/pytorch/releases/download/v7.1.0.1-pytorch2.6.0/torch_npu-2.6.0.post1-cp311-cp311-manylinux_2_28_aarch64.whl" ARG VLLM_TAG=v0.8.5 -ARG TRITON_ASCEND_URL=https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/triton_ascend-3.2.0.dev20250729-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl +ARG TRITON_ASCEND_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/triton_ascend-3.2.0%2Bgitb0ea0850-cp311-cp311-linux_aarch64.whl" +ARG BISHENG_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/Ascend-BiSheng-toolkit_aarch64.run" ARG SGLANG_TAG=main ARG ASCEND_CANN_PATH=/usr/local/Ascend/ascend-toolkit ARG SGLANG_KERNEL_NPU_TAG=main @@ -81,13 +82,17 @@ RUN git clone https://github.com/sgl-project/sglang --branch $SGLANG_TAG && \ rm -rf sglang # Install Deep-ep -RUN git clone --branch $SGLANG_KERNEL_NPU_TAG https://github.com/sgl-project/sgl-kernel-npu.git \ +# pin wheel to 0.45.1 ref: https://github.com/pypa/wheel/issues/662 +RUN pip install wheel==0.45.1 && git clone --branch $SGLANG_KERNEL_NPU_TAG https://github.com/sgl-project/sgl-kernel-npu.git \ && export LD_LIBRARY_PATH=${ASCEND_CANN_PATH}/latest/runtime/lib64/stub:$LD_LIBRARY_PATH && \ source ${ASCEND_CANN_PATH}/set_env.sh && \ cd sgl-kernel-npu && \ bash build.sh \ - && pip install output/deep_ep*.whl --no-cache-dir \ + && pip install output/deep_ep*.whl output/sgl_kernel_npu*.whl --no-cache-dir \ && cd .. && rm -rf sgl-kernel-npu \ && cd "$(pip show deep-ep | awk '/^Location:/ {print $2}')" && ln -s deep_ep/deep_ep_cpp*.so +# Install Bisheng +RUN wget ${BISHENG_URL} && chmod a+x Ascend-BiSheng-toolkit_aarch64.run && ./Ascend-BiSheng-toolkit_aarch64.run --install && rm Ascend-BiSheng-toolkit_aarch64.run + CMD ["/bin/bash"] diff --git a/python/sglang/srt/layers/attention/fla/layernorm_gated.py b/python/sglang/srt/layers/attention/fla/layernorm_gated.py index bd53d0d64..89482245b 100644 --- a/python/sglang/srt/layers/attention/fla/layernorm_gated.py +++ b/python/sglang/srt/layers/attention/fla/layernorm_gated.py @@ -158,7 +158,7 @@ def _layer_norm_fwd( # heuristics for number of warps num_warps = min(max(BLOCK_N // 256, 1), 8) grid = (M, ngroups) - with torch.cuda.device(x.device.index): + with torch.get_device_module(x.device).device(x.device.index): _layer_norm_fwd_1pass_kernel[grid]( x, out, diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index a676573f2..82baea216 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -23,6 +23,22 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer, fused_gdn_gating from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput +from sglang.srt.utils import is_npu + +if is_npu(): + from sgl_kernel_npu.fla.chunk import chunk_gated_delta_rule_npu + from sgl_kernel_npu.fla.fused_sigmoid_gating_recurrent import ( + fused_sigmoid_gating_delta_rule_update_npu, + ) + from sgl_kernel_npu.mamba.causal_conv1d import ( + causal_conv1d_fn_npu, + causal_conv1d_update_npu, + ) + + chunk_gated_delta_rule = chunk_gated_delta_rule_npu + fused_sigmoid_gating_delta_rule_update = fused_sigmoid_gating_delta_rule_update_npu + causal_conv1d_fn = causal_conv1d_fn_npu + causal_conv1d_update = causal_conv1d_update_npu @dataclass @@ -85,10 +101,12 @@ class MambaAttnBackend(AttentionBackend): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): for i in range(max_bs): self.state_indices_list.append( - torch.full((i + 1,), self.pad_slot_id, dtype=torch.int32, device="cuda") + torch.full( + (i + 1,), self.pad_slot_id, dtype=torch.int32, device=self.device + ) ) self.query_start_loc_list.append( - torch.empty((i + 2,), dtype=torch.int32, device="cuda") + torch.empty((i + 2,), dtype=torch.int32, device=self.device) ) def init_forward_metadata_capture_cuda_graph( @@ -110,7 +128,7 @@ class MambaAttnBackend(AttentionBackend): bs * spec_info.draft_token_num + 1, step=spec_info.draft_token_num, dtype=torch.int32, - device="cuda", + device=self.device, ) ) else: @@ -152,7 +170,7 @@ class MambaAttnBackend(AttentionBackend): bs * spec_info.draft_token_num + 1, step=spec_info.draft_token_num, dtype=torch.int32, - device="cuda", + device=self.device, ) ) if num_padding > 0: diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 80c549033..ddcb19ea2 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -649,6 +649,7 @@ class HybridLinearKVPool(KVCache): self, size: int, dtype: torch.dtype, + page_size: int, head_num: int, head_dim: int, full_attention_layer_ids: List[int], @@ -659,10 +660,14 @@ class HybridLinearKVPool(KVCache): self.dtype = dtype self.device = device self.full_layer_nums = len(full_attention_layer_ids) - self.page_size = 1 + self.page_size = page_size # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True assert not enable_kvcache_transpose - self.full_kv_pool = MHATokenToKVPool( + if _is_npu: + TokenToKVPoolClass = AscendTokenToKVPool + else: + TokenToKVPoolClass = MHATokenToKVPool + self.full_kv_pool = TokenToKVPoolClass( size=size, page_size=self.page_size, dtype=dtype, @@ -904,8 +909,12 @@ class AscendTokenToKVPool(MHATokenToKVPool): cache_v: torch.Tensor, k_scale: Optional[float] = None, v_scale: Optional[float] = None, + layer_id_override: Optional[int] = None, ): - layer_id = layer.layer_id + if layer_id_override is not None: + layer_id = layer_id_override + else: + layer_id = layer.layer_id if cache_k.dtype != self.dtype: if k_scale is not None: cache_k.div_(k_scale) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 8960b0bc8..36785d86f 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1567,6 +1567,7 @@ class ModelRunner: ) elif self.is_hybrid_gdn: self.token_to_kv_pool = HybridLinearKVPool( + page_size=self.page_size if _is_npu else 1, size=self.max_total_num_tokens, dtype=self.kv_cache_dtype, head_num=self.model_config.get_num_kv_heads( @@ -1601,7 +1602,10 @@ class ModelRunner: # Initialize token_to_kv_pool_allocator need_sort = self.server_args.disaggregation_mode in ("decode", "prefill") if self.token_to_kv_pool_allocator is None: - if self.server_args.attention_backend == "ascend": + if _is_npu and self.server_args.attention_backend in [ + "ascend", + "hybrid_linear_attn", + ]: self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator( self.max_total_num_tokens, page_size=self.page_size, @@ -1819,15 +1823,22 @@ class ModelRunner: assert ( self.is_hybrid_gdn ), "hybrid_linear_attn backend can only be used with hybrid GDN models." - from sglang.srt.layers.attention.flashattention_backend import ( - FlashAttentionBackend, - ) from sglang.srt.layers.attention.hybrid_linear_attn_backend import ( HybridLinearAttnBackend, MambaAttnBackend, ) - full_attn_backend = FlashAttentionBackend(self) + if _is_npu: + from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend + + full_attn_backend = AscendAttnBackend(self) + else: + from sglang.srt.layers.attention.flashattention_backend import ( + FlashAttentionBackend, + ) + + full_attn_backend = FlashAttentionBackend(self) + linear_attn_backend = MambaAttnBackend(self) full_attn_layers = self.model_config.hf_config.full_attention_layer_ids return HybridLinearAttnBackend( diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index cdba9975f..52d8c6faf 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -46,10 +46,11 @@ from sglang.srt.model_loader.weight_utils import ( sharded_weight_loader, ) from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock -from sglang.srt.utils import add_prefix, is_cuda, make_layers, set_weight_attrs +from sglang.srt.utils import add_prefix, is_cuda, is_npu, make_layers, set_weight_attrs logger = logging.getLogger(__name__) _is_cuda = is_cuda() +_is_npu = is_npu() import triton import triton.language as tl @@ -327,7 +328,7 @@ class Qwen3GatedDeltaNet(nn.Module): eps=self.layer_norm_epsilon, group_size=None, norm_before_gate=True, - device=torch.cuda.current_device(), + device=torch.get_device_module().current_device(), dtype=config.torch_dtype, ) @@ -388,7 +389,7 @@ class Qwen3GatedDeltaNet(nn.Module): return query, key, value, z, b, a def _forward_input_proj(self, hidden_states: torch.Tensor): - DUAL_STREAM_TOKEN_THRESHOLD = 1024 + DUAL_STREAM_TOKEN_THRESHOLD = 1024 if not _is_npu else 0 seq_len, _ = hidden_states.shape if seq_len < DUAL_STREAM_TOKEN_THRESHOLD: current_stream = torch.cuda.current_stream() @@ -454,6 +455,8 @@ class Qwen3GatedDeltaNet(nn.Module): "dt_bias": self.dt_bias, "layer_id": self.layer_id, "seq_len": seq_len, + "num_k_heads": self.num_k_heads, + "num_v_heads": self.num_v_heads, "z": z, } diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f041011e0..897c732b5 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -38,6 +38,7 @@ from sglang.srt.utils import ( is_cuda, is_flashinfer_available, is_hip, + is_npu, is_port_available, is_remote_url, is_sm90_supported, @@ -569,7 +570,7 @@ class ServerArgs: ) self.disable_cuda_graph = True - if self.attention_backend == "ascend": + if is_npu() and self.attention_backend in ["ascend", "hybrid_linear_attn"]: logger.warning( "At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128." ) diff --git a/scripts/ci/npu_ci_install_dependency.sh b/scripts/ci/npu_ci_install_dependency.sh index 71cf46f7f..97c7dba1b 100755 --- a/scripts/ci/npu_ci_install_dependency.sh +++ b/scripts/ci/npu_ci_install_dependency.sh @@ -45,16 +45,22 @@ wget -O "${PTA_NAME}" "${PTA_URL}" && ${PIP_INSTALL} "./${PTA_NAME}" ### Install Triton-Ascend -TRITON_ASCEND_NAME="triton_ascend-3.2.0.dev20250729-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl" -TRITON_ASCEND_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/${TRITON_ASCEND_NAME}" +TRITON_ASCEND_NAME="triton_ascend-3.2.0+gitb0ea0850-cp311-cp311-linux_aarch64.whl" +TRITON_ASCEND_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/triton_ascend-3.2.0%2Bgitb0ea0850-cp311-cp311-linux_aarch64.whl" ${PIP_INSTALL} attrs==24.2.0 numpy==1.26.4 scipy==1.13.1 decorator==5.1.1 psutil==6.0.0 pytest==8.3.2 pytest-xdist==3.6.1 pyyaml pybind11 wget -O "${TRITON_ASCEND_NAME}" "${TRITON_ASCEND_URL}" && ${PIP_INSTALL} "./${TRITON_ASCEND_NAME}" +### Install BiSheng +BISHENG_NAME="Ascend-BiSheng-toolkit_aarch64.run" +BISHENG_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/${BISHENG_NAME}" +wget -O "${BISHENG_NAME}" "${BISHENG_URL}" && chmod a+x "${BISHENG_NAME}" && "./${BISHENG_NAME}" --install && rm "${BISHENG_NAME}" + + ### Install sgl-kernel-npu -SGL_KERNEL_NPU_TAG="20250901" +SGL_KERNEL_NPU_TAG="20250913" git clone --depth 1 https://github.com/sgl-project/sgl-kernel-npu.git --branch ${SGL_KERNEL_NPU_TAG} -(cd sgl-kernel-npu && bash ./build.sh -a deepep && pip install output/deep_ep*.whl && cd "$(pip show deep-ep | grep -E '^Location:' | awk '{print $2}')" && ln -s deep_ep/deep_ep_cpp*.so) +(cd sgl-kernel-npu && bash ./build.sh && pip install output/deep_ep*.whl output/sgl_kernel_npu*.whl && cd "$(pip show deep-ep | grep -E '^Location:' | awk '{print $2}')" && ln -s deep_ep/deep_ep_cpp*.so) ### Install SGLang