diff --git a/.github/workflows/pr-test-npu.yml b/.github/workflows/pr-test-npu.yml index 45c115dbe..528ba80af 100644 --- a/.github/workflows/pr-test-npu.yml +++ b/.github/workflows/pr-test-npu.yml @@ -47,7 +47,7 @@ jobs: curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl - name: Run test - timeout-minutes: 30 + timeout-minutes: 60 env: SGLANG_USE_MODELSCOPE: true SGLANG_IS_IN_CI: true @@ -82,7 +82,7 @@ jobs: curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl - name: Run test - timeout-minutes: 30 + timeout-minutes: 90 env: SGLANG_USE_MODELSCOPE: true SGLANG_IS_IN_CI: true @@ -117,7 +117,7 @@ jobs: curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl - name: Run test - timeout-minutes: 30 + timeout-minutes: 60 env: SGLANG_USE_MODELSCOPE: true SGLANG_IS_IN_CI: true diff --git a/python/sglang/srt/layers/attention/ascend_backend.py b/python/sglang/srt/layers/attention/ascend_backend.py index c1f4c2785..f5b521d20 100644 --- a/python/sglang/srt/layers/attention/ascend_backend.py +++ b/python/sglang/srt/layers/attention/ascend_backend.py @@ -12,11 +12,16 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.utils import get_bool_env_var if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner +import os + +import numpy as np + @dataclass class ForwardMetadata: @@ -54,7 +59,6 @@ class AscendAttnBackend(AttentionBackend): super().__init__() self.forward_metadata = None self.device = model_runner.device - self.gen_attention_mask(128, model_runner.dtype) self.page_size = model_runner.page_size self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA if self.use_mla: @@ -65,6 +69,17 @@ class AscendAttnBackend(AttentionBackend): self.max_context_len = model_runner.model_config.context_len self.req_to_token = model_runner.req_to_token_pool.req_to_token self.graph_mode = False + self.use_fia = get_bool_env_var("ASCEND_USE_FIA", "False") + if not self.use_fia: + self.gen_attention_mask(128, model_runner.dtype) + mask_length = 2048 + self.fia_mask = ~torch.tril( + torch.ones( + (mask_length, mask_length), + dtype=torch.bool, + device=model_runner.device, + ) + ) def init_forward_metadata(self, forward_batch: ForwardBatch): """Init the metadata for a forward pass.""" @@ -81,6 +96,9 @@ class AscendAttnBackend(AttentionBackend): forward_batch.extend_seq_lens.cpu().int() ) self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int() + self.forward_metadata.seq_lens_list_cumsum = np.cumsum( + forward_batch.extend_seq_lens_cpu + ) self.graph_mode = False @@ -151,71 +169,89 @@ class AscendAttnBackend(AttentionBackend): forward_batch: ForwardBatch, save_kv_cache=True, ): - if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v - ) - - k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) - if not self.use_mla: - query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim) - output = torch.empty( - (query.shape[0], layer.tp_q_head_num * layer.v_head_dim), - dtype=query.dtype, - device=query.device, - ) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) + + k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) + + if self.use_fia: + """FIA will support multi-bs in the later version of CANN""" + q = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) + attn_output = torch.empty( + (q.size(0), layer.tp_q_head_num, layer.v_head_dim), + device=q.device, + dtype=q.dtype, + ) + q_len_offset = 0 + for q_len in forward_batch.extend_seq_lens_cpu: + attn_output[q_len_offset : q_len_offset + q_len] = ( + torch.ops.npu.npu_fused_infer_attention_score( + q[None, q_len_offset : q_len_offset + q_len], + k[None, q_len_offset : q_len_offset + q_len], + v[None, q_len_offset : q_len_offset + q_len], + num_heads=layer.tp_q_head_num, + num_key_value_heads=layer.tp_k_head_num, + input_layout="BSND", # todo, TND not supports q_heads!=k_heads + atten_mask=self.fia_mask.unsqueeze(0), + sparse_mode=3, + scale=layer.scaling, + next_tokens=0, + )[0] + ) + q_len_offset += q_len + attn_output = attn_output.view( + -1, layer.tp_q_head_num * layer.v_head_dim + ) - torch_npu._npu_flash_attention_qlens( - query=query, - key_cache=k_cache, - value_cache=v_cache, - mask=self.mask, - block_table=self.forward_metadata.block_tables, - seq_len=self.forward_metadata.extend_seq_lens_cpu_int, - context_lens=self.forward_metadata.seq_lens_cpu_int, - scale_value=layer.scaling, - num_heads=layer.tp_q_head_num, - num_kv_heads=layer.tp_k_head_num, - out=output, - ) - return output - else: - if layer.qk_head_dim != layer.v_head_dim: - o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) else: - o = torch.empty_like(q) + query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim) + attn_output = torch.empty( + (query.shape[0], layer.tp_q_head_num * layer.v_head_dim), + dtype=query.dtype, + device=query.device, + ) - use_gqa = layer.tp_q_head_num != layer.tp_k_head_num + torch_npu._npu_flash_attention_qlens( + query=query, + key_cache=k_cache, + value_cache=v_cache, + mask=self.mask, + block_table=self.forward_metadata.block_tables, + seq_len=self.forward_metadata.extend_seq_lens_cpu_int, + context_lens=self.forward_metadata.seq_lens_cpu_int, + scale_value=layer.scaling, + num_heads=layer.tp_q_head_num, + num_kv_heads=layer.tp_k_head_num, + out=attn_output, + ) + else: + assert ( + layer.qk_head_dim != layer.v_head_dim + ), "FIA only supports qk_head_dim != v_head_dim" + q_nope, q_rope = q.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1) + k_nope, k_rope = k.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1) - q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) - o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim) - - causal = True - if ( - layer.is_cross_attention - or layer.attn_type == AttentionType.ENCODER_ONLY - ): - causal = False - - self.native_attn._run_sdpa_forward_extend( - q_, - o_, - k_cache.view( - -1, layer.tp_k_head_num, (self.kv_lora_rank + self.qk_rope_head_dim) - ), - v_cache.view(-1, layer.tp_v_head_num, self.kv_lora_rank), - forward_batch.req_to_token_pool.req_to_token, - forward_batch.req_pool_indices, - forward_batch.seq_lens, - forward_batch.extend_prefix_lens, - forward_batch.extend_seq_lens, - scaling=layer.scaling, - enable_gqa=use_gqa, - causal=causal, + attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score( + q_nope, + k_nope, + v, + query_rope=q_rope, + key_rope=k_rope, + num_heads=layer.tp_q_head_num, + input_layout="TND", + atten_mask=self.fia_mask, + sparse_mode=3, + actual_seq_lengths=self.forward_metadata.seq_lens_list_cumsum, + actual_seq_lengths_kv=self.forward_metadata.seq_lens_list_cumsum, + scale=layer.scaling, + next_tokens=0, ) - return o + + return attn_output def forward_decode( self, @@ -224,13 +260,17 @@ class AscendAttnBackend(AttentionBackend): v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, - save_kv_cache=True, + save_kv_cache: bool = False, + # For multi-head latent attention + q_rope: Optional[torch.Tensor] = None, + k_rope: Optional[torch.Tensor] = None, ): - if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v - ) if not self.use_mla: + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) + num_tokens = q.shape[0] if self.graph_mode: k_cache = forward_batch.token_to_kv_pool.get_key_buffer( layer.layer_id @@ -239,7 +279,6 @@ class AscendAttnBackend(AttentionBackend): layer.layer_id ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim) query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim) - num_tokens = query.shape[0] workspace = ( torch_npu._npu_fused_infer_attention_score_get_max_workspace( query, @@ -254,7 +293,7 @@ class AscendAttnBackend(AttentionBackend): actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list, ) ) - output = torch.empty( + attn_output = torch.empty( (num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim), dtype=q.dtype, device=q.device, @@ -272,61 +311,129 @@ class AscendAttnBackend(AttentionBackend): scale=layer.scaling, actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list, workspace=workspace, - out=[output, softmax_lse], + out=[attn_output, softmax_lse], ) else: k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) v_cache = forward_batch.token_to_kv_pool.get_value_buffer( layer.layer_id ) + if self.use_fia: + attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score( + q.view( + forward_batch.batch_size, + -1, + layer.tp_q_head_num, + layer.qk_head_dim, + ), + k_cache.view( + -1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim + ), + v_cache.view( + -1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim + ), + num_heads=layer.tp_q_head_num, + num_key_value_heads=layer.tp_k_head_num, + input_layout="BSND", + atten_mask=None, + block_size=self.page_size, + block_table=self.forward_metadata.block_tables, + actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int, + scale=layer.scaling, + ) + else: + query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) + attn_output = torch.empty( + (num_tokens, layer.tp_q_head_num, layer.v_head_dim), + dtype=query.dtype, + device=query.device, + ) - query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) - num_tokens = query.shape[0] - output = torch.empty( - (num_tokens, layer.tp_q_head_num, layer.v_head_dim), - dtype=query.dtype, - device=query.device, + torch_npu._npu_paged_attention( + query=query, + key_cache=k_cache, + value_cache=v_cache, + num_heads=layer.tp_q_head_num, + num_kv_heads=layer.tp_k_head_num, + scale_value=layer.scaling, + block_table=self.forward_metadata.block_tables, + context_lens=self.forward_metadata.seq_lens_cpu_int, + out=attn_output, + ) + return attn_output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim) + else: + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, k_rope ) + num_tokens = q.shape[0] + kv_c = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + k_pe = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) - torch_npu._npu_paged_attention( - query=query, - key_cache=k_cache, - value_cache=v_cache, + if (self.graph_mode or self.use_fia) and ( + layer.tp_q_head_num // layer.tp_k_head_num + ) >= 8: + """layer.tp_q_head_num // layer.tp_k_head_num < 8 will support in the later version of CANN""" + kv_c = kv_c.view( + -1, self.page_size, layer.tp_k_head_num * self.kv_lora_rank + ) + k_pe = k_pe.view( + -1, self.page_size, layer.tp_k_head_num * self.qk_rope_head_dim + ) + q = q.view( + forward_batch.batch_size, -1, layer.tp_q_head_num, self.kv_lora_rank + ) + q_rope = q_rope.view( + forward_batch.batch_size, + -1, + layer.tp_q_head_num, + self.qk_rope_head_dim, + ) + attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score( + q, + kv_c, + kv_c, + query_rope=q_rope, + key_rope=k_pe, num_heads=layer.tp_q_head_num, + num_key_value_heads=layer.tp_k_head_num, + input_layout="BSND", + atten_mask=None, + sparse_mode=0, + scale=layer.scaling, + antiquant_mode=0, + antiquant_scale=None, + block_table=self.forward_metadata.block_tables, + block_size=self.page_size, + actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int, + ) + else: + assert ( + self.graph_mode == False + ) # _npu_paged_attention_mla not support graph mode + q = torch.cat([q, q_rope], dim=-1) + query = q.view(-1, layer.tp_q_head_num, layer.head_dim) + kv_c_and_k_pe_cache = torch.cat([kv_c, k_pe], dim=-1) + kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( + -1, + self.page_size, + layer.tp_k_head_num, + self.kv_lora_rank + self.qk_rope_head_dim, + ) + attn_output = torch.empty( + [num_tokens, layer.tp_q_head_num, self.kv_lora_rank], + dtype=q.dtype, + device=q.device, + ) + torch_npu._npu_paged_attention_mla( + query=query, + key_cache=kv_c_and_k_pe_cache, num_kv_heads=layer.tp_k_head_num, + num_heads=layer.tp_q_head_num, scale_value=layer.scaling, block_table=self.forward_metadata.block_tables, context_lens=self.forward_metadata.seq_lens_cpu_int, - out=output, + mla_vheadsize=self.kv_lora_rank, + out=attn_output, ) - return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim) - else: - query = q.view(-1, layer.tp_q_head_num, layer.head_dim) - num_tokens = query.shape[0] - kv_c_and_k_pe_cache = forward_batch.token_to_kv_pool.get_key_buffer( - layer.layer_id - ) - kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( - -1, - self.page_size, - layer.tp_k_head_num, - self.kv_lora_rank + self.qk_rope_head_dim, - ) - - attn_output = torch.empty( - [num_tokens, layer.tp_q_head_num, self.kv_lora_rank], - dtype=q.dtype, - device=q.device, - ) - torch_npu._npu_paged_attention_mla( - query=query, - key_cache=kv_c_and_k_pe_cache, - num_kv_heads=layer.tp_k_head_num, - num_heads=layer.tp_q_head_num, - scale_value=layer.scaling, - block_table=self.forward_metadata.block_tables, - context_lens=self.forward_metadata.seq_lens_cpu_int, - mla_vheadsize=self.kv_lora_rank, - out=attn_output, - ) return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 48296752d..3f8b4afd0 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -304,7 +304,7 @@ class TopK(CustomOp): global_num_experts = router_logits.shape[-1] # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - if global_num_experts == 256 and self.topk_config.renormalize is False: + if global_num_experts == 256 and self.topk_config.renormalize is True: routed_scaling_factor = self.topk_config.routed_scaling_factor or 1 router_logits = router_logits.to(torch.float32) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 1653d4535..142597b3a 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -36,12 +36,15 @@ import triton.language as tl from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2 +from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2 logger = logging.getLogger(__name__) GB = 1024 * 1024 * 1024 _is_cuda = is_cuda() +_is_npu = is_npu() +if _is_npu: + import torch_npu class ReqToTokenPool: @@ -624,8 +627,6 @@ class AscendTokenToKVPool(MHATokenToKVPool): cache_k = cache_k.view(self.store_dtype) cache_v = cache_v.view(self.store_dtype) - import torch_npu - torch_npu._npu_reshape_and_cache( key=cache_k, value=cache_v, @@ -912,12 +913,22 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool): with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): # The padded slot 0 is used for writing dummy outputs from padded tokens. - self.kv_buffer = torch.zeros( + self.k_buffer = torch.zeros( ( layer_num, self.size // self.page_size + 1, self.page_size, - self.kv_lora_rank + self.qk_rope_head_dim, + self.kv_lora_rank, + ), + dtype=self.store_dtype, + device=self.device, + ) + self.v_buffer = torch.zeros( + ( + layer_num, + self.size // self.page_size + 1, + self.page_size, + self.qk_rope_head_dim, ), dtype=self.store_dtype, device=self.device, @@ -931,12 +942,52 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool): ) self.mem_usage = kv_size / GB + def get_kv_size_bytes(self): + assert hasattr(self, "k_buffer") + assert hasattr(self, "v_buffer") + kv_size_bytes = 0 + for k_cache in self.k_buffer: + kv_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize + for v_cache in self.v_buffer: + kv_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize + return kv_size_bytes + + def get_kv_buffer(self, layer_id: int): + if self.layer_transfer_counter is not None: + self.layer_transfer_counter.wait_until(layer_id - self.start_layer) + return ( + self.k_buffer[layer_id - self.start_layer], + self.v_buffer[layer_id - self.start_layer], + ) + + def get_key_buffer(self, layer_id: int): + if self.layer_transfer_counter is not None: + self.layer_transfer_counter.wait_until(layer_id - self.start_layer) + + if self.store_dtype != self.dtype: + return self.k_buffer[layer_id - self.start_layer].view(self.dtype) + return self.k_buffer[layer_id - self.start_layer] + + def get_value_buffer(self, layer_id: int): + if self.layer_transfer_counter is not None: + self.layer_transfer_counter.wait_until(layer_id - self.start_layer) + + if self.store_dtype != self.dtype: + return self.v_buffer[layer_id - self.start_layer].view(self.dtype) + return self.v_buffer[layer_id - self.start_layer] + # for disagg def get_contiguous_buf_infos(self): # MLA has only one kv_buffer, so only the information of this buffer needs to be returned. - kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)] - kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)] - kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)] + kv_data_ptrs = [self.k_buffer[i].data_ptr() for i in range(self.layer_num)] + [ + self.v_buffer[i].data_ptr() for i in range(self.layer_num) + ] + kv_data_lens = [self.k_buffer[i].nbytes for i in range(self.layer_num)] + [ + self.v_buffer[i].nbytes for i in range(self.layer_num) + ] + kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [ + self.v_buffer[i][0].nbytes for i in range(self.layer_num) + ] return kv_data_ptrs, kv_data_lens, kv_item_lens def set_kv_buffer( @@ -953,14 +1004,22 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool): if self.store_dtype != self.dtype: cache_k = cache_k.view(self.store_dtype) - import torch_npu + if cache_v is None: + cache_k, cache_v = cache_k.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) - torch_npu._npu_reshape_and_cache_siso( - key=cache_k.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim), - key_cache=self.kv_buffer[layer_id - self.start_layer].view( - -1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim + torch_npu.npu_scatter_nd_update_( + self.k_buffer[layer_id - self.start_layer].view(-1, 1, self.kv_lora_rank), + loc.view(-1, 1), + cache_k.view(-1, 1, self.kv_lora_rank), + ) + torch_npu.npu_scatter_nd_update_( + self.v_buffer[layer_id - self.start_layer].view( + -1, 1, self.qk_rope_head_dim ), - slot_indices=loc, + loc.view(-1, 1), + cache_v.view(-1, 1, self.qk_rope_head_dim), ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index bf22528f0..c9305d06e 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -994,7 +994,14 @@ class DeepseekV2AttentionMLA(nn.Module): self.current_attention_backend = attention_backend if attention_backend == "ascend": - return AttnForwardMethod.MLA + if ( + forward_batch.forward_mode.is_extend() + and not forward_batch.forward_mode.is_target_verify() + and not forward_batch.forward_mode.is_draft_extend() + ): + return AttnForwardMethod.MHA + else: + return AttnForwardMethod.MLA elif ( attention_backend == "flashinfer" or attention_backend == "fa3" @@ -1292,6 +1299,7 @@ class DeepseekV2AttentionMLA(nn.Module): or self.current_attention_backend == "flashinfer" or self.current_attention_backend == "cutlass_mla" or self.current_attention_backend == "trtllm_mla" + or self.current_attention_backend == "ascend" ): extra_args = {} if self._fuse_rope_for_trtllm_mla(forward_batch): diff --git a/test/srt/ascend/test_ascend_mla_fia_w8a8int8.py b/test/srt/ascend/test_ascend_mla_fia_w8a8int8.py new file mode 100644 index 000000000..6de97b04d --- /dev/null +++ b/test/srt/ascend/test_ascend_mla_fia_w8a8int8.py @@ -0,0 +1,103 @@ +import os +import unittest +from types import SimpleNamespace +from urllib.parse import urlparse + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + is_in_ci, + popen_launch_server, + run_bench_offline_throughput, +) + +TEST_MODEL_MATRIX = { + "/root/.cache/modelscope/hub/models/vllm-ascend/DeepSeek-V2-Lite-W8A8": { + "accuracy": 0.34, + "latency": 1000, + "output_throughput": 6, + }, +} + + +class TestAscendMlaW8A8Int8(CustomTestCase): + + @classmethod + def setUpClass(cls): + cls.models = TEST_MODEL_MATRIX.keys() + cls.base_url = DEFAULT_URL_FOR_TEST + cls.url = urlparse(DEFAULT_URL_FOR_TEST) + cls.common_args = [ + "--trust-remote-code", + "--disable-cuda-graph", + "--mem-fraction-static", + 0.8, + "--attention-backend", + "ascend", + "--quantization", + "w8a8_int8", + "--tp-size", + 2, + "--disable-radix-cache", + ] + + def test_a_gsm8k(self): + os.environ["ASCEND_USE_FIA"] = "true" + for model in self.models: + with self.subTest(model=model): + print(f"##=== Testing accuracy: {model} ===##") + + process = popen_launch_server( + model, + self.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + *self.common_args, + ], + ) + + try: + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=1319, + max_new_tokens=512, + parallel=128, + host=f"http://{self.url.hostname}", + port=int(self.url.port), + ) + + metrics = run_eval_few_shot_gsm8k(args) + self.assertGreaterEqual( + metrics["accuracy"], + TEST_MODEL_MATRIX[model]["accuracy"], + ) + finally: + kill_process_tree(process.pid) + + def test_b_throughput(self): + for model in self.models: + with self.subTest(model=model): + print(f"##=== Testing throughput: {model} ===##") + + output_throughput = run_bench_offline_throughput( + model, + [ + *self.common_args, + ], + ) + + print(f"##=== {model} throughput: {output_throughput} ===##") + + if is_in_ci(): + self.assertGreater( + output_throughput, + TEST_MODEL_MATRIX[model]["output_throughput"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/ascend/test_ascend_mla_w8a8int8.py b/test/srt/ascend/test_ascend_mla_w8a8int8.py index cdbc52023..70f7edab4 100644 --- a/test/srt/ascend/test_ascend_mla_w8a8int8.py +++ b/test/srt/ascend/test_ascend_mla_w8a8int8.py @@ -40,6 +40,7 @@ class TestAscendMlaW8A8Int8(CustomTestCase): "w8a8_int8", "--tp-size", 4, + "--disable-radix-cache", ] def test_a_gsm8k(self): diff --git a/test/srt/ascend/test_ascend_tp2_fia_bf16.py b/test/srt/ascend/test_ascend_tp2_fia_bf16.py new file mode 100644 index 000000000..bdd1c5733 --- /dev/null +++ b/test/srt/ascend/test_ascend_tp2_fia_bf16.py @@ -0,0 +1,101 @@ +import os +import unittest +from types import SimpleNamespace +from urllib.parse import urlparse + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + is_in_ci, + popen_launch_server, + run_bench_offline_throughput, +) + +TEST_MODEL_MATRIX = { + "Qwen/Qwen2.5-7B-Instruct": { + "accuracy": 0.85, + "latency": 180, + "output_throughput": 20, + }, +} + + +class TestAscendTp2Bf16(CustomTestCase): + + @classmethod + def setUpClass(cls): + cls.models = TEST_MODEL_MATRIX.keys() + cls.base_url = DEFAULT_URL_FOR_TEST + cls.url = urlparse(DEFAULT_URL_FOR_TEST) + cls.common_args = [ + "--trust-remote-code", + "--disable-cuda-graph", + "--mem-fraction-static", + 0.8, + "--attention-backend", + "ascend", + "--tp-size", + 2, + "--disable-radix-cache", + ] + + def test_a_gsm8k(self): + os.environ["ASCEND_USE_FIA"] = "true" + for model in self.models: + with self.subTest(model=model): + print(f"##=== Testing accuracy: {model} ===##") + + process = popen_launch_server( + model, + self.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + *self.common_args, + ], + ) + + try: + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=1319, + max_new_tokens=512, + parallel=128, + host=f"http://{self.url.hostname}", + port=int(self.url.port), + ) + + metrics = run_eval_few_shot_gsm8k(args) + self.assertGreaterEqual( + metrics["accuracy"], + TEST_MODEL_MATRIX[model]["accuracy"], + ) + finally: + kill_process_tree(process.pid) + + def test_b_throughput(self): + for model in self.models: + with self.subTest(model=model): + print(f"##=== Testing throughput: {model} ===##") + + output_throughput = run_bench_offline_throughput( + model, + [ + *self.common_args, + ], + ) + + print(f"##=== {model} throughput: {output_throughput} ===##") + + if is_in_ci(): + self.assertGreater( + output_throughput, + TEST_MODEL_MATRIX[model]["output_throughput"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 713d4163c..003942e65 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -275,6 +275,8 @@ suite_ascend = { "per-commit-2-ascend-npu": [ TestFile("ascend/test_ascend_tp2_bf16.py", 400), TestFile("ascend/test_ascend_graph_tp2_bf16.py", 400), + TestFile("ascend/test_ascend_tp2_fia_bf16.py", 400), + TestFile("ascend/test_ascend_mla_fia_w8a8int8.py", 400), ], "per-commit-4-ascend-npu": [ TestFile("ascend/test_ascend_mla_w8a8int8.py", 400),