[new feat] ascend backend support fia fusion kernel (#8328)
Co-authored-by: Even Zhou <even.y.zhou@outlook.com>
This commit is contained in:
6
.github/workflows/pr-test-npu.yml
vendored
6
.github/workflows/pr-test-npu.yml
vendored
@@ -47,7 +47,7 @@ jobs:
|
|||||||
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
|
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
|
||||||
|
|
||||||
- name: Run test
|
- name: Run test
|
||||||
timeout-minutes: 30
|
timeout-minutes: 60
|
||||||
env:
|
env:
|
||||||
SGLANG_USE_MODELSCOPE: true
|
SGLANG_USE_MODELSCOPE: true
|
||||||
SGLANG_IS_IN_CI: true
|
SGLANG_IS_IN_CI: true
|
||||||
@@ -82,7 +82,7 @@ jobs:
|
|||||||
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
|
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
|
||||||
|
|
||||||
- name: Run test
|
- name: Run test
|
||||||
timeout-minutes: 30
|
timeout-minutes: 90
|
||||||
env:
|
env:
|
||||||
SGLANG_USE_MODELSCOPE: true
|
SGLANG_USE_MODELSCOPE: true
|
||||||
SGLANG_IS_IN_CI: true
|
SGLANG_IS_IN_CI: true
|
||||||
@@ -117,7 +117,7 @@ jobs:
|
|||||||
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
|
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
|
||||||
|
|
||||||
- name: Run test
|
- name: Run test
|
||||||
timeout-minutes: 30
|
timeout-minutes: 60
|
||||||
env:
|
env:
|
||||||
SGLANG_USE_MODELSCOPE: true
|
SGLANG_USE_MODELSCOPE: true
|
||||||
SGLANG_IS_IN_CI: true
|
SGLANG_IS_IN_CI: true
|
||||||
|
|||||||
@@ -12,11 +12,16 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|||||||
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
||||||
from sglang.srt.layers.radix_attention import AttentionType
|
from sglang.srt.layers.radix_attention import AttentionType
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
from sglang.srt.utils import get_bool_env_var
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ForwardMetadata:
|
class ForwardMetadata:
|
||||||
@@ -54,7 +59,6 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.forward_metadata = None
|
self.forward_metadata = None
|
||||||
self.device = model_runner.device
|
self.device = model_runner.device
|
||||||
self.gen_attention_mask(128, model_runner.dtype)
|
|
||||||
self.page_size = model_runner.page_size
|
self.page_size = model_runner.page_size
|
||||||
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
|
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
|
||||||
if self.use_mla:
|
if self.use_mla:
|
||||||
@@ -65,6 +69,17 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
self.max_context_len = model_runner.model_config.context_len
|
self.max_context_len = model_runner.model_config.context_len
|
||||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
self.graph_mode = False
|
self.graph_mode = False
|
||||||
|
self.use_fia = get_bool_env_var("ASCEND_USE_FIA", "False")
|
||||||
|
if not self.use_fia:
|
||||||
|
self.gen_attention_mask(128, model_runner.dtype)
|
||||||
|
mask_length = 2048
|
||||||
|
self.fia_mask = ~torch.tril(
|
||||||
|
torch.ones(
|
||||||
|
(mask_length, mask_length),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=model_runner.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
"""Init the metadata for a forward pass."""
|
"""Init the metadata for a forward pass."""
|
||||||
@@ -81,6 +96,9 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
forward_batch.extend_seq_lens.cpu().int()
|
forward_batch.extend_seq_lens.cpu().int()
|
||||||
)
|
)
|
||||||
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
||||||
|
self.forward_metadata.seq_lens_list_cumsum = np.cumsum(
|
||||||
|
forward_batch.extend_seq_lens_cpu
|
||||||
|
)
|
||||||
|
|
||||||
self.graph_mode = False
|
self.graph_mode = False
|
||||||
|
|
||||||
@@ -151,71 +169,89 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
):
|
):
|
||||||
if save_kv_cache:
|
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
||||||
layer, forward_batch.out_cache_loc, k, v
|
|
||||||
)
|
|
||||||
|
|
||||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
|
||||||
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
|
||||||
|
|
||||||
if not self.use_mla:
|
if not self.use_mla:
|
||||||
query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
if save_kv_cache:
|
||||||
output = torch.empty(
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
(query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
|
layer, forward_batch.out_cache_loc, k, v
|
||||||
dtype=query.dtype,
|
)
|
||||||
device=query.device,
|
|
||||||
)
|
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
|
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
||||||
|
|
||||||
|
if self.use_fia:
|
||||||
|
"""FIA will support multi-bs in the later version of CANN"""
|
||||||
|
q = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||||
|
attn_output = torch.empty(
|
||||||
|
(q.size(0), layer.tp_q_head_num, layer.v_head_dim),
|
||||||
|
device=q.device,
|
||||||
|
dtype=q.dtype,
|
||||||
|
)
|
||||||
|
q_len_offset = 0
|
||||||
|
for q_len in forward_batch.extend_seq_lens_cpu:
|
||||||
|
attn_output[q_len_offset : q_len_offset + q_len] = (
|
||||||
|
torch.ops.npu.npu_fused_infer_attention_score(
|
||||||
|
q[None, q_len_offset : q_len_offset + q_len],
|
||||||
|
k[None, q_len_offset : q_len_offset + q_len],
|
||||||
|
v[None, q_len_offset : q_len_offset + q_len],
|
||||||
|
num_heads=layer.tp_q_head_num,
|
||||||
|
num_key_value_heads=layer.tp_k_head_num,
|
||||||
|
input_layout="BSND", # todo, TND not supports q_heads!=k_heads
|
||||||
|
atten_mask=self.fia_mask.unsqueeze(0),
|
||||||
|
sparse_mode=3,
|
||||||
|
scale=layer.scaling,
|
||||||
|
next_tokens=0,
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
q_len_offset += q_len
|
||||||
|
attn_output = attn_output.view(
|
||||||
|
-1, layer.tp_q_head_num * layer.v_head_dim
|
||||||
|
)
|
||||||
|
|
||||||
torch_npu._npu_flash_attention_qlens(
|
|
||||||
query=query,
|
|
||||||
key_cache=k_cache,
|
|
||||||
value_cache=v_cache,
|
|
||||||
mask=self.mask,
|
|
||||||
block_table=self.forward_metadata.block_tables,
|
|
||||||
seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
|
|
||||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
|
||||||
scale_value=layer.scaling,
|
|
||||||
num_heads=layer.tp_q_head_num,
|
|
||||||
num_kv_heads=layer.tp_k_head_num,
|
|
||||||
out=output,
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
else:
|
|
||||||
if layer.qk_head_dim != layer.v_head_dim:
|
|
||||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
|
||||||
else:
|
else:
|
||||||
o = torch.empty_like(q)
|
query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||||
|
attn_output = torch.empty(
|
||||||
|
(query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
|
||||||
|
dtype=query.dtype,
|
||||||
|
device=query.device,
|
||||||
|
)
|
||||||
|
|
||||||
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
torch_npu._npu_flash_attention_qlens(
|
||||||
|
query=query,
|
||||||
|
key_cache=k_cache,
|
||||||
|
value_cache=v_cache,
|
||||||
|
mask=self.mask,
|
||||||
|
block_table=self.forward_metadata.block_tables,
|
||||||
|
seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
|
||||||
|
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||||
|
scale_value=layer.scaling,
|
||||||
|
num_heads=layer.tp_q_head_num,
|
||||||
|
num_kv_heads=layer.tp_k_head_num,
|
||||||
|
out=attn_output,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
layer.qk_head_dim != layer.v_head_dim
|
||||||
|
), "FIA only supports qk_head_dim != v_head_dim"
|
||||||
|
q_nope, q_rope = q.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||||
|
k_nope, k_rope = k.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||||
|
|
||||||
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
||||||
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
q_nope,
|
||||||
|
k_nope,
|
||||||
causal = True
|
v,
|
||||||
if (
|
query_rope=q_rope,
|
||||||
layer.is_cross_attention
|
key_rope=k_rope,
|
||||||
or layer.attn_type == AttentionType.ENCODER_ONLY
|
num_heads=layer.tp_q_head_num,
|
||||||
):
|
input_layout="TND",
|
||||||
causal = False
|
atten_mask=self.fia_mask,
|
||||||
|
sparse_mode=3,
|
||||||
self.native_attn._run_sdpa_forward_extend(
|
actual_seq_lengths=self.forward_metadata.seq_lens_list_cumsum,
|
||||||
q_,
|
actual_seq_lengths_kv=self.forward_metadata.seq_lens_list_cumsum,
|
||||||
o_,
|
scale=layer.scaling,
|
||||||
k_cache.view(
|
next_tokens=0,
|
||||||
-1, layer.tp_k_head_num, (self.kv_lora_rank + self.qk_rope_head_dim)
|
|
||||||
),
|
|
||||||
v_cache.view(-1, layer.tp_v_head_num, self.kv_lora_rank),
|
|
||||||
forward_batch.req_to_token_pool.req_to_token,
|
|
||||||
forward_batch.req_pool_indices,
|
|
||||||
forward_batch.seq_lens,
|
|
||||||
forward_batch.extend_prefix_lens,
|
|
||||||
forward_batch.extend_seq_lens,
|
|
||||||
scaling=layer.scaling,
|
|
||||||
enable_gqa=use_gqa,
|
|
||||||
causal=causal,
|
|
||||||
)
|
)
|
||||||
return o
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
def forward_decode(
|
def forward_decode(
|
||||||
self,
|
self,
|
||||||
@@ -224,13 +260,17 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
layer: RadixAttention,
|
layer: RadixAttention,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache: bool = False,
|
||||||
|
# For multi-head latent attention
|
||||||
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
|
k_rope: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
if save_kv_cache:
|
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
||||||
layer, forward_batch.out_cache_loc, k, v
|
|
||||||
)
|
|
||||||
if not self.use_mla:
|
if not self.use_mla:
|
||||||
|
if save_kv_cache:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer, forward_batch.out_cache_loc, k, v
|
||||||
|
)
|
||||||
|
num_tokens = q.shape[0]
|
||||||
if self.graph_mode:
|
if self.graph_mode:
|
||||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
||||||
layer.layer_id
|
layer.layer_id
|
||||||
@@ -239,7 +279,6 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
layer.layer_id
|
layer.layer_id
|
||||||
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
|
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
|
||||||
query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
|
query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||||
num_tokens = query.shape[0]
|
|
||||||
workspace = (
|
workspace = (
|
||||||
torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||||
query,
|
query,
|
||||||
@@ -254,7 +293,7 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
|
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
output = torch.empty(
|
attn_output = torch.empty(
|
||||||
(num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
|
(num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
|
||||||
dtype=q.dtype,
|
dtype=q.dtype,
|
||||||
device=q.device,
|
device=q.device,
|
||||||
@@ -272,61 +311,129 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
scale=layer.scaling,
|
scale=layer.scaling,
|
||||||
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
|
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
|
||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
out=[output, softmax_lse],
|
out=[attn_output, softmax_lse],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
|
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
|
||||||
layer.layer_id
|
layer.layer_id
|
||||||
)
|
)
|
||||||
|
if self.use_fia:
|
||||||
|
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
||||||
|
q.view(
|
||||||
|
forward_batch.batch_size,
|
||||||
|
-1,
|
||||||
|
layer.tp_q_head_num,
|
||||||
|
layer.qk_head_dim,
|
||||||
|
),
|
||||||
|
k_cache.view(
|
||||||
|
-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim
|
||||||
|
),
|
||||||
|
v_cache.view(
|
||||||
|
-1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim
|
||||||
|
),
|
||||||
|
num_heads=layer.tp_q_head_num,
|
||||||
|
num_key_value_heads=layer.tp_k_head_num,
|
||||||
|
input_layout="BSND",
|
||||||
|
atten_mask=None,
|
||||||
|
block_size=self.page_size,
|
||||||
|
block_table=self.forward_metadata.block_tables,
|
||||||
|
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
|
||||||
|
scale=layer.scaling,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||||
|
attn_output = torch.empty(
|
||||||
|
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
|
||||||
|
dtype=query.dtype,
|
||||||
|
device=query.device,
|
||||||
|
)
|
||||||
|
|
||||||
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
torch_npu._npu_paged_attention(
|
||||||
num_tokens = query.shape[0]
|
query=query,
|
||||||
output = torch.empty(
|
key_cache=k_cache,
|
||||||
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
|
value_cache=v_cache,
|
||||||
dtype=query.dtype,
|
num_heads=layer.tp_q_head_num,
|
||||||
device=query.device,
|
num_kv_heads=layer.tp_k_head_num,
|
||||||
|
scale_value=layer.scaling,
|
||||||
|
block_table=self.forward_metadata.block_tables,
|
||||||
|
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||||
|
out=attn_output,
|
||||||
|
)
|
||||||
|
return attn_output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
|
else:
|
||||||
|
if save_kv_cache:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer, forward_batch.out_cache_loc, k, k_rope
|
||||||
)
|
)
|
||||||
|
num_tokens = q.shape[0]
|
||||||
|
kv_c = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
|
k_pe = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
||||||
|
|
||||||
torch_npu._npu_paged_attention(
|
if (self.graph_mode or self.use_fia) and (
|
||||||
query=query,
|
layer.tp_q_head_num // layer.tp_k_head_num
|
||||||
key_cache=k_cache,
|
) >= 8:
|
||||||
value_cache=v_cache,
|
"""layer.tp_q_head_num // layer.tp_k_head_num < 8 will support in the later version of CANN"""
|
||||||
|
kv_c = kv_c.view(
|
||||||
|
-1, self.page_size, layer.tp_k_head_num * self.kv_lora_rank
|
||||||
|
)
|
||||||
|
k_pe = k_pe.view(
|
||||||
|
-1, self.page_size, layer.tp_k_head_num * self.qk_rope_head_dim
|
||||||
|
)
|
||||||
|
q = q.view(
|
||||||
|
forward_batch.batch_size, -1, layer.tp_q_head_num, self.kv_lora_rank
|
||||||
|
)
|
||||||
|
q_rope = q_rope.view(
|
||||||
|
forward_batch.batch_size,
|
||||||
|
-1,
|
||||||
|
layer.tp_q_head_num,
|
||||||
|
self.qk_rope_head_dim,
|
||||||
|
)
|
||||||
|
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
||||||
|
q,
|
||||||
|
kv_c,
|
||||||
|
kv_c,
|
||||||
|
query_rope=q_rope,
|
||||||
|
key_rope=k_pe,
|
||||||
num_heads=layer.tp_q_head_num,
|
num_heads=layer.tp_q_head_num,
|
||||||
|
num_key_value_heads=layer.tp_k_head_num,
|
||||||
|
input_layout="BSND",
|
||||||
|
atten_mask=None,
|
||||||
|
sparse_mode=0,
|
||||||
|
scale=layer.scaling,
|
||||||
|
antiquant_mode=0,
|
||||||
|
antiquant_scale=None,
|
||||||
|
block_table=self.forward_metadata.block_tables,
|
||||||
|
block_size=self.page_size,
|
||||||
|
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
self.graph_mode == False
|
||||||
|
) # _npu_paged_attention_mla not support graph mode
|
||||||
|
q = torch.cat([q, q_rope], dim=-1)
|
||||||
|
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
kv_c_and_k_pe_cache = torch.cat([kv_c, k_pe], dim=-1)
|
||||||
|
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
|
||||||
|
-1,
|
||||||
|
self.page_size,
|
||||||
|
layer.tp_k_head_num,
|
||||||
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
|
)
|
||||||
|
attn_output = torch.empty(
|
||||||
|
[num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
|
||||||
|
dtype=q.dtype,
|
||||||
|
device=q.device,
|
||||||
|
)
|
||||||
|
torch_npu._npu_paged_attention_mla(
|
||||||
|
query=query,
|
||||||
|
key_cache=kv_c_and_k_pe_cache,
|
||||||
num_kv_heads=layer.tp_k_head_num,
|
num_kv_heads=layer.tp_k_head_num,
|
||||||
|
num_heads=layer.tp_q_head_num,
|
||||||
scale_value=layer.scaling,
|
scale_value=layer.scaling,
|
||||||
block_table=self.forward_metadata.block_tables,
|
block_table=self.forward_metadata.block_tables,
|
||||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||||
out=output,
|
mla_vheadsize=self.kv_lora_rank,
|
||||||
|
out=attn_output,
|
||||||
)
|
)
|
||||||
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
|
|
||||||
else:
|
|
||||||
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
|
||||||
num_tokens = query.shape[0]
|
|
||||||
kv_c_and_k_pe_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
|
||||||
layer.layer_id
|
|
||||||
)
|
|
||||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
|
|
||||||
-1,
|
|
||||||
self.page_size,
|
|
||||||
layer.tp_k_head_num,
|
|
||||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = torch.empty(
|
|
||||||
[num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
|
|
||||||
dtype=q.dtype,
|
|
||||||
device=q.device,
|
|
||||||
)
|
|
||||||
torch_npu._npu_paged_attention_mla(
|
|
||||||
query=query,
|
|
||||||
key_cache=kv_c_and_k_pe_cache,
|
|
||||||
num_kv_heads=layer.tp_k_head_num,
|
|
||||||
num_heads=layer.tp_q_head_num,
|
|
||||||
scale_value=layer.scaling,
|
|
||||||
block_table=self.forward_metadata.block_tables,
|
|
||||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
|
||||||
mla_vheadsize=self.kv_lora_rank,
|
|
||||||
out=attn_output,
|
|
||||||
)
|
|
||||||
return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)
|
return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)
|
||||||
|
|||||||
@@ -304,7 +304,7 @@ class TopK(CustomOp):
|
|||||||
global_num_experts = router_logits.shape[-1]
|
global_num_experts = router_logits.shape[-1]
|
||||||
|
|
||||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||||
if global_num_experts == 256 and self.topk_config.renormalize is False:
|
if global_num_experts == 256 and self.topk_config.renormalize is True:
|
||||||
|
|
||||||
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
|
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
|
||||||
router_logits = router_logits.to(torch.float32)
|
router_logits = router_logits.to(torch.float32)
|
||||||
|
|||||||
@@ -36,12 +36,15 @@ import triton.language as tl
|
|||||||
|
|
||||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
|
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
GB = 1024 * 1024 * 1024
|
GB = 1024 * 1024 * 1024
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
_is_npu = is_npu()
|
||||||
|
if _is_npu:
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
|
||||||
class ReqToTokenPool:
|
class ReqToTokenPool:
|
||||||
@@ -624,8 +627,6 @@ class AscendTokenToKVPool(MHATokenToKVPool):
|
|||||||
cache_k = cache_k.view(self.store_dtype)
|
cache_k = cache_k.view(self.store_dtype)
|
||||||
cache_v = cache_v.view(self.store_dtype)
|
cache_v = cache_v.view(self.store_dtype)
|
||||||
|
|
||||||
import torch_npu
|
|
||||||
|
|
||||||
torch_npu._npu_reshape_and_cache(
|
torch_npu._npu_reshape_and_cache(
|
||||||
key=cache_k,
|
key=cache_k,
|
||||||
value=cache_v,
|
value=cache_v,
|
||||||
@@ -912,12 +913,22 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|||||||
|
|
||||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||||
self.kv_buffer = torch.zeros(
|
self.k_buffer = torch.zeros(
|
||||||
(
|
(
|
||||||
layer_num,
|
layer_num,
|
||||||
self.size // self.page_size + 1,
|
self.size // self.page_size + 1,
|
||||||
self.page_size,
|
self.page_size,
|
||||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
self.kv_lora_rank,
|
||||||
|
),
|
||||||
|
dtype=self.store_dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
self.v_buffer = torch.zeros(
|
||||||
|
(
|
||||||
|
layer_num,
|
||||||
|
self.size // self.page_size + 1,
|
||||||
|
self.page_size,
|
||||||
|
self.qk_rope_head_dim,
|
||||||
),
|
),
|
||||||
dtype=self.store_dtype,
|
dtype=self.store_dtype,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
@@ -931,12 +942,52 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|||||||
)
|
)
|
||||||
self.mem_usage = kv_size / GB
|
self.mem_usage = kv_size / GB
|
||||||
|
|
||||||
|
def get_kv_size_bytes(self):
|
||||||
|
assert hasattr(self, "k_buffer")
|
||||||
|
assert hasattr(self, "v_buffer")
|
||||||
|
kv_size_bytes = 0
|
||||||
|
for k_cache in self.k_buffer:
|
||||||
|
kv_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
|
||||||
|
for v_cache in self.v_buffer:
|
||||||
|
kv_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
|
||||||
|
return kv_size_bytes
|
||||||
|
|
||||||
|
def get_kv_buffer(self, layer_id: int):
|
||||||
|
if self.layer_transfer_counter is not None:
|
||||||
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||||
|
return (
|
||||||
|
self.k_buffer[layer_id - self.start_layer],
|
||||||
|
self.v_buffer[layer_id - self.start_layer],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_key_buffer(self, layer_id: int):
|
||||||
|
if self.layer_transfer_counter is not None:
|
||||||
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||||
|
|
||||||
|
if self.store_dtype != self.dtype:
|
||||||
|
return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
|
||||||
|
return self.k_buffer[layer_id - self.start_layer]
|
||||||
|
|
||||||
|
def get_value_buffer(self, layer_id: int):
|
||||||
|
if self.layer_transfer_counter is not None:
|
||||||
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||||
|
|
||||||
|
if self.store_dtype != self.dtype:
|
||||||
|
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
|
||||||
|
return self.v_buffer[layer_id - self.start_layer]
|
||||||
|
|
||||||
# for disagg
|
# for disagg
|
||||||
def get_contiguous_buf_infos(self):
|
def get_contiguous_buf_infos(self):
|
||||||
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
|
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
|
||||||
kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
|
kv_data_ptrs = [self.k_buffer[i].data_ptr() for i in range(self.layer_num)] + [
|
||||||
kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
|
self.v_buffer[i].data_ptr() for i in range(self.layer_num)
|
||||||
kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)]
|
]
|
||||||
|
kv_data_lens = [self.k_buffer[i].nbytes for i in range(self.layer_num)] + [
|
||||||
|
self.v_buffer[i].nbytes for i in range(self.layer_num)
|
||||||
|
]
|
||||||
|
kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
|
||||||
|
self.v_buffer[i][0].nbytes for i in range(self.layer_num)
|
||||||
|
]
|
||||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||||
|
|
||||||
def set_kv_buffer(
|
def set_kv_buffer(
|
||||||
@@ -953,14 +1004,22 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|||||||
if self.store_dtype != self.dtype:
|
if self.store_dtype != self.dtype:
|
||||||
cache_k = cache_k.view(self.store_dtype)
|
cache_k = cache_k.view(self.store_dtype)
|
||||||
|
|
||||||
import torch_npu
|
if cache_v is None:
|
||||||
|
cache_k, cache_v = cache_k.split(
|
||||||
|
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
torch_npu._npu_reshape_and_cache_siso(
|
torch_npu.npu_scatter_nd_update_(
|
||||||
key=cache_k.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim),
|
self.k_buffer[layer_id - self.start_layer].view(-1, 1, self.kv_lora_rank),
|
||||||
key_cache=self.kv_buffer[layer_id - self.start_layer].view(
|
loc.view(-1, 1),
|
||||||
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
|
cache_k.view(-1, 1, self.kv_lora_rank),
|
||||||
|
)
|
||||||
|
torch_npu.npu_scatter_nd_update_(
|
||||||
|
self.v_buffer[layer_id - self.start_layer].view(
|
||||||
|
-1, 1, self.qk_rope_head_dim
|
||||||
),
|
),
|
||||||
slot_indices=loc,
|
loc.view(-1, 1),
|
||||||
|
cache_v.view(-1, 1, self.qk_rope_head_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -994,7 +994,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
self.current_attention_backend = attention_backend
|
self.current_attention_backend = attention_backend
|
||||||
|
|
||||||
if attention_backend == "ascend":
|
if attention_backend == "ascend":
|
||||||
return AttnForwardMethod.MLA
|
if (
|
||||||
|
forward_batch.forward_mode.is_extend()
|
||||||
|
and not forward_batch.forward_mode.is_target_verify()
|
||||||
|
and not forward_batch.forward_mode.is_draft_extend()
|
||||||
|
):
|
||||||
|
return AttnForwardMethod.MHA
|
||||||
|
else:
|
||||||
|
return AttnForwardMethod.MLA
|
||||||
elif (
|
elif (
|
||||||
attention_backend == "flashinfer"
|
attention_backend == "flashinfer"
|
||||||
or attention_backend == "fa3"
|
or attention_backend == "fa3"
|
||||||
@@ -1292,6 +1299,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
or self.current_attention_backend == "flashinfer"
|
or self.current_attention_backend == "flashinfer"
|
||||||
or self.current_attention_backend == "cutlass_mla"
|
or self.current_attention_backend == "cutlass_mla"
|
||||||
or self.current_attention_backend == "trtllm_mla"
|
or self.current_attention_backend == "trtllm_mla"
|
||||||
|
or self.current_attention_backend == "ascend"
|
||||||
):
|
):
|
||||||
extra_args = {}
|
extra_args = {}
|
||||||
if self._fuse_rope_for_trtllm_mla(forward_batch):
|
if self._fuse_rope_for_trtllm_mla(forward_batch):
|
||||||
|
|||||||
103
test/srt/ascend/test_ascend_mla_fia_w8a8int8.py
Normal file
103
test/srt/ascend/test_ascend_mla_fia_w8a8int8.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
is_in_ci,
|
||||||
|
popen_launch_server,
|
||||||
|
run_bench_offline_throughput,
|
||||||
|
)
|
||||||
|
|
||||||
|
TEST_MODEL_MATRIX = {
|
||||||
|
"/root/.cache/modelscope/hub/models/vllm-ascend/DeepSeek-V2-Lite-W8A8": {
|
||||||
|
"accuracy": 0.34,
|
||||||
|
"latency": 1000,
|
||||||
|
"output_throughput": 6,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestAscendMlaW8A8Int8(CustomTestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.models = TEST_MODEL_MATRIX.keys()
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
|
||||||
|
cls.common_args = [
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--disable-cuda-graph",
|
||||||
|
"--mem-fraction-static",
|
||||||
|
0.8,
|
||||||
|
"--attention-backend",
|
||||||
|
"ascend",
|
||||||
|
"--quantization",
|
||||||
|
"w8a8_int8",
|
||||||
|
"--tp-size",
|
||||||
|
2,
|
||||||
|
"--disable-radix-cache",
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_a_gsm8k(self):
|
||||||
|
os.environ["ASCEND_USE_FIA"] = "true"
|
||||||
|
for model in self.models:
|
||||||
|
with self.subTest(model=model):
|
||||||
|
print(f"##=== Testing accuracy: {model} ===##")
|
||||||
|
|
||||||
|
process = popen_launch_server(
|
||||||
|
model,
|
||||||
|
self.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
*self.common_args,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=1319,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host=f"http://{self.url.hostname}",
|
||||||
|
port=int(self.url.port),
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval_few_shot_gsm8k(args)
|
||||||
|
self.assertGreaterEqual(
|
||||||
|
metrics["accuracy"],
|
||||||
|
TEST_MODEL_MATRIX[model]["accuracy"],
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
kill_process_tree(process.pid)
|
||||||
|
|
||||||
|
def test_b_throughput(self):
|
||||||
|
for model in self.models:
|
||||||
|
with self.subTest(model=model):
|
||||||
|
print(f"##=== Testing throughput: {model} ===##")
|
||||||
|
|
||||||
|
output_throughput = run_bench_offline_throughput(
|
||||||
|
model,
|
||||||
|
[
|
||||||
|
*self.common_args,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"##=== {model} throughput: {output_throughput} ===##")
|
||||||
|
|
||||||
|
if is_in_ci():
|
||||||
|
self.assertGreater(
|
||||||
|
output_throughput,
|
||||||
|
TEST_MODEL_MATRIX[model]["output_throughput"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -40,6 +40,7 @@ class TestAscendMlaW8A8Int8(CustomTestCase):
|
|||||||
"w8a8_int8",
|
"w8a8_int8",
|
||||||
"--tp-size",
|
"--tp-size",
|
||||||
4,
|
4,
|
||||||
|
"--disable-radix-cache",
|
||||||
]
|
]
|
||||||
|
|
||||||
def test_a_gsm8k(self):
|
def test_a_gsm8k(self):
|
||||||
|
|||||||
101
test/srt/ascend/test_ascend_tp2_fia_bf16.py
Normal file
101
test/srt/ascend/test_ascend_tp2_fia_bf16.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
is_in_ci,
|
||||||
|
popen_launch_server,
|
||||||
|
run_bench_offline_throughput,
|
||||||
|
)
|
||||||
|
|
||||||
|
TEST_MODEL_MATRIX = {
|
||||||
|
"Qwen/Qwen2.5-7B-Instruct": {
|
||||||
|
"accuracy": 0.85,
|
||||||
|
"latency": 180,
|
||||||
|
"output_throughput": 20,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestAscendTp2Bf16(CustomTestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.models = TEST_MODEL_MATRIX.keys()
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
|
||||||
|
cls.common_args = [
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--disable-cuda-graph",
|
||||||
|
"--mem-fraction-static",
|
||||||
|
0.8,
|
||||||
|
"--attention-backend",
|
||||||
|
"ascend",
|
||||||
|
"--tp-size",
|
||||||
|
2,
|
||||||
|
"--disable-radix-cache",
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_a_gsm8k(self):
|
||||||
|
os.environ["ASCEND_USE_FIA"] = "true"
|
||||||
|
for model in self.models:
|
||||||
|
with self.subTest(model=model):
|
||||||
|
print(f"##=== Testing accuracy: {model} ===##")
|
||||||
|
|
||||||
|
process = popen_launch_server(
|
||||||
|
model,
|
||||||
|
self.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
*self.common_args,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=1319,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host=f"http://{self.url.hostname}",
|
||||||
|
port=int(self.url.port),
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval_few_shot_gsm8k(args)
|
||||||
|
self.assertGreaterEqual(
|
||||||
|
metrics["accuracy"],
|
||||||
|
TEST_MODEL_MATRIX[model]["accuracy"],
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
kill_process_tree(process.pid)
|
||||||
|
|
||||||
|
def test_b_throughput(self):
|
||||||
|
for model in self.models:
|
||||||
|
with self.subTest(model=model):
|
||||||
|
print(f"##=== Testing throughput: {model} ===##")
|
||||||
|
|
||||||
|
output_throughput = run_bench_offline_throughput(
|
||||||
|
model,
|
||||||
|
[
|
||||||
|
*self.common_args,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"##=== {model} throughput: {output_throughput} ===##")
|
||||||
|
|
||||||
|
if is_in_ci():
|
||||||
|
self.assertGreater(
|
||||||
|
output_throughput,
|
||||||
|
TEST_MODEL_MATRIX[model]["output_throughput"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -275,6 +275,8 @@ suite_ascend = {
|
|||||||
"per-commit-2-ascend-npu": [
|
"per-commit-2-ascend-npu": [
|
||||||
TestFile("ascend/test_ascend_tp2_bf16.py", 400),
|
TestFile("ascend/test_ascend_tp2_bf16.py", 400),
|
||||||
TestFile("ascend/test_ascend_graph_tp2_bf16.py", 400),
|
TestFile("ascend/test_ascend_graph_tp2_bf16.py", 400),
|
||||||
|
TestFile("ascend/test_ascend_tp2_fia_bf16.py", 400),
|
||||||
|
TestFile("ascend/test_ascend_mla_fia_w8a8int8.py", 400),
|
||||||
],
|
],
|
||||||
"per-commit-4-ascend-npu": [
|
"per-commit-4-ascend-npu": [
|
||||||
TestFile("ascend/test_ascend_mla_w8a8int8.py", 400),
|
TestFile("ascend/test_ascend_mla_w8a8int8.py", 400),
|
||||||
|
|||||||
Reference in New Issue
Block a user