Ascend attention backend(PA&MLA) (#7722)
Co-authored-by: Maksim <makcum888e@mail.ru> Co-authored-by: VDV1985 <vladdv85@mail.ru>
This commit is contained in:
219
python/sglang/srt/layers/attention/ascend_backend.py
Normal file
219
python/sglang/srt/layers/attention/ascend_backend.py
Normal file
@@ -0,0 +1,219 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from sglang.srt.configs.model_config import AttentionArch
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
||||
from sglang.srt.layers.radix_attention import AttentionType
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardMetadata:
|
||||
|
||||
# calculated map for kv positions [bs * maxseqlen]
|
||||
block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
# seq len inputs
|
||||
extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
|
||||
seq_lens_cpu_int: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class AscendAttnBackend(AttentionBackend):
|
||||
|
||||
def gen_attention_mask(self, max_seq_len: int, dtype=torch.float16):
|
||||
mask_flag = torch.tril(
|
||||
torch.ones((max_seq_len, max_seq_len), dtype=torch.bool)
|
||||
).view(max_seq_len, max_seq_len)
|
||||
mask_flag = ~mask_flag
|
||||
if dtype == torch.float16:
|
||||
mask_value = torch.finfo(torch.float32).min
|
||||
else:
|
||||
mask_value = 1
|
||||
self.mask = (
|
||||
torch.masked_fill(
|
||||
torch.zeros(size=(max_seq_len, max_seq_len)), mask_flag, mask_value
|
||||
)
|
||||
.to(dtype)
|
||||
.to(self.device)
|
||||
)
|
||||
self.mask_len = max_seq_len
|
||||
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
super().__init__()
|
||||
self.forward_metadata = ForwardMetadata()
|
||||
self.device = model_runner.device
|
||||
self.gen_attention_mask(128, model_runner.dtype)
|
||||
self.page_size = model_runner.page_size
|
||||
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
|
||||
if self.use_mla:
|
||||
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
||||
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
||||
self.native_attn = TorchNativeAttnBackend(model_runner)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Init the metadata for a forward pass."""
|
||||
self.forward_metadata.block_tables = (
|
||||
forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : forward_batch.seq_lens.max()
|
||||
][:, :: self.page_size]
|
||||
// self.page_size
|
||||
)
|
||||
if forward_batch.extend_seq_lens is not None:
|
||||
self.forward_metadata.extend_seq_lens_cpu_int = (
|
||||
forward_batch.extend_seq_lens.cpu().int()
|
||||
)
|
||||
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
||||
|
||||
if not self.use_mla:
|
||||
query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
output = torch.empty(
|
||||
(query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
|
||||
torch_npu._npu_flash_attention_qlens(
|
||||
query=query,
|
||||
key_cache=k_cache,
|
||||
value_cache=v_cache,
|
||||
mask=self.mask,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
|
||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||
scale_value=layer.scaling,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_kv_heads=layer.tp_k_head_num,
|
||||
out=output,
|
||||
)
|
||||
return output
|
||||
else:
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
||||
|
||||
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||
|
||||
causal = True
|
||||
if (
|
||||
layer.is_cross_attention
|
||||
or layer.attn_type == AttentionType.ENCODER_ONLY
|
||||
):
|
||||
causal = False
|
||||
|
||||
self.native_attn._run_sdpa_forward_extend(
|
||||
q_,
|
||||
o_,
|
||||
k_cache.view(
|
||||
-1, layer.tp_k_head_num, (self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
),
|
||||
v_cache.view(-1, layer.tp_v_head_num, self.kv_lora_rank),
|
||||
forward_batch.req_to_token_pool.req_to_token,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.extend_prefix_lens,
|
||||
forward_batch.extend_seq_lens,
|
||||
scaling=layer.scaling,
|
||||
enable_gqa=use_gqa,
|
||||
causal=causal,
|
||||
)
|
||||
return o
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
if not self.use_mla:
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
||||
|
||||
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||
num_tokens = query.shape[0]
|
||||
output = torch.empty(
|
||||
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=k_cache,
|
||||
value_cache=v_cache,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_kv_heads=layer.tp_k_head_num,
|
||||
scale_value=layer.scaling,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||
out=output,
|
||||
)
|
||||
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
|
||||
else:
|
||||
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
num_tokens = query.shape[0]
|
||||
kv_c_and_k_pe_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
||||
layer.layer_id
|
||||
)
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
|
||||
-1,
|
||||
self.page_size,
|
||||
layer.tp_k_head_num,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
)
|
||||
|
||||
attn_output = torch.empty(
|
||||
[num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
)
|
||||
torch_npu._npu_paged_attention_mla(
|
||||
query=query,
|
||||
key_cache=kv_c_and_k_pe_cache,
|
||||
num_kv_heads=layer.tp_k_head_num,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
scale_value=layer.scaling,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||
mla_vheadsize=self.kv_lora_rank,
|
||||
out=attn_output,
|
||||
)
|
||||
return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)
|
||||
@@ -3,7 +3,6 @@ from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from sgl_kernel import silu_and_mul
|
||||
from torch.nn import Module
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
@@ -50,13 +49,18 @@ from sglang.srt.utils import (
|
||||
dispose_tensor,
|
||||
get_bool_env_var,
|
||||
is_hip,
|
||||
is_npu,
|
||||
set_weight_attrs,
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_npu = is_npu()
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
if not _is_npu:
|
||||
from sgl_kernel import silu_and_mul
|
||||
|
||||
if _is_hip:
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
|
||||
|
||||
@@ -321,6 +321,44 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
def forward_npu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return moe_forward_native(
|
||||
layer,
|
||||
x,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
num_fused_shared_experts,
|
||||
custom_routing_function,
|
||||
correction_bias,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
inplace,
|
||||
no_combine,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ from sglang.srt.utils import (
|
||||
is_cpu,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
is_npu,
|
||||
)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
@@ -42,6 +43,7 @@ _is_hip = is_hip()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
_is_npu = is_npu()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import moe_fused_gate
|
||||
@@ -159,6 +161,9 @@ def grouped_topk_gpu(
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
# NPU compiler limitation
|
||||
if _is_npu and scores.dtype == torch.bfloat16:
|
||||
scores = scores.to(torch.float16)
|
||||
num_token = scores.shape[0]
|
||||
num_experts = scores.shape[1]
|
||||
group_scores = (
|
||||
|
||||
@@ -660,7 +660,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
beta_slow: int = 1,
|
||||
mscale: float = 1,
|
||||
mscale_all_dim: float = 0,
|
||||
device: Optional[str] = "cuda",
|
||||
device: Optional[str] = "cuda" if not _is_npu else "npu",
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
@@ -679,7 +679,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
)
|
||||
|
||||
# Re-dispatch
|
||||
if _is_hip:
|
||||
if _is_hip or _is_npu:
|
||||
self._forward_method = self.forward_native
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
|
||||
@@ -1673,6 +1673,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
)
|
||||
or global_server_args_dict["attention_backend"] == "flashmla"
|
||||
or global_server_args_dict["attention_backend"] == "cutlass_mla"
|
||||
or global_server_args_dict["attention_backend"] == "ascend"
|
||||
or global_server_args_dict["enable_two_batch_overlap"]
|
||||
):
|
||||
seq_lens_cpu = (
|
||||
@@ -1875,7 +1876,10 @@ def get_last_loc(
|
||||
req_pool_indices_tensor: torch.Tensor,
|
||||
prefix_lens_tensor: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if global_server_args_dict["attention_backend"] != "torch_native":
|
||||
if (
|
||||
global_server_args_dict["attention_backend"] != "ascend"
|
||||
and global_server_args_dict["attention_backend"] != "torch_native"
|
||||
):
|
||||
impl = get_last_loc_triton
|
||||
else:
|
||||
impl = get_last_loc_torch
|
||||
|
||||
@@ -540,3 +540,164 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
)
|
||||
self.is_not_in_free_group = True
|
||||
self.free_group = []
|
||||
|
||||
|
||||
def alloc_extend_kernel_ascend(
|
||||
prefix_lens,
|
||||
seq_lens,
|
||||
last_loc,
|
||||
free_pages,
|
||||
out_indices,
|
||||
page_size,
|
||||
device,
|
||||
):
|
||||
extend_lens = seq_lens - prefix_lens
|
||||
end_pos = torch.cumsum(extend_lens, 0)
|
||||
start_pos = end_pos - extend_lens
|
||||
num_new_pages = (seq_lens + page_size - 1) // page_size - (
|
||||
prefix_lens + page_size - 1
|
||||
) // page_size
|
||||
num_full_new_pages = (seq_lens) // page_size - (
|
||||
prefix_lens + page_size - 1
|
||||
) // page_size
|
||||
need_page = num_new_pages - num_full_new_pages
|
||||
end_new_pages = torch.cumsum(num_new_pages, 0)
|
||||
start_new_pages = end_new_pages - num_new_pages
|
||||
pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
|
||||
for i in range(len(prefix_lens)):
|
||||
num1 = (
|
||||
min(
|
||||
seq_lens[i],
|
||||
(prefix_lens[i] + page_size - 1) // page_size * page_size,
|
||||
)
|
||||
- prefix_lens[i]
|
||||
)
|
||||
if num1:
|
||||
out_indices[start_pos[i] : start_pos[i] + num1] = (
|
||||
last_loc[i] + 1 + pos_in_page[:num1].view(-1)
|
||||
)
|
||||
|
||||
num2 = (
|
||||
seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
|
||||
) * page_size
|
||||
if num2:
|
||||
pages = (
|
||||
free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
|
||||
* page_size
|
||||
)
|
||||
out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
|
||||
pages.view(-1, 1) + pos_in_page.view(1, -1)
|
||||
).view(-1)
|
||||
|
||||
num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
|
||||
if num3:
|
||||
out_indices[end_pos[i] - num3 : end_pos[i]] = (
|
||||
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
|
||||
).view(-1)
|
||||
return num_new_pages
|
||||
|
||||
|
||||
def alloc_decode_kernel_ascend(
|
||||
seq_lens,
|
||||
last_loc,
|
||||
free_pages,
|
||||
out_indices,
|
||||
page_size,
|
||||
):
|
||||
num_new_pages = (seq_lens + page_size - 1) // page_size - (
|
||||
seq_lens - 1 + page_size - 1
|
||||
) // page_size
|
||||
end_new_pages = torch.cumsum(num_new_pages, 0)
|
||||
start_new_pages = end_new_pages - num_new_pages
|
||||
for i in range(len(seq_lens)):
|
||||
if num_new_pages[i]:
|
||||
out_indices[i] = free_pages[start_new_pages[i]] * page_size
|
||||
else:
|
||||
out_indices[i] = last_loc[i] + 1
|
||||
return num_new_pages
|
||||
|
||||
|
||||
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
page_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
kvcache: KVCache,
|
||||
):
|
||||
super().__init__(size, page_size, dtype, device, kvcache)
|
||||
self.ret_values = torch.empty((), dtype=torch.int32, device=self.device)
|
||||
|
||||
def alloc_extend(
|
||||
self,
|
||||
prefix_lens: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
extend_num_tokens: int,
|
||||
):
|
||||
if self.debug_mode:
|
||||
assert torch.all(
|
||||
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
||||
)
|
||||
|
||||
bs = len(prefix_lens)
|
||||
out_indices = torch.empty(
|
||||
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
self.ret_values = alloc_extend_kernel_ascend(
|
||||
prefix_lens,
|
||||
seq_lens,
|
||||
last_loc,
|
||||
self.free_pages,
|
||||
out_indices,
|
||||
self.page_size,
|
||||
self.device,
|
||||
)
|
||||
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||
|
||||
num_new_pages = self.ret_values.sum()
|
||||
if num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
self.free_pages = self.free_pages[num_new_pages:]
|
||||
return out_indices
|
||||
|
||||
def alloc_decode(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
):
|
||||
if self.debug_mode:
|
||||
assert torch.all(
|
||||
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
||||
)
|
||||
|
||||
bs = len(seq_lens)
|
||||
out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
||||
|
||||
self.ret_values = alloc_decode_kernel_ascend(
|
||||
seq_lens,
|
||||
last_loc,
|
||||
self.free_pages,
|
||||
out_indices,
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||
|
||||
num_new_pages = self.ret_values.sum()
|
||||
if num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
self.free_pages = self.free_pages[num_new_pages:]
|
||||
return out_indices
|
||||
|
||||
def clear(self):
|
||||
super().clear()
|
||||
self.free_pages = self.free_pages.to(torch.int32)
|
||||
|
||||
@@ -568,6 +568,76 @@ class SWAKVPool(KVCache):
|
||||
)
|
||||
|
||||
|
||||
class AscendTokenToKVPool(MHATokenToKVPool):
|
||||
|
||||
def _create_buffers(self):
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
# [size, head_num, head_dim] for each layer
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.k_buffer = [
|
||||
torch.zeros(
|
||||
(
|
||||
self.size // self.page_size + 1,
|
||||
self.page_size,
|
||||
self.head_num,
|
||||
self.head_dim,
|
||||
),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.zeros(
|
||||
(
|
||||
self.size // self.page_size + 1,
|
||||
self.page_size,
|
||||
self.head_num,
|
||||
self.head_dim,
|
||||
),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer: RadixAttention,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
k_scale: Optional[float] = None,
|
||||
v_scale: Optional[float] = None,
|
||||
):
|
||||
layer_id = layer.layer_id
|
||||
if cache_k.dtype != self.dtype:
|
||||
if k_scale is not None:
|
||||
cache_k.div_(k_scale)
|
||||
if v_scale is not None:
|
||||
cache_v.div_(v_scale)
|
||||
cache_k = cache_k.to(self.dtype)
|
||||
cache_v = cache_v.to(self.dtype)
|
||||
|
||||
if self.store_dtype != self.dtype:
|
||||
cache_k = cache_k.view(self.store_dtype)
|
||||
cache_v = cache_v.view(self.store_dtype)
|
||||
|
||||
import torch_npu
|
||||
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=cache_k,
|
||||
value=cache_v,
|
||||
key_cache=self.k_buffer[layer_id].view(
|
||||
-1, self.page_size, self.head_num, self.head_dim
|
||||
),
|
||||
value_cache=self.v_buffer[layer_id].view(
|
||||
-1, self.page_size, self.head_num, self.head_dim
|
||||
),
|
||||
slot_indices=loc,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def set_mla_kv_buffer_kernel(
|
||||
kv_buffer_ptr,
|
||||
@@ -820,6 +890,84 @@ class MLATokenToKVPool(KVCache):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
page_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
layer_num: int,
|
||||
device: str,
|
||||
enable_memory_saver: bool,
|
||||
start_layer: Optional[int] = None,
|
||||
end_layer: Optional[int] = None,
|
||||
):
|
||||
super(MLATokenToKVPool, self).__init__(
|
||||
size,
|
||||
page_size,
|
||||
dtype,
|
||||
layer_num,
|
||||
device,
|
||||
enable_memory_saver,
|
||||
start_layer,
|
||||
end_layer,
|
||||
)
|
||||
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
|
||||
self.custom_mem_pool = None
|
||||
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.kv_buffer = [
|
||||
torch.zeros(
|
||||
(
|
||||
self.size // self.page_size + 1,
|
||||
self.page_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
self.layer_transfer_counter = None
|
||||
|
||||
kv_size = self.get_kv_size_bytes()
|
||||
logger.info(
|
||||
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
|
||||
)
|
||||
self.mem_usage = kv_size / GB
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer: RadixAttention,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
):
|
||||
layer_id = layer.layer_id
|
||||
if cache_k.dtype != self.dtype:
|
||||
cache_k = cache_k.to(self.dtype)
|
||||
|
||||
if self.store_dtype != self.dtype:
|
||||
cache_k = cache_k.view(store_dtype)
|
||||
|
||||
import torch_npu
|
||||
|
||||
torch_npu._npu_reshape_and_cache_siso(
|
||||
key=cache_k.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim),
|
||||
key_cache=self.kv_buffer[layer_id - self.start_layer].view(
|
||||
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
|
||||
),
|
||||
slot_indices=loc,
|
||||
)
|
||||
|
||||
|
||||
class DoubleSparseTokenToKVPool(KVCache):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -39,7 +39,12 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
from sglang.srt.utils import flatten_nested_list, get_compiler_backend, support_triton
|
||||
from sglang.srt.utils import (
|
||||
flatten_nested_list,
|
||||
get_compiler_backend,
|
||||
is_npu,
|
||||
support_triton,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
@@ -50,6 +55,8 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
|
||||
_is_npu = is_npu()
|
||||
|
||||
|
||||
class ForwardMode(IntEnum):
|
||||
# Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
|
||||
@@ -739,7 +746,7 @@ def compute_position_torch(
|
||||
return positions.to(torch.int64), extend_start_loc
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu)
|
||||
def clamp_position(seq_lens):
|
||||
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
||||
|
||||
|
||||
@@ -72,12 +72,15 @@ from sglang.srt.managers.schedule_batch import (
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.mem_cache.allocator import (
|
||||
AscendPagedTokenToKVPoolAllocator,
|
||||
BaseTokenToKVPoolAllocator,
|
||||
PagedTokenToKVPoolAllocator,
|
||||
SWATokenToKVPoolAllocator,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
AscendMLAPagedTokenToKVPool,
|
||||
AscendTokenToKVPool,
|
||||
DoubleSparseTokenToKVPool,
|
||||
MHATokenToKVPool,
|
||||
MLATokenToKVPool,
|
||||
@@ -110,6 +113,7 @@ from sglang.srt.utils import (
|
||||
is_hip,
|
||||
is_hopper_with_cuda_12_3,
|
||||
is_no_spec_infer_or_topk_one,
|
||||
is_npu,
|
||||
monkey_patch_p2p_access_check,
|
||||
monkey_patch_vllm_gguf_config,
|
||||
set_cpu_offload_max_bytes,
|
||||
@@ -117,6 +121,7 @@ from sglang.srt.utils import (
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_npu = is_npu()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
|
||||
# Use a small KV cache pool size for tests in CI
|
||||
@@ -308,6 +313,7 @@ class ModelRunner:
|
||||
self.init_cuda_graphs()
|
||||
else:
|
||||
self.cuda_graph_runner = None
|
||||
self.cuda_graph_mem_usage = 0
|
||||
self.init_attention_backend()
|
||||
|
||||
# auxiliary hidden capture mode. TODO: expose this to server args?
|
||||
@@ -369,6 +375,8 @@ class ModelRunner:
|
||||
server_args.attention_backend = "fa3"
|
||||
elif _is_hip:
|
||||
server_args.attention_backend = "aiter"
|
||||
elif _is_npu:
|
||||
server_args.attention_backend = "ascend"
|
||||
else:
|
||||
server_args.attention_backend = (
|
||||
"flashinfer" if is_flashinfer_available() else "triton"
|
||||
@@ -388,6 +396,8 @@ class ModelRunner:
|
||||
server_args.attention_backend = "aiter"
|
||||
else:
|
||||
server_args.attention_backend = "triton"
|
||||
elif _is_npu:
|
||||
server_args.attention_backend = "ascend"
|
||||
else:
|
||||
server_args.attention_backend = "triton"
|
||||
logger.info(
|
||||
@@ -402,6 +412,7 @@ class ModelRunner:
|
||||
"triton",
|
||||
"flashmla",
|
||||
"cutlass_mla",
|
||||
"ascend",
|
||||
]:
|
||||
logger.info(
|
||||
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
|
||||
@@ -1096,7 +1107,35 @@ class ModelRunner:
|
||||
# Draft worker shares req_to_token_pool with the target worker.
|
||||
assert self.is_draft_worker
|
||||
|
||||
if self.use_mla_backend:
|
||||
if self.server_args.attention_backend == "ascend" and not self.use_mla_backend:
|
||||
self.token_to_kv_pool = AscendTokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
||||
head_dim=self.model_config.head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
elif self.server_args.attention_backend == "ascend" and self.use_mla_backend:
|
||||
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
kv_lora_rank=self.model_config.kv_lora_rank,
|
||||
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
||||
layer_num=(
|
||||
self.model_config.num_hidden_layers
|
||||
if not self.is_draft_worker
|
||||
else self.model_config.hf_config.num_nextn_predict_layers
|
||||
), # PP is not compatible with mla backend
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
start_layer=self.start_layer,
|
||||
end_layer=self.end_layer,
|
||||
)
|
||||
elif self.use_mla_backend:
|
||||
self.token_to_kv_pool = MLATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
@@ -1176,13 +1215,22 @@ class ModelRunner:
|
||||
kvcache=self.token_to_kv_pool,
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
)
|
||||
if _is_npu:
|
||||
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
page_size=self.page_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
kvcache=self.token_to_kv_pool,
|
||||
)
|
||||
else:
|
||||
assert self.is_draft_worker
|
||||
|
||||
@@ -1229,6 +1277,10 @@ class ModelRunner:
|
||||
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
||||
|
||||
return AiterAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "ascend":
|
||||
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
||||
|
||||
return AscendAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "triton":
|
||||
assert not self.model_config.is_encoder_decoder, (
|
||||
"Cross attention is not supported in the triton attention backend. "
|
||||
|
||||
@@ -956,7 +956,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
|
||||
if self.attention_backend == "flashinfer":
|
||||
if self.attention_backend == "ascend":
|
||||
return AttnForwardMethod.MLA
|
||||
elif self.attention_backend == "flashinfer":
|
||||
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
||||
if (
|
||||
not self.flashinfer_mla_disable_ragged
|
||||
|
||||
@@ -380,6 +380,12 @@ class ServerArgs:
|
||||
)
|
||||
self.disable_cuda_graph = True
|
||||
|
||||
if self.attention_backend == "ascend":
|
||||
logger.warning(
|
||||
"At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
|
||||
)
|
||||
self.page_size = 128
|
||||
|
||||
# Choose grammar backend
|
||||
if self.grammar_backend is None:
|
||||
self.grammar_backend = "xgrammar"
|
||||
@@ -1113,6 +1119,7 @@ class ServerArgs:
|
||||
"flashmla",
|
||||
"intel_amx",
|
||||
"torch_native",
|
||||
"ascend",
|
||||
"triton",
|
||||
],
|
||||
default=ServerArgs.attention_backend,
|
||||
|
||||
@@ -2399,7 +2399,7 @@ def bind_or_assign(target, source):
|
||||
|
||||
|
||||
def support_triton(backend: str) -> bool:
|
||||
return backend not in ["torch_native", "intel_amx"]
|
||||
return backend not in ["torch_native", "intel_amx", "ascend"]
|
||||
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user