[new feat] ascend backend support fia fusion kernel (#8328)
Co-authored-by: Even Zhou <even.y.zhou@outlook.com>
This commit is contained in:
6
.github/workflows/pr-test-npu.yml
vendored
6
.github/workflows/pr-test-npu.yml
vendored
@@ -47,7 +47,7 @@ jobs:
|
||||
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
|
||||
|
||||
- name: Run test
|
||||
timeout-minutes: 30
|
||||
timeout-minutes: 60
|
||||
env:
|
||||
SGLANG_USE_MODELSCOPE: true
|
||||
SGLANG_IS_IN_CI: true
|
||||
@@ -82,7 +82,7 @@ jobs:
|
||||
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
|
||||
|
||||
- name: Run test
|
||||
timeout-minutes: 30
|
||||
timeout-minutes: 90
|
||||
env:
|
||||
SGLANG_USE_MODELSCOPE: true
|
||||
SGLANG_IS_IN_CI: true
|
||||
@@ -117,7 +117,7 @@ jobs:
|
||||
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
|
||||
|
||||
- name: Run test
|
||||
timeout-minutes: 30
|
||||
timeout-minutes: 60
|
||||
env:
|
||||
SGLANG_USE_MODELSCOPE: true
|
||||
SGLANG_IS_IN_CI: true
|
||||
|
||||
@@ -12,11 +12,16 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
||||
from sglang.srt.layers.radix_attention import AttentionType
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardMetadata:
|
||||
@@ -54,7 +59,6 @@ class AscendAttnBackend(AttentionBackend):
|
||||
super().__init__()
|
||||
self.forward_metadata = None
|
||||
self.device = model_runner.device
|
||||
self.gen_attention_mask(128, model_runner.dtype)
|
||||
self.page_size = model_runner.page_size
|
||||
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
|
||||
if self.use_mla:
|
||||
@@ -65,6 +69,17 @@ class AscendAttnBackend(AttentionBackend):
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.graph_mode = False
|
||||
self.use_fia = get_bool_env_var("ASCEND_USE_FIA", "False")
|
||||
if not self.use_fia:
|
||||
self.gen_attention_mask(128, model_runner.dtype)
|
||||
mask_length = 2048
|
||||
self.fia_mask = ~torch.tril(
|
||||
torch.ones(
|
||||
(mask_length, mask_length),
|
||||
dtype=torch.bool,
|
||||
device=model_runner.device,
|
||||
)
|
||||
)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Init the metadata for a forward pass."""
|
||||
@@ -81,6 +96,9 @@ class AscendAttnBackend(AttentionBackend):
|
||||
forward_batch.extend_seq_lens.cpu().int()
|
||||
)
|
||||
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
||||
self.forward_metadata.seq_lens_list_cumsum = np.cumsum(
|
||||
forward_batch.extend_seq_lens_cpu
|
||||
)
|
||||
|
||||
self.graph_mode = False
|
||||
|
||||
@@ -151,71 +169,89 @@ class AscendAttnBackend(AttentionBackend):
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
||||
|
||||
if not self.use_mla:
|
||||
query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
output = torch.empty(
|
||||
(query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
||||
|
||||
if self.use_fia:
|
||||
"""FIA will support multi-bs in the later version of CANN"""
|
||||
q = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||
attn_output = torch.empty(
|
||||
(q.size(0), layer.tp_q_head_num, layer.v_head_dim),
|
||||
device=q.device,
|
||||
dtype=q.dtype,
|
||||
)
|
||||
q_len_offset = 0
|
||||
for q_len in forward_batch.extend_seq_lens_cpu:
|
||||
attn_output[q_len_offset : q_len_offset + q_len] = (
|
||||
torch.ops.npu.npu_fused_infer_attention_score(
|
||||
q[None, q_len_offset : q_len_offset + q_len],
|
||||
k[None, q_len_offset : q_len_offset + q_len],
|
||||
v[None, q_len_offset : q_len_offset + q_len],
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_key_value_heads=layer.tp_k_head_num,
|
||||
input_layout="BSND", # todo, TND not supports q_heads!=k_heads
|
||||
atten_mask=self.fia_mask.unsqueeze(0),
|
||||
sparse_mode=3,
|
||||
scale=layer.scaling,
|
||||
next_tokens=0,
|
||||
)[0]
|
||||
)
|
||||
q_len_offset += q_len
|
||||
attn_output = attn_output.view(
|
||||
-1, layer.tp_q_head_num * layer.v_head_dim
|
||||
)
|
||||
|
||||
torch_npu._npu_flash_attention_qlens(
|
||||
query=query,
|
||||
key_cache=k_cache,
|
||||
value_cache=v_cache,
|
||||
mask=self.mask,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
|
||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||
scale_value=layer.scaling,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_kv_heads=layer.tp_k_head_num,
|
||||
out=output,
|
||||
)
|
||||
return output
|
||||
else:
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
attn_output = torch.empty(
|
||||
(query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
|
||||
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
||||
torch_npu._npu_flash_attention_qlens(
|
||||
query=query,
|
||||
key_cache=k_cache,
|
||||
value_cache=v_cache,
|
||||
mask=self.mask,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
|
||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||
scale_value=layer.scaling,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_kv_heads=layer.tp_k_head_num,
|
||||
out=attn_output,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
layer.qk_head_dim != layer.v_head_dim
|
||||
), "FIA only supports qk_head_dim != v_head_dim"
|
||||
q_nope, q_rope = q.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
k_nope, k_rope = k.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||
|
||||
causal = True
|
||||
if (
|
||||
layer.is_cross_attention
|
||||
or layer.attn_type == AttentionType.ENCODER_ONLY
|
||||
):
|
||||
causal = False
|
||||
|
||||
self.native_attn._run_sdpa_forward_extend(
|
||||
q_,
|
||||
o_,
|
||||
k_cache.view(
|
||||
-1, layer.tp_k_head_num, (self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
),
|
||||
v_cache.view(-1, layer.tp_v_head_num, self.kv_lora_rank),
|
||||
forward_batch.req_to_token_pool.req_to_token,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.extend_prefix_lens,
|
||||
forward_batch.extend_seq_lens,
|
||||
scaling=layer.scaling,
|
||||
enable_gqa=use_gqa,
|
||||
causal=causal,
|
||||
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
q_nope,
|
||||
k_nope,
|
||||
v,
|
||||
query_rope=q_rope,
|
||||
key_rope=k_rope,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
input_layout="TND",
|
||||
atten_mask=self.fia_mask,
|
||||
sparse_mode=3,
|
||||
actual_seq_lengths=self.forward_metadata.seq_lens_list_cumsum,
|
||||
actual_seq_lengths_kv=self.forward_metadata.seq_lens_list_cumsum,
|
||||
scale=layer.scaling,
|
||||
next_tokens=0,
|
||||
)
|
||||
return o
|
||||
|
||||
return attn_output
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
@@ -224,13 +260,17 @@ class AscendAttnBackend(AttentionBackend):
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
save_kv_cache: bool = False,
|
||||
# For multi-head latent attention
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
if not self.use_mla:
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
num_tokens = q.shape[0]
|
||||
if self.graph_mode:
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
||||
layer.layer_id
|
||||
@@ -239,7 +279,6 @@ class AscendAttnBackend(AttentionBackend):
|
||||
layer.layer_id
|
||||
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
|
||||
query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
num_tokens = query.shape[0]
|
||||
workspace = (
|
||||
torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||
query,
|
||||
@@ -254,7 +293,7 @@ class AscendAttnBackend(AttentionBackend):
|
||||
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
|
||||
)
|
||||
)
|
||||
output = torch.empty(
|
||||
attn_output = torch.empty(
|
||||
(num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
@@ -272,61 +311,129 @@ class AscendAttnBackend(AttentionBackend):
|
||||
scale=layer.scaling,
|
||||
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
|
||||
workspace=workspace,
|
||||
out=[output, softmax_lse],
|
||||
out=[attn_output, softmax_lse],
|
||||
)
|
||||
else:
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
|
||||
layer.layer_id
|
||||
)
|
||||
if self.use_fia:
|
||||
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
q.view(
|
||||
forward_batch.batch_size,
|
||||
-1,
|
||||
layer.tp_q_head_num,
|
||||
layer.qk_head_dim,
|
||||
),
|
||||
k_cache.view(
|
||||
-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim
|
||||
),
|
||||
v_cache.view(
|
||||
-1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim
|
||||
),
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_key_value_heads=layer.tp_k_head_num,
|
||||
input_layout="BSND",
|
||||
atten_mask=None,
|
||||
block_size=self.page_size,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
|
||||
scale=layer.scaling,
|
||||
)
|
||||
else:
|
||||
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||
attn_output = torch.empty(
|
||||
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
|
||||
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||
num_tokens = query.shape[0]
|
||||
output = torch.empty(
|
||||
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=k_cache,
|
||||
value_cache=v_cache,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_kv_heads=layer.tp_k_head_num,
|
||||
scale_value=layer.scaling,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||
out=attn_output,
|
||||
)
|
||||
return attn_output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
|
||||
else:
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, k_rope
|
||||
)
|
||||
num_tokens = q.shape[0]
|
||||
kv_c = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
k_pe = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
||||
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=k_cache,
|
||||
value_cache=v_cache,
|
||||
if (self.graph_mode or self.use_fia) and (
|
||||
layer.tp_q_head_num // layer.tp_k_head_num
|
||||
) >= 8:
|
||||
"""layer.tp_q_head_num // layer.tp_k_head_num < 8 will support in the later version of CANN"""
|
||||
kv_c = kv_c.view(
|
||||
-1, self.page_size, layer.tp_k_head_num * self.kv_lora_rank
|
||||
)
|
||||
k_pe = k_pe.view(
|
||||
-1, self.page_size, layer.tp_k_head_num * self.qk_rope_head_dim
|
||||
)
|
||||
q = q.view(
|
||||
forward_batch.batch_size, -1, layer.tp_q_head_num, self.kv_lora_rank
|
||||
)
|
||||
q_rope = q_rope.view(
|
||||
forward_batch.batch_size,
|
||||
-1,
|
||||
layer.tp_q_head_num,
|
||||
self.qk_rope_head_dim,
|
||||
)
|
||||
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
q,
|
||||
kv_c,
|
||||
kv_c,
|
||||
query_rope=q_rope,
|
||||
key_rope=k_pe,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_key_value_heads=layer.tp_k_head_num,
|
||||
input_layout="BSND",
|
||||
atten_mask=None,
|
||||
sparse_mode=0,
|
||||
scale=layer.scaling,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
block_size=self.page_size,
|
||||
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
self.graph_mode == False
|
||||
) # _npu_paged_attention_mla not support graph mode
|
||||
q = torch.cat([q, q_rope], dim=-1)
|
||||
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
kv_c_and_k_pe_cache = torch.cat([kv_c, k_pe], dim=-1)
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
|
||||
-1,
|
||||
self.page_size,
|
||||
layer.tp_k_head_num,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
)
|
||||
attn_output = torch.empty(
|
||||
[num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
)
|
||||
torch_npu._npu_paged_attention_mla(
|
||||
query=query,
|
||||
key_cache=kv_c_and_k_pe_cache,
|
||||
num_kv_heads=layer.tp_k_head_num,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
scale_value=layer.scaling,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||
out=output,
|
||||
mla_vheadsize=self.kv_lora_rank,
|
||||
out=attn_output,
|
||||
)
|
||||
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
|
||||
else:
|
||||
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
num_tokens = query.shape[0]
|
||||
kv_c_and_k_pe_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
||||
layer.layer_id
|
||||
)
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
|
||||
-1,
|
||||
self.page_size,
|
||||
layer.tp_k_head_num,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
)
|
||||
|
||||
attn_output = torch.empty(
|
||||
[num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
)
|
||||
torch_npu._npu_paged_attention_mla(
|
||||
query=query,
|
||||
key_cache=kv_c_and_k_pe_cache,
|
||||
num_kv_heads=layer.tp_k_head_num,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
scale_value=layer.scaling,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||
mla_vheadsize=self.kv_lora_rank,
|
||||
out=attn_output,
|
||||
)
|
||||
return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)
|
||||
|
||||
@@ -304,7 +304,7 @@ class TopK(CustomOp):
|
||||
global_num_experts = router_logits.shape[-1]
|
||||
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if global_num_experts == 256 and self.topk_config.renormalize is False:
|
||||
if global_num_experts == 256 and self.topk_config.renormalize is True:
|
||||
|
||||
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
|
||||
router_logits = router_logits.to(torch.float32)
|
||||
|
||||
@@ -36,12 +36,15 @@ import triton.language as tl
|
||||
|
||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
|
||||
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GB = 1024 * 1024 * 1024
|
||||
_is_cuda = is_cuda()
|
||||
_is_npu = is_npu()
|
||||
if _is_npu:
|
||||
import torch_npu
|
||||
|
||||
|
||||
class ReqToTokenPool:
|
||||
@@ -624,8 +627,6 @@ class AscendTokenToKVPool(MHATokenToKVPool):
|
||||
cache_k = cache_k.view(self.store_dtype)
|
||||
cache_v = cache_v.view(self.store_dtype)
|
||||
|
||||
import torch_npu
|
||||
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=cache_k,
|
||||
value=cache_v,
|
||||
@@ -912,12 +913,22 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.kv_buffer = torch.zeros(
|
||||
self.k_buffer = torch.zeros(
|
||||
(
|
||||
layer_num,
|
||||
self.size // self.page_size + 1,
|
||||
self.page_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
self.kv_lora_rank,
|
||||
),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.v_buffer = torch.zeros(
|
||||
(
|
||||
layer_num,
|
||||
self.size // self.page_size + 1,
|
||||
self.page_size,
|
||||
self.qk_rope_head_dim,
|
||||
),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
@@ -931,12 +942,52 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
)
|
||||
self.mem_usage = kv_size / GB
|
||||
|
||||
def get_kv_size_bytes(self):
|
||||
assert hasattr(self, "k_buffer")
|
||||
assert hasattr(self, "v_buffer")
|
||||
kv_size_bytes = 0
|
||||
for k_cache in self.k_buffer:
|
||||
kv_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
|
||||
for v_cache in self.v_buffer:
|
||||
kv_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
|
||||
return kv_size_bytes
|
||||
|
||||
def get_kv_buffer(self, layer_id: int):
|
||||
if self.layer_transfer_counter is not None:
|
||||
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||
return (
|
||||
self.k_buffer[layer_id - self.start_layer],
|
||||
self.v_buffer[layer_id - self.start_layer],
|
||||
)
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
if self.layer_transfer_counter is not None:
|
||||
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||
|
||||
if self.store_dtype != self.dtype:
|
||||
return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
|
||||
return self.k_buffer[layer_id - self.start_layer]
|
||||
|
||||
def get_value_buffer(self, layer_id: int):
|
||||
if self.layer_transfer_counter is not None:
|
||||
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||
|
||||
if self.store_dtype != self.dtype:
|
||||
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
|
||||
return self.v_buffer[layer_id - self.start_layer]
|
||||
|
||||
# for disagg
|
||||
def get_contiguous_buf_infos(self):
|
||||
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
|
||||
kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
|
||||
kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
|
||||
kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)]
|
||||
kv_data_ptrs = [self.k_buffer[i].data_ptr() for i in range(self.layer_num)] + [
|
||||
self.v_buffer[i].data_ptr() for i in range(self.layer_num)
|
||||
]
|
||||
kv_data_lens = [self.k_buffer[i].nbytes for i in range(self.layer_num)] + [
|
||||
self.v_buffer[i].nbytes for i in range(self.layer_num)
|
||||
]
|
||||
kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
|
||||
self.v_buffer[i][0].nbytes for i in range(self.layer_num)
|
||||
]
|
||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||
|
||||
def set_kv_buffer(
|
||||
@@ -953,14 +1004,22 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
if self.store_dtype != self.dtype:
|
||||
cache_k = cache_k.view(self.store_dtype)
|
||||
|
||||
import torch_npu
|
||||
if cache_v is None:
|
||||
cache_k, cache_v = cache_k.split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
torch_npu._npu_reshape_and_cache_siso(
|
||||
key=cache_k.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim),
|
||||
key_cache=self.kv_buffer[layer_id - self.start_layer].view(
|
||||
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
self.k_buffer[layer_id - self.start_layer].view(-1, 1, self.kv_lora_rank),
|
||||
loc.view(-1, 1),
|
||||
cache_k.view(-1, 1, self.kv_lora_rank),
|
||||
)
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
self.v_buffer[layer_id - self.start_layer].view(
|
||||
-1, 1, self.qk_rope_head_dim
|
||||
),
|
||||
slot_indices=loc,
|
||||
loc.view(-1, 1),
|
||||
cache_v.view(-1, 1, self.qk_rope_head_dim),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -994,7 +994,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.current_attention_backend = attention_backend
|
||||
|
||||
if attention_backend == "ascend":
|
||||
return AttnForwardMethod.MLA
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
return AttnForwardMethod.MHA
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
elif (
|
||||
attention_backend == "flashinfer"
|
||||
or attention_backend == "fa3"
|
||||
@@ -1292,6 +1299,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
or self.current_attention_backend == "flashinfer"
|
||||
or self.current_attention_backend == "cutlass_mla"
|
||||
or self.current_attention_backend == "trtllm_mla"
|
||||
or self.current_attention_backend == "ascend"
|
||||
):
|
||||
extra_args = {}
|
||||
if self._fuse_rope_for_trtllm_mla(forward_batch):
|
||||
|
||||
103
test/srt/ascend/test_ascend_mla_fia_w8a8int8.py
Normal file
103
test/srt/ascend/test_ascend_mla_fia_w8a8int8.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import os
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
run_bench_offline_throughput,
|
||||
)
|
||||
|
||||
TEST_MODEL_MATRIX = {
|
||||
"/root/.cache/modelscope/hub/models/vllm-ascend/DeepSeek-V2-Lite-W8A8": {
|
||||
"accuracy": 0.34,
|
||||
"latency": 1000,
|
||||
"output_throughput": 6,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestAscendMlaW8A8Int8(CustomTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.models = TEST_MODEL_MATRIX.keys()
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
|
||||
cls.common_args = [
|
||||
"--trust-remote-code",
|
||||
"--disable-cuda-graph",
|
||||
"--mem-fraction-static",
|
||||
0.8,
|
||||
"--attention-backend",
|
||||
"ascend",
|
||||
"--quantization",
|
||||
"w8a8_int8",
|
||||
"--tp-size",
|
||||
2,
|
||||
"--disable-radix-cache",
|
||||
]
|
||||
|
||||
def test_a_gsm8k(self):
|
||||
os.environ["ASCEND_USE_FIA"] = "true"
|
||||
for model in self.models:
|
||||
with self.subTest(model=model):
|
||||
print(f"##=== Testing accuracy: {model} ===##")
|
||||
|
||||
process = popen_launch_server(
|
||||
model,
|
||||
self.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
*self.common_args,
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=1319,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host=f"http://{self.url.hostname}",
|
||||
port=int(self.url.port),
|
||||
)
|
||||
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
self.assertGreaterEqual(
|
||||
metrics["accuracy"],
|
||||
TEST_MODEL_MATRIX[model]["accuracy"],
|
||||
)
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
def test_b_throughput(self):
|
||||
for model in self.models:
|
||||
with self.subTest(model=model):
|
||||
print(f"##=== Testing throughput: {model} ===##")
|
||||
|
||||
output_throughput = run_bench_offline_throughput(
|
||||
model,
|
||||
[
|
||||
*self.common_args,
|
||||
],
|
||||
)
|
||||
|
||||
print(f"##=== {model} throughput: {output_throughput} ===##")
|
||||
|
||||
if is_in_ci():
|
||||
self.assertGreater(
|
||||
output_throughput,
|
||||
TEST_MODEL_MATRIX[model]["output_throughput"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -40,6 +40,7 @@ class TestAscendMlaW8A8Int8(CustomTestCase):
|
||||
"w8a8_int8",
|
||||
"--tp-size",
|
||||
4,
|
||||
"--disable-radix-cache",
|
||||
]
|
||||
|
||||
def test_a_gsm8k(self):
|
||||
|
||||
101
test/srt/ascend/test_ascend_tp2_fia_bf16.py
Normal file
101
test/srt/ascend/test_ascend_tp2_fia_bf16.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import os
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
run_bench_offline_throughput,
|
||||
)
|
||||
|
||||
TEST_MODEL_MATRIX = {
|
||||
"Qwen/Qwen2.5-7B-Instruct": {
|
||||
"accuracy": 0.85,
|
||||
"latency": 180,
|
||||
"output_throughput": 20,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestAscendTp2Bf16(CustomTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.models = TEST_MODEL_MATRIX.keys()
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
|
||||
cls.common_args = [
|
||||
"--trust-remote-code",
|
||||
"--disable-cuda-graph",
|
||||
"--mem-fraction-static",
|
||||
0.8,
|
||||
"--attention-backend",
|
||||
"ascend",
|
||||
"--tp-size",
|
||||
2,
|
||||
"--disable-radix-cache",
|
||||
]
|
||||
|
||||
def test_a_gsm8k(self):
|
||||
os.environ["ASCEND_USE_FIA"] = "true"
|
||||
for model in self.models:
|
||||
with self.subTest(model=model):
|
||||
print(f"##=== Testing accuracy: {model} ===##")
|
||||
|
||||
process = popen_launch_server(
|
||||
model,
|
||||
self.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
*self.common_args,
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=1319,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host=f"http://{self.url.hostname}",
|
||||
port=int(self.url.port),
|
||||
)
|
||||
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
self.assertGreaterEqual(
|
||||
metrics["accuracy"],
|
||||
TEST_MODEL_MATRIX[model]["accuracy"],
|
||||
)
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
def test_b_throughput(self):
|
||||
for model in self.models:
|
||||
with self.subTest(model=model):
|
||||
print(f"##=== Testing throughput: {model} ===##")
|
||||
|
||||
output_throughput = run_bench_offline_throughput(
|
||||
model,
|
||||
[
|
||||
*self.common_args,
|
||||
],
|
||||
)
|
||||
|
||||
print(f"##=== {model} throughput: {output_throughput} ===##")
|
||||
|
||||
if is_in_ci():
|
||||
self.assertGreater(
|
||||
output_throughput,
|
||||
TEST_MODEL_MATRIX[model]["output_throughput"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -275,6 +275,8 @@ suite_ascend = {
|
||||
"per-commit-2-ascend-npu": [
|
||||
TestFile("ascend/test_ascend_tp2_bf16.py", 400),
|
||||
TestFile("ascend/test_ascend_graph_tp2_bf16.py", 400),
|
||||
TestFile("ascend/test_ascend_tp2_fia_bf16.py", 400),
|
||||
TestFile("ascend/test_ascend_mla_fia_w8a8int8.py", 400),
|
||||
],
|
||||
"per-commit-4-ascend-npu": [
|
||||
TestFile("ascend/test_ascend_mla_w8a8int8.py", 400),
|
||||
|
||||
Reference in New Issue
Block a user