From 1e0e549766a0b13164283d759120200d73d71858 Mon Sep 17 00:00:00 2001 From: ronnie_zheng Date: Thu, 3 Jul 2025 19:23:19 +0300 Subject: [PATCH] Ascend attention backend(PA&MLA) (#7722) Co-authored-by: Maksim Co-authored-by: VDV1985 --- docs/backend/attention_backend.md | 6 + .../srt/layers/attention/ascend_backend.py | 219 ++++++++++++++++++ python/sglang/srt/layers/moe/ep_moe/layer.py | 6 +- .../srt/layers/moe/fused_moe_triton/layer.py | 38 +++ python/sglang/srt/layers/moe/topk.py | 5 + python/sglang/srt/layers/rotary_embedding.py | 4 +- python/sglang/srt/managers/schedule_batch.py | 6 +- python/sglang/srt/mem_cache/allocator.py | 161 +++++++++++++ python/sglang/srt/mem_cache/memory_pool.py | 148 ++++++++++++ .../srt/model_executor/forward_batch_info.py | 11 +- .../sglang/srt/model_executor/model_runner.py | 68 +++++- python/sglang/srt/models/deepseek_v2.py | 4 +- python/sglang/srt/server_args.py | 7 + python/sglang/srt/utils.py | 2 +- test/srt/run_suite.py | 3 + test/srt/test_ascend_attention_backend.py | 74 ++++++ test/srt/test_ascend_mla_backend.py | 96 ++++++++ 17 files changed, 842 insertions(+), 16 deletions(-) create mode 100644 python/sglang/srt/layers/attention/ascend_backend.py create mode 100644 test/srt/test_ascend_attention_backend.py create mode 100644 test/srt/test_ascend_mla_backend.py diff --git a/docs/backend/attention_backend.md b/docs/backend/attention_backend.md index ad5ddfde9..4e9ecf8e2 100644 --- a/docs/backend/attention_backend.md +++ b/docs/backend/attention_backend.md @@ -9,6 +9,7 @@ | **Triton** | ❌ | ✅ | ✅ | ✅ | ❌ | | **Torch Native** | ❌ | ❌ | ❌ | ❌ | ❌ | | **FlashMLA** | ✅ | ✅ | ✅ | ❌ | ❌ | +| **Ascend** | ✅ | ❌ | ❌ | ❌ | ❌ | Note: Every kernel backend is compatible with a page size > 1 by specifying an argument such as `--page-size 16`. This is because a page size of 16 can be converted to a page size of 1 in the kernel backend. @@ -46,3 +47,8 @@ python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct -- python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --trust-remote-code python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --kv-cache-dtype fp8_e4m3 --trust-remote-code ``` + +- Ascend +```bash +python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend +``` diff --git a/python/sglang/srt/layers/attention/ascend_backend.py b/python/sglang/srt/layers/attention/ascend_backend.py new file mode 100644 index 000000000..7bce68655 --- /dev/null +++ b/python/sglang/srt/layers/attention/ascend_backend.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import torch +import torch_npu +from torch.nn.functional import scaled_dot_product_attention + +from sglang.srt.configs.model_config import AttentionArch +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend +from sglang.srt.layers.radix_attention import AttentionType +from sglang.srt.model_executor.forward_batch_info import ForwardBatch + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + + +@dataclass +class ForwardMetadata: + + # calculated map for kv positions [bs * maxseqlen] + block_tables: Optional[torch.Tensor] = None + + # seq len inputs + extend_seq_lens_cpu_int: Optional[torch.Tensor] = None + seq_lens_cpu_int: Optional[torch.Tensor] = None + + +class AscendAttnBackend(AttentionBackend): + + def gen_attention_mask(self, max_seq_len: int, dtype=torch.float16): + mask_flag = torch.tril( + torch.ones((max_seq_len, max_seq_len), dtype=torch.bool) + ).view(max_seq_len, max_seq_len) + mask_flag = ~mask_flag + if dtype == torch.float16: + mask_value = torch.finfo(torch.float32).min + else: + mask_value = 1 + self.mask = ( + torch.masked_fill( + torch.zeros(size=(max_seq_len, max_seq_len)), mask_flag, mask_value + ) + .to(dtype) + .to(self.device) + ) + self.mask_len = max_seq_len + + def __init__(self, model_runner: ModelRunner): + super().__init__() + self.forward_metadata = ForwardMetadata() + self.device = model_runner.device + self.gen_attention_mask(128, model_runner.dtype) + self.page_size = model_runner.page_size + self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA + if self.use_mla: + self.kv_lora_rank = model_runner.model_config.kv_lora_rank + self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim + self.native_attn = TorchNativeAttnBackend(model_runner) + + def init_forward_metadata(self, forward_batch: ForwardBatch): + """Init the metadata for a forward pass.""" + self.forward_metadata.block_tables = ( + forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : forward_batch.seq_lens.max() + ][:, :: self.page_size] + // self.page_size + ) + if forward_batch.extend_seq_lens is not None: + self.forward_metadata.extend_seq_lens_cpu_int = ( + forward_batch.extend_seq_lens.cpu().int() + ) + self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int() + + def forward_extend( + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) + + k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) + + if not self.use_mla: + query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim) + output = torch.empty( + (query.shape[0], layer.tp_q_head_num * layer.v_head_dim), + dtype=query.dtype, + device=query.device, + ) + + torch_npu._npu_flash_attention_qlens( + query=query, + key_cache=k_cache, + value_cache=v_cache, + mask=self.mask, + block_table=self.forward_metadata.block_tables, + seq_len=self.forward_metadata.extend_seq_lens_cpu_int, + context_lens=self.forward_metadata.seq_lens_cpu_int, + scale_value=layer.scaling, + num_heads=layer.tp_q_head_num, + num_kv_heads=layer.tp_k_head_num, + out=output, + ) + return output + else: + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + use_gqa = layer.tp_q_head_num != layer.tp_k_head_num + + q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) + o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim) + + causal = True + if ( + layer.is_cross_attention + or layer.attn_type == AttentionType.ENCODER_ONLY + ): + causal = False + + self.native_attn._run_sdpa_forward_extend( + q_, + o_, + k_cache.view( + -1, layer.tp_k_head_num, (self.kv_lora_rank + self.qk_rope_head_dim) + ), + v_cache.view(-1, layer.tp_v_head_num, self.kv_lora_rank), + forward_batch.req_to_token_pool.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.extend_prefix_lens, + forward_batch.extend_seq_lens, + scaling=layer.scaling, + enable_gqa=use_gqa, + causal=causal, + ) + return o + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) + if not self.use_mla: + k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) + + query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) + num_tokens = query.shape[0] + output = torch.empty( + (num_tokens, layer.tp_q_head_num, layer.v_head_dim), + dtype=query.dtype, + device=query.device, + ) + + torch_npu._npu_paged_attention( + query=query, + key_cache=k_cache, + value_cache=v_cache, + num_heads=layer.tp_q_head_num, + num_kv_heads=layer.tp_k_head_num, + scale_value=layer.scaling, + block_table=self.forward_metadata.block_tables, + context_lens=self.forward_metadata.seq_lens_cpu_int, + out=output, + ) + return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim) + else: + query = q.view(-1, layer.tp_q_head_num, layer.head_dim) + num_tokens = query.shape[0] + kv_c_and_k_pe_cache = forward_batch.token_to_kv_pool.get_key_buffer( + layer.layer_id + ) + kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( + -1, + self.page_size, + layer.tp_k_head_num, + self.kv_lora_rank + self.qk_rope_head_dim, + ) + + attn_output = torch.empty( + [num_tokens, layer.tp_q_head_num, self.kv_lora_rank], + dtype=q.dtype, + device=q.device, + ) + torch_npu._npu_paged_attention_mla( + query=query, + key_cache=kv_c_and_k_pe_cache, + num_kv_heads=layer.tp_k_head_num, + num_heads=layer.tp_q_head_num, + scale_value=layer.scaling, + block_table=self.forward_metadata.block_tables, + context_lens=self.forward_metadata.seq_lens_cpu_int, + mla_vheadsize=self.kv_lora_rank, + out=attn_output, + ) + return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index d5cf3b568..10319800b 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -3,7 +3,6 @@ from typing import Callable, List, Optional, Tuple import einops import torch -from sgl_kernel import silu_and_mul from torch.nn import Module from sglang.srt.custom_op import CustomOp @@ -50,13 +49,18 @@ from sglang.srt.utils import ( dispose_tensor, get_bool_env_var, is_hip, + is_npu, set_weight_attrs, ) _is_hip = is_hip() +_is_npu = is_npu() _is_fp8_fnuz = is_fp8_fnuz() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip +if not _is_npu: + from sgl_kernel import silu_and_mul + if _is_hip: from vllm._custom_ops import scaled_fp8_quant diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 723601737..8cc068dbf 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -321,6 +321,44 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): routed_scaling_factor, ) + def forward_npu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + num_fused_shared_experts: int = 0, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + inplace: bool = True, + no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, + ) -> torch.Tensor: + return moe_forward_native( + layer, + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + num_fused_shared_experts, + custom_routing_function, + correction_bias, + activation, + apply_router_weight_on_input, + inplace, + no_combine, + routed_scaling_factor, + ) + def forward_tpu(self, *args, **kwargs) -> torch.Tensor: raise NotImplementedError("The TPU backend currently does not support MoE.") diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 610931cc8..908927b88 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -35,6 +35,7 @@ from sglang.srt.utils import ( is_cpu, is_cuda, is_hip, + is_npu, ) _is_cuda = is_cuda() @@ -42,6 +43,7 @@ _is_hip = is_hip() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() +_is_npu = is_npu() if _is_cuda: from sgl_kernel import moe_fused_gate @@ -159,6 +161,9 @@ def grouped_topk_gpu( assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" scores = torch.softmax(gating_output, dim=-1) + # NPU compiler limitation + if _is_npu and scores.dtype == torch.bfloat16: + scores = scores.to(torch.float16) num_token = scores.shape[0] num_experts = scores.shape[1] group_scores = ( diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index c81954318..0be507f84 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -660,7 +660,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): beta_slow: int = 1, mscale: float = 1, mscale_all_dim: float = 0, - device: Optional[str] = "cuda", + device: Optional[str] = "cuda" if not _is_npu else "npu", ) -> None: self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor @@ -679,7 +679,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ) # Re-dispatch - if _is_hip: + if _is_hip or _is_npu: self._forward_method = self.forward_native def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index b257fe6ef..1039cd693 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1673,6 +1673,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ) or global_server_args_dict["attention_backend"] == "flashmla" or global_server_args_dict["attention_backend"] == "cutlass_mla" + or global_server_args_dict["attention_backend"] == "ascend" or global_server_args_dict["enable_two_batch_overlap"] ): seq_lens_cpu = ( @@ -1875,7 +1876,10 @@ def get_last_loc( req_pool_indices_tensor: torch.Tensor, prefix_lens_tensor: torch.Tensor, ) -> torch.Tensor: - if global_server_args_dict["attention_backend"] != "torch_native": + if ( + global_server_args_dict["attention_backend"] != "ascend" + and global_server_args_dict["attention_backend"] != "torch_native" + ): impl = get_last_loc_triton else: impl = get_last_loc_torch diff --git a/python/sglang/srt/mem_cache/allocator.py b/python/sglang/srt/mem_cache/allocator.py index 6bcabf648..6d06fa103 100644 --- a/python/sglang/srt/mem_cache/allocator.py +++ b/python/sglang/srt/mem_cache/allocator.py @@ -540,3 +540,164 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ) self.is_not_in_free_group = True self.free_group = [] + + +def alloc_extend_kernel_ascend( + prefix_lens, + seq_lens, + last_loc, + free_pages, + out_indices, + page_size, + device, +): + extend_lens = seq_lens - prefix_lens + end_pos = torch.cumsum(extend_lens, 0) + start_pos = end_pos - extend_lens + num_new_pages = (seq_lens + page_size - 1) // page_size - ( + prefix_lens + page_size - 1 + ) // page_size + num_full_new_pages = (seq_lens) // page_size - ( + prefix_lens + page_size - 1 + ) // page_size + need_page = num_new_pages - num_full_new_pages + end_new_pages = torch.cumsum(num_new_pages, 0) + start_new_pages = end_new_pages - num_new_pages + pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32) + for i in range(len(prefix_lens)): + num1 = ( + min( + seq_lens[i], + (prefix_lens[i] + page_size - 1) // page_size * page_size, + ) + - prefix_lens[i] + ) + if num1: + out_indices[start_pos[i] : start_pos[i] + num1] = ( + last_loc[i] + 1 + pos_in_page[:num1].view(-1) + ) + + num2 = ( + seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size + ) * page_size + if num2: + pages = ( + free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]] + * page_size + ) + out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = ( + pages.view(-1, 1) + pos_in_page.view(1, -1) + ).view(-1) + + num3 = seq_lens[i] - seq_lens[i] // page_size * page_size + if num3: + out_indices[end_pos[i] - num3 : end_pos[i]] = ( + free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3] + ).view(-1) + return num_new_pages + + +def alloc_decode_kernel_ascend( + seq_lens, + last_loc, + free_pages, + out_indices, + page_size, +): + num_new_pages = (seq_lens + page_size - 1) // page_size - ( + seq_lens - 1 + page_size - 1 + ) // page_size + end_new_pages = torch.cumsum(num_new_pages, 0) + start_new_pages = end_new_pages - num_new_pages + for i in range(len(seq_lens)): + if num_new_pages[i]: + out_indices[i] = free_pages[start_new_pages[i]] * page_size + else: + out_indices[i] = last_loc[i] + 1 + return num_new_pages + + +class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): + + def __init__( + self, + size: int, + page_size: int, + dtype: torch.dtype, + device: str, + kvcache: KVCache, + ): + super().__init__(size, page_size, dtype, device, kvcache) + self.ret_values = torch.empty((), dtype=torch.int32, device=self.device) + + def alloc_extend( + self, + prefix_lens: torch.Tensor, + seq_lens: torch.Tensor, + last_loc: torch.Tensor, + extend_num_tokens: int, + ): + if self.debug_mode: + assert torch.all( + (last_loc + 1) % self.page_size == prefix_lens % self.page_size + ) + + bs = len(prefix_lens) + out_indices = torch.empty( + (extend_num_tokens,), dtype=torch.int32, device=self.device + ) + + self.ret_values = alloc_extend_kernel_ascend( + prefix_lens, + seq_lens, + last_loc, + self.free_pages, + out_indices, + self.page_size, + self.device, + ) + + if self.debug_mode: + assert len(torch.unique(out_indices)) == len(out_indices) + + num_new_pages = self.ret_values.sum() + if num_new_pages > len(self.free_pages): + return None + + self.free_pages = self.free_pages[num_new_pages:] + return out_indices + + def alloc_decode( + self, + seq_lens: torch.Tensor, + last_loc: torch.Tensor, + ): + if self.debug_mode: + assert torch.all( + (last_loc + 2) % self.page_size == seq_lens % self.page_size + ) + + bs = len(seq_lens) + out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device) + + self.ret_values = alloc_decode_kernel_ascend( + seq_lens, + last_loc, + self.free_pages, + out_indices, + self.page_size, + ) + + if self.debug_mode: + assert len(torch.unique(out_indices)) == len(out_indices) + + num_new_pages = self.ret_values.sum() + if num_new_pages > len(self.free_pages): + return None + + self.free_pages = self.free_pages[num_new_pages:] + return out_indices + + def clear(self): + super().clear() + self.free_pages = self.free_pages.to(torch.int32) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 4e3f40371..00ad66552 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -568,6 +568,76 @@ class SWAKVPool(KVCache): ) +class AscendTokenToKVPool(MHATokenToKVPool): + + def _create_buffers(self): + with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): + # [size, head_num, head_dim] for each layer + # The padded slot 0 is used for writing dummy outputs from padded tokens. + self.k_buffer = [ + torch.zeros( + ( + self.size // self.page_size + 1, + self.page_size, + self.head_num, + self.head_dim, + ), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + self.v_buffer = [ + torch.zeros( + ( + self.size // self.page_size + 1, + self.page_size, + self.head_num, + self.head_dim, + ), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + + def set_kv_buffer( + self, + layer: RadixAttention, + loc: torch.Tensor, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, + ): + layer_id = layer.layer_id + if cache_k.dtype != self.dtype: + if k_scale is not None: + cache_k.div_(k_scale) + if v_scale is not None: + cache_v.div_(v_scale) + cache_k = cache_k.to(self.dtype) + cache_v = cache_v.to(self.dtype) + + if self.store_dtype != self.dtype: + cache_k = cache_k.view(self.store_dtype) + cache_v = cache_v.view(self.store_dtype) + + import torch_npu + + torch_npu._npu_reshape_and_cache( + key=cache_k, + value=cache_v, + key_cache=self.k_buffer[layer_id].view( + -1, self.page_size, self.head_num, self.head_dim + ), + value_cache=self.v_buffer[layer_id].view( + -1, self.page_size, self.head_num, self.head_dim + ), + slot_indices=loc, + ) + + @triton.jit def set_mla_kv_buffer_kernel( kv_buffer_ptr, @@ -820,6 +890,84 @@ class MLATokenToKVPool(KVCache): torch.cuda.synchronize() +class AscendMLAPagedTokenToKVPool(MLATokenToKVPool): + def __init__( + self, + size: int, + page_size: int, + dtype: torch.dtype, + kv_lora_rank: int, + qk_rope_head_dim: int, + layer_num: int, + device: str, + enable_memory_saver: bool, + start_layer: Optional[int] = None, + end_layer: Optional[int] = None, + ): + super(MLATokenToKVPool, self).__init__( + size, + page_size, + dtype, + layer_num, + device, + enable_memory_saver, + start_layer, + end_layer, + ) + + self.kv_lora_rank = kv_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + + self.custom_mem_pool = None + + with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): + # The padded slot 0 is used for writing dummy outputs from padded tokens. + self.kv_buffer = [ + torch.zeros( + ( + self.size // self.page_size + 1, + self.page_size, + self.kv_lora_rank + self.qk_rope_head_dim, + ), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(layer_num) + ] + + self.layer_transfer_counter = None + + kv_size = self.get_kv_size_bytes() + logger.info( + f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB" + ) + self.mem_usage = kv_size / GB + + def set_kv_buffer( + self, + layer: RadixAttention, + loc: torch.Tensor, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + ): + layer_id = layer.layer_id + if cache_k.dtype != self.dtype: + cache_k = cache_k.to(self.dtype) + + if self.store_dtype != self.dtype: + cache_k = cache_k.view(store_dtype) + + import torch_npu + + torch_npu._npu_reshape_and_cache_siso( + key=cache_k.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim), + key_cache=self.kv_buffer[layer_id - self.start_layer].view( + -1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim + ), + slot_indices=loc, + ) + + class DoubleSparseTokenToKVPool(KVCache): def __init__( self, diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index cc01e963e..65dd8d428 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -39,7 +39,12 @@ import triton import triton.language as tl from sglang.srt.layers.rotary_embedding import MRotaryEmbedding -from sglang.srt.utils import flatten_nested_list, get_compiler_backend, support_triton +from sglang.srt.utils import ( + flatten_nested_list, + get_compiler_backend, + is_npu, + support_triton, +) if TYPE_CHECKING: from sglang.srt.layers.attention.base_attn_backend import AttentionBackend @@ -50,6 +55,8 @@ if TYPE_CHECKING: from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +_is_npu = is_npu() + class ForwardMode(IntEnum): # Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt). @@ -739,7 +746,7 @@ def compute_position_torch( return positions.to(torch.int64), extend_start_loc -@torch.compile(dynamic=True, backend=get_compiler_backend()) +@torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu) def clamp_position(seq_lens): return torch.clamp((seq_lens - 1), min=0).to(torch.int64) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 435024475..de976c9af 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -72,12 +72,15 @@ from sglang.srt.managers.schedule_batch import ( global_server_args_dict, ) from sglang.srt.mem_cache.allocator import ( + AscendPagedTokenToKVPoolAllocator, BaseTokenToKVPoolAllocator, PagedTokenToKVPoolAllocator, SWATokenToKVPoolAllocator, TokenToKVPoolAllocator, ) from sglang.srt.mem_cache.memory_pool import ( + AscendMLAPagedTokenToKVPool, + AscendTokenToKVPool, DoubleSparseTokenToKVPool, MHATokenToKVPool, MLATokenToKVPool, @@ -110,6 +113,7 @@ from sglang.srt.utils import ( is_hip, is_hopper_with_cuda_12_3, is_no_spec_infer_or_topk_one, + is_npu, monkey_patch_p2p_access_check, monkey_patch_vllm_gguf_config, set_cpu_offload_max_bytes, @@ -117,6 +121,7 @@ from sglang.srt.utils import ( ) _is_hip = is_hip() +_is_npu = is_npu() _is_cpu_amx_available = cpu_has_amx_support() # Use a small KV cache pool size for tests in CI @@ -308,6 +313,7 @@ class ModelRunner: self.init_cuda_graphs() else: self.cuda_graph_runner = None + self.cuda_graph_mem_usage = 0 self.init_attention_backend() # auxiliary hidden capture mode. TODO: expose this to server args? @@ -369,6 +375,8 @@ class ModelRunner: server_args.attention_backend = "fa3" elif _is_hip: server_args.attention_backend = "aiter" + elif _is_npu: + server_args.attention_backend = "ascend" else: server_args.attention_backend = ( "flashinfer" if is_flashinfer_available() else "triton" @@ -388,6 +396,8 @@ class ModelRunner: server_args.attention_backend = "aiter" else: server_args.attention_backend = "triton" + elif _is_npu: + server_args.attention_backend = "ascend" else: server_args.attention_backend = "triton" logger.info( @@ -402,6 +412,7 @@ class ModelRunner: "triton", "flashmla", "cutlass_mla", + "ascend", ]: logger.info( f"MLA optimization is turned on. Use {server_args.attention_backend} backend." @@ -1096,7 +1107,35 @@ class ModelRunner: # Draft worker shares req_to_token_pool with the target worker. assert self.is_draft_worker - if self.use_mla_backend: + if self.server_args.attention_backend == "ascend" and not self.use_mla_backend: + self.token_to_kv_pool = AscendTokenToKVPool( + self.max_total_num_tokens, + page_size=self.page_size, + dtype=self.kv_cache_dtype, + head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()), + head_dim=self.model_config.head_dim, + layer_num=self.model_config.num_hidden_layers, + device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, + ) + elif self.server_args.attention_backend == "ascend" and self.use_mla_backend: + self.token_to_kv_pool = AscendMLAPagedTokenToKVPool( + self.max_total_num_tokens, + page_size=self.page_size, + dtype=self.kv_cache_dtype, + kv_lora_rank=self.model_config.kv_lora_rank, + qk_rope_head_dim=self.model_config.qk_rope_head_dim, + layer_num=( + self.model_config.num_hidden_layers + if not self.is_draft_worker + else self.model_config.hf_config.num_nextn_predict_layers + ), # PP is not compatible with mla backend + device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, + start_layer=self.start_layer, + end_layer=self.end_layer, + ) + elif self.use_mla_backend: self.token_to_kv_pool = MLATokenToKVPool( self.max_total_num_tokens, page_size=self.page_size, @@ -1176,13 +1215,22 @@ class ModelRunner: kvcache=self.token_to_kv_pool, ) else: - self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator( - self.max_total_num_tokens, - page_size=self.page_size, - dtype=self.kv_cache_dtype, - device=self.device, - kvcache=self.token_to_kv_pool, - ) + if _is_npu: + self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator( + self.max_total_num_tokens, + page_size=self.page_size, + dtype=self.kv_cache_dtype, + device=self.device, + kvcache=self.token_to_kv_pool, + ) + else: + self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator( + self.max_total_num_tokens, + page_size=self.page_size, + dtype=self.kv_cache_dtype, + device=self.device, + kvcache=self.token_to_kv_pool, + ) else: assert self.is_draft_worker @@ -1229,6 +1277,10 @@ class ModelRunner: from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend return AiterAttnBackend(self) + elif self.server_args.attention_backend == "ascend": + from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend + + return AscendAttnBackend(self) elif self.server_args.attention_backend == "triton": assert not self.model_config.is_encoder_decoder, ( "Cross attention is not supported in the triton attention backend. " diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 8e9c4c496..73b4271f4 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -956,7 +956,9 @@ class DeepseekV2AttentionMLA(nn.Module): else: return AttnForwardMethod.MLA - if self.attention_backend == "flashinfer": + if self.attention_backend == "ascend": + return AttnForwardMethod.MLA + elif self.attention_backend == "flashinfer": # Flashinfer MLA: Do not absorb when enabling ragged prefill if ( not self.flashinfer_mla_disable_ragged diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 76e0272a8..ef957dd12 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -380,6 +380,12 @@ class ServerArgs: ) self.disable_cuda_graph = True + if self.attention_backend == "ascend": + logger.warning( + "At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128." + ) + self.page_size = 128 + # Choose grammar backend if self.grammar_backend is None: self.grammar_backend = "xgrammar" @@ -1113,6 +1119,7 @@ class ServerArgs: "flashmla", "intel_amx", "torch_native", + "ascend", "triton", ], default=ServerArgs.attention_backend, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 72cba9aac..608eae654 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2399,7 +2399,7 @@ def bind_or_assign(target, source): def support_triton(backend: str) -> bool: - return backend not in ["torch_native", "intel_amx"] + return backend not in ["torch_native", "intel_amx", "ascend"] try: diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 4d914f980..b98e37ca8 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -143,6 +143,9 @@ suites = { # TestFile("test_vision_chunked_prefill.py", 175), # Disabled temporarily and track in #7701 TestFile("test_reasoning_parser.py", 5), ], + "per-commit-npu": [ + TestFile("test_ascend_attention_backend.py", 400), + ], "per-commit-2-gpu": [ TestFile("models/lora/test_lora_tp.py", 116), TestFile("test_data_parallelism.py", 73), diff --git a/test/srt/test_ascend_attention_backend.py b/test/srt/test_ascend_attention_backend.py new file mode 100644 index 000000000..4ca6bba8f --- /dev/null +++ b/test/srt/test_ascend_attention_backend.py @@ -0,0 +1,74 @@ +""" +Usage: +python3 -m unittest test_ascend_attention_backend.TestAscendAttnBackend.test_gsm8k +""" + +import unittest +from types import SimpleNamespace +from urllib.parse import urlparse + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + is_in_ci, + popen_launch_server, + run_bench_offline_throughput, +) + + +class TestAscendAttnBackend(CustomTestCase): + def test_latency(self): + output_throughput = run_bench_offline_throughput( + DEFAULT_MODEL_NAME_FOR_TEST, + [ + "--attention-backend", + "ascend", + ], + ) + + print(f"{output_throughput=}") + + if is_in_ci(): + self.assertGreater(output_throughput, 18) + + def test_gsm8k(self): + model = DEFAULT_MODEL_NAME_FOR_TEST + base_url = DEFAULT_URL_FOR_TEST + url = urlparse(base_url) + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--attention-backend", + "ascend", + "--mem-fraction-static", + 0.8, + ], + ) + + try: + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=1319, + max_new_tokens=512, + parallel=128, + host=f"http://{url.hostname}", + port=int(url.port), + ) + + metrics = run_eval_few_shot_gsm8k(args) + self.assertGreaterEqual(metrics["accuracy"], 0.62) + self.assertLessEqual(metrics["latency"], 150) + finally: + kill_process_tree(process.pid) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_ascend_mla_backend.py b/test/srt/test_ascend_mla_backend.py new file mode 100644 index 000000000..0db2f3b3e --- /dev/null +++ b/test/srt/test_ascend_mla_backend.py @@ -0,0 +1,96 @@ +""" +Usage: +python3 -m unittest test_ascend_mla_backend.TestAscendMLABackend.test_gsm8k +""" + +import os +import unittest +from types import SimpleNamespace +from urllib.parse import urlparse + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + is_in_ci, + popen_launch_server, + run_bench_offline_throughput, +) + +if "ASCEND_RT_VISIBLE_DEVICES" not in os.environ: + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1,2,3" +DEFAULT_PORT_FOR_SRT_TEST_RUNNER = ( + 7000 + int(os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "0")[0]) * 100 +) +DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}" +DEFAULT_MODEL_NAME_FOR_TEST = "/models/DeepSeek-V2-Lite-Chat" +if not os.path.exists(DEFAULT_MODEL_NAME_FOR_TEST): + DEFAULT_MODEL_NAME_FOR_TEST = DEFAULT_MLA_MODEL_NAME_FOR_TEST + + +class TestAscendMLABackend(CustomTestCase): + def test_latency(self): + output_throughput = run_bench_offline_throughput( + DEFAULT_MODEL_NAME_FOR_TEST, + [ + "--attention-backend", + "ascend", + "--mem-fraction-static", + 0.7, + "--tp-size", + "4", + "--trust-remote-code", + "--disable-cuda-graph", + ], + ) + + print(f"{output_throughput=}") + + if is_in_ci(): + self.assertGreater(output_throughput, 18) + + def test_gsm8k(self): + model = DEFAULT_MODEL_NAME_FOR_TEST + base_url = DEFAULT_URL_FOR_TEST + url = urlparse(base_url) + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--attention-backend", + "ascend", + "--mem-fraction-static", + 0.7, + "--tp-size", + "4", + "--trust-remote-code", + "--disable-cuda-graph", + ], + ) + + try: + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=128, + max_new_tokens=512, + parallel=128, + host=f"http://{url.hostname}", + port=int(url.port), + ) + + metrics = run_eval_few_shot_gsm8k(args) + self.assertGreaterEqual(metrics["accuracy"], 0.62) + self.assertGreaterEqual(metrics["output_throughput"], 50) + finally: + kill_process_tree(process.pid) + + +if __name__ == "__main__": + unittest.main()