diff --git a/python/sglang/srt/disaggregation/ascend/conn.py b/python/sglang/srt/disaggregation/ascend/conn.py index 3e988c0a4..b0009fc7c 100644 --- a/python/sglang/srt/disaggregation/ascend/conn.py +++ b/python/sglang/srt/disaggregation/ascend/conn.py @@ -1,6 +1,12 @@ +import concurrent.futures import logging +from typing import List, Tuple + +import numpy as np +import numpy.typing as npt from sglang.srt.disaggregation.ascend.transfer_engine import AscendTransferEngine +from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous from sglang.srt.disaggregation.mooncake.conn import ( MooncakeKVBootstrapServer, MooncakeKVManager, @@ -29,6 +35,75 @@ class AscendKVManager(MooncakeKVManager): self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens ) + def send_kvcache( + self, + mooncake_session_id: str, + prefill_kv_indices: npt.NDArray[np.int32], + dst_kv_ptrs: list[int], + dst_kv_indices: npt.NDArray[np.int32], + executor: concurrent.futures.ThreadPoolExecutor, + ): + # Group by indices + prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous( + prefill_kv_indices, dst_kv_indices + ) + + num_layers = len(self.kv_args.kv_data_ptrs) + layers_params = [ + ( + self.kv_args.kv_data_ptrs[layer_id], + dst_kv_ptrs[layer_id], + self.kv_args.kv_item_lens[layer_id], + ) + for layer_id in range(num_layers) + ] + + def set_transfer_blocks( + src_ptr: int, dst_ptr: int, item_len: int + ) -> List[Tuple[int, int, int]]: + transfer_blocks = [] + for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks): + src_addr = src_ptr + int(prefill_index[0]) * item_len + dst_addr = dst_ptr + int(decode_index[0]) * item_len + length = item_len * len(prefill_index) + transfer_blocks.append((src_addr, dst_addr, length)) + return transfer_blocks + + # Worker function for processing a single layer + def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int: + transfer_blocks = set_transfer_blocks(src_ptr, dst_ptr, item_len) + return self._transfer_data(mooncake_session_id, transfer_blocks) + + # Worker function for processing all layers in a batch + def process_layers(layers_params: List[Tuple[int, int, int]]) -> int: + transfer_blocks = [] + for src_ptr, dst_ptr, item_len in layers_params: + transfer_blocks.extend(set_transfer_blocks(src_ptr, dst_ptr, item_len)) + return self._transfer_data(mooncake_session_id, transfer_blocks) + + if self.enable_custom_mem_pool: + futures = [ + executor.submit( + process_layer, + src_ptr, + dst_ptr, + item_len, + ) + for (src_ptr, dst_ptr, item_len) in layers_params + ] + for future in concurrent.futures.as_completed(futures): + status = future.result() + if status != 0: + for f in futures: + f.cancel() + return status + else: + # Combining all layers' params in one batch transfer is more efficient + # compared to using multiple threads + return process_layers(layers_params) + + return 0 + class AscendKVSender(MooncakeKVSender): pass diff --git a/python/sglang/srt/layers/attention/ascend_backend.py b/python/sglang/srt/layers/attention/ascend_backend.py index f5b521d20..0f826d2df 100644 --- a/python/sglang/srt/layers/attention/ascend_backend.py +++ b/python/sglang/srt/layers/attention/ascend_backend.py @@ -158,7 +158,7 @@ class AscendAttnBackend(AttentionBackend): self.graph_mode = True def get_cuda_graph_seq_len_fill_value(self): - return 1 + return 0 def forward_extend( self, @@ -167,7 +167,7 @@ class AscendAttnBackend(AttentionBackend): v, layer: RadixAttention, forward_batch: ForwardBatch, - save_kv_cache=True, + save_kv_cache: bool = True, ): if not self.use_mla: if save_kv_cache: @@ -253,6 +253,136 @@ class AscendAttnBackend(AttentionBackend): return attn_output + def forward_decode_graph( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + q_rope: Optional[torch.Tensor] = None, + k_rope: Optional[torch.Tensor] = None, + ): + if save_kv_cache: + if self.use_mla: + k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank) + k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim) + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, k_rope + ) + else: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) + + if not self.use_mla: + k_cache = forward_batch.token_to_kv_pool.get_key_buffer( + layer.layer_id + ).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim) + v_cache = forward_batch.token_to_kv_pool.get_value_buffer( + layer.layer_id + ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim) + query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim) + if self.forward_metadata.seq_lens_cpu_int is None: + actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list + else: + actual_seq_len_kv = ( + self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist() + ) + num_tokens = query.shape[0] + workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( + query, + k_cache, + v_cache, + block_table=self.forward_metadata.block_tables, + block_size=self.page_size, + num_heads=layer.tp_q_head_num, + num_key_value_heads=layer.tp_k_head_num, + input_layout="BSH", + scale=layer.scaling, + actual_seq_lengths_kv=actual_seq_len_kv, + ) + output = torch.empty( + (num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim), + dtype=q.dtype, + device=q.device, + ) + softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device) + torch_npu.npu_fused_infer_attention_score.out( + query, + k_cache, + v_cache, + block_table=self.forward_metadata.block_tables, + block_size=self.page_size, + num_heads=layer.tp_q_head_num, + num_key_value_heads=layer.tp_k_head_num, + input_layout="BSH", + scale=layer.scaling, + actual_seq_lengths_kv=actual_seq_len_kv, + workspace=workspace, + out=[output, softmax_lse], + ) + return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim) + else: + c_kv, k_rope = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + k_rope_cache = k_rope.view( + -1, layer.tp_k_head_num, self.page_size, self.qk_rope_head_dim + ) + c_kv_cache = c_kv.view( + -1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank + ) + + q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank) + q_rope = q_rope.view(-1, layer.tp_q_head_num, 1, self.qk_rope_head_dim) + if self.forward_metadata.seq_lens_cpu_int is None: + actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list + else: + actual_seq_len_kv = ( + self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist() + ) + + workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( + q_nope, + c_kv_cache, + c_kv_cache, + query_rope=q_rope, + key_rope=k_rope_cache, + num_heads=layer.tp_q_head_num, + num_key_value_heads=layer.tp_k_head_num, + block_table=self.forward_metadata.block_tables, + block_size=self.page_size, + input_layout="BNSD", + scale=layer.scaling, + actual_seq_lengths_kv=actual_seq_len_kv, + antiquant_mode=0, + antiquant_scale=None, + sparse_mode=0, + ) + output = torch.zeros_like(q_nope, dtype=q.dtype, device=q.device) + softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device) + + torch_npu.npu_fused_infer_attention_score.out( + q_nope, + c_kv_cache, + c_kv_cache, + query_rope=q_rope, + key_rope=k_rope_cache, + num_heads=layer.tp_q_head_num, + num_key_value_heads=layer.tp_k_head_num, + block_table=self.forward_metadata.block_tables, + block_size=self.page_size, + input_layout="BNSD", + scale=layer.scaling, + actual_seq_lengths_kv=actual_seq_len_kv, + antiquant_mode=0, + antiquant_scale=None, + sparse_mode=0, + workspace=workspace, + out=[output, softmax_lse], + ) + return output.view(-1, layer.tp_q_head_num * self.kv_lora_rank) + def forward_decode( self, q: torch.Tensor, @@ -260,106 +390,73 @@ class AscendAttnBackend(AttentionBackend): v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, - save_kv_cache: bool = False, + save_kv_cache: bool = True, # For multi-head latent attention q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, ): + if self.graph_mode: + return self.forward_decode_graph( + q, + k, + v, + layer, + forward_batch, + save_kv_cache, + q_rope=q_rope, + k_rope=k_rope, + ) + if not self.use_mla: if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( layer, forward_batch.out_cache_loc, k, v ) num_tokens = q.shape[0] - if self.graph_mode: - k_cache = forward_batch.token_to_kv_pool.get_key_buffer( - layer.layer_id - ).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim) - v_cache = forward_batch.token_to_kv_pool.get_value_buffer( - layer.layer_id - ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim) - query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim) - workspace = ( - torch_npu._npu_fused_infer_attention_score_get_max_workspace( - query, - k_cache, - v_cache, - block_table=self.forward_metadata.block_tables, - block_size=self.page_size, - num_heads=layer.tp_q_head_num, - num_key_value_heads=layer.tp_k_head_num, - input_layout="BSH", - scale=layer.scaling, - actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list, - ) - ) - attn_output = torch.empty( - (num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim), - dtype=q.dtype, - device=q.device, - ) - softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device) - torch_npu.npu_fused_infer_attention_score.out( - query, - k_cache, - v_cache, - block_table=self.forward_metadata.block_tables, - block_size=self.page_size, + k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) + if self.use_fia: + attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score( + q.view( + forward_batch.batch_size, + -1, + layer.tp_q_head_num, + layer.qk_head_dim, + ), + k_cache.view( + -1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim + ), + v_cache.view( + -1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim + ), num_heads=layer.tp_q_head_num, num_key_value_heads=layer.tp_k_head_num, - input_layout="BSH", + input_layout="BSND", + atten_mask=None, + block_size=self.page_size, + block_table=self.forward_metadata.block_tables, + actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int, scale=layer.scaling, - actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list, - workspace=workspace, - out=[attn_output, softmax_lse], ) else: - k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - v_cache = forward_batch.token_to_kv_pool.get_value_buffer( - layer.layer_id + query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) + attn_output = torch.empty( + (num_tokens, layer.tp_q_head_num, layer.v_head_dim), + dtype=query.dtype, + device=query.device, ) - if self.use_fia: - attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score( - q.view( - forward_batch.batch_size, - -1, - layer.tp_q_head_num, - layer.qk_head_dim, - ), - k_cache.view( - -1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim - ), - v_cache.view( - -1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim - ), - num_heads=layer.tp_q_head_num, - num_key_value_heads=layer.tp_k_head_num, - input_layout="BSND", - atten_mask=None, - block_size=self.page_size, - block_table=self.forward_metadata.block_tables, - actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int, - scale=layer.scaling, - ) - else: - query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) - attn_output = torch.empty( - (num_tokens, layer.tp_q_head_num, layer.v_head_dim), - dtype=query.dtype, - device=query.device, - ) - torch_npu._npu_paged_attention( - query=query, - key_cache=k_cache, - value_cache=v_cache, - num_heads=layer.tp_q_head_num, - num_kv_heads=layer.tp_k_head_num, - scale_value=layer.scaling, - block_table=self.forward_metadata.block_tables, - context_lens=self.forward_metadata.seq_lens_cpu_int, - out=attn_output, - ) + torch_npu._npu_paged_attention( + query=query, + key_cache=k_cache, + value_cache=v_cache, + num_heads=layer.tp_q_head_num, + num_kv_heads=layer.tp_k_head_num, + scale_value=layer.scaling, + block_table=self.forward_metadata.block_tables, + context_lens=self.forward_metadata.seq_lens_cpu_int, + out=attn_output, + ) return attn_output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim) else: if save_kv_cache: @@ -370,9 +467,7 @@ class AscendAttnBackend(AttentionBackend): kv_c = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) k_pe = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) - if (self.graph_mode or self.use_fia) and ( - layer.tp_q_head_num // layer.tp_k_head_num - ) >= 8: + if self.use_fia and (layer.tp_q_head_num // layer.tp_k_head_num) >= 8: """layer.tp_q_head_num // layer.tp_k_head_num < 8 will support in the later version of CANN""" kv_c = kv_c.view( -1, self.page_size, layer.tp_k_head_num * self.kv_lora_rank diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index e35a4e017..175914560 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -746,19 +746,25 @@ class DeepEPMoE(EPMoE): hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[self.w13_weight], - scale=[self.w13_weight_scale.to(output_dtype)], - per_token_scale=[pertoken_scale], split_item=2, group_list_type=group_list_type, group_type=0, group_list=seg_indptr, - output_dtype=output_dtype, + output_dtype=torch.int32, )[0] # act_fn: swiglu - hidden_states = torch_npu.npu_swiglu(hidden_states) - - hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states) + hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( + x=hidden_states, + weight_scale=self.w13_weight_scale.to(torch.float32), + activation_scale=pertoken_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=seg_indptr, + activate_left=True, + quant_mode=1, + ) # gmm2: down_proj hidden_states = torch_npu.npu_grouped_matmul( diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 3f8b4afd0..7e43a5541 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -304,12 +304,12 @@ class TopK(CustomOp): global_num_experts = router_logits.shape[-1] # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - if global_num_experts == 256 and self.topk_config.renormalize is True: + if global_num_experts == 256: routed_scaling_factor = self.topk_config.routed_scaling_factor or 1 router_logits = router_logits.to(torch.float32) - return torch_npu.npu_moe_gating_top_k( + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( router_logits, k=self.topk_config.top_k, bias=self.topk_config.correction_bias.to(torch.float32), @@ -321,6 +321,16 @@ class TopK(CustomOp): routed_scaling_factor=routed_scaling_factor, eps=float(1e-20), ) + + if self.topk_config.renormalize: + topk_weights_sum = ( + topk_weights.sum(dim=-1, keepdim=True) + if self.topk_config.num_fused_shared_experts == 0 + else topk_weights[:, :-1].sum(dim=-1, keepdim=True) + ) + topk_weights = topk_weights / topk_weights_sum + + return StandardTopKOutput(topk_weights, topk_ids, _) else: self.topk_config.torch_native = True return select_experts( diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index abcf334e0..db9bdbec9 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -551,7 +551,7 @@ class NPU_W8A8LinearMethodImpl: def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: params_dict = {} params_dict["input_scale"] = torch.empty(1, dtype=params_dtype) - params_dict["input_offset"] = torch.empty(1, dtype=torch.int8) + params_dict["input_offset"] = torch.empty(1, dtype=params_dtype) return params_dict @staticmethod @@ -582,11 +582,11 @@ class NPU_W8A8LinearMethodImpl: if original_dtype != torch.int8: x = torch_npu.npu_quantize( x, - layer.aclnn_input_scale, + layer.aclnn_input_scale_reciprocal, layer.aclnn_input_offset, torch.qint8, -1, - True, + False, ) # Only fuse bias add into GEMM for rank 0 (this ensures that # bias will not get added more than once in Attention TP>1 case) @@ -608,6 +608,10 @@ class NPU_W8A8LinearMethodImpl: layer.input_scale.data.repeat(expanding_factor).to(device="npu"), requires_grad=False, ) + layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter( + layer.input_scale.data.repeat(expanding_factor).to(device="npu"), + requires_grad=False, + ) layer.aclnn_input_offset = torch.nn.Parameter( layer.input_offset.data.repeat(expanding_factor).to(device="npu"), requires_grad=False, diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 142597b3a..3bde48da4 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -918,6 +918,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool): layer_num, self.size // self.page_size + 1, self.page_size, + 1, self.kv_lora_rank, ), dtype=self.store_dtype, @@ -928,6 +929,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool): layer_num, self.size // self.page_size + 1, self.page_size, + 1, self.qk_rope_head_dim, ), dtype=self.store_dtype, @@ -1000,9 +1002,11 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool): layer_id = layer.layer_id if cache_k.dtype != self.dtype: cache_k = cache_k.to(self.dtype) + cache_v = cache_v.to(self.dtype) if self.store_dtype != self.dtype: cache_k = cache_k.view(self.store_dtype) + cache_v = cache_v.view(self.store_dtype) if cache_v is None: cache_k, cache_v = cache_k.split( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 6c942fcd1..30df6afcd 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -114,6 +114,7 @@ from sglang.srt.utils import ( is_flashinfer_available, is_hip, is_non_idle_and_non_empty, + is_npu, is_sm100_supported, log_info_on_rank0, make_layers, @@ -122,6 +123,7 @@ from sglang.srt.utils import ( _is_hip = is_hip() _is_cuda = is_cuda() +_is_npu = is_npu() _is_fp8_fnuz = is_fp8_fnuz() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _is_cpu_amx_available = cpu_has_amx_support() @@ -1181,13 +1183,19 @@ class DeepseekV2AttentionMLA(nn.Module): k[..., : self.qk_nope_head_dim] = k_nope k[..., self.qk_nope_head_dim :] = k_pe - latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) - latent_cache[:, :, self.kv_lora_rank :] = k_pe + if not _is_npu: + latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) + latent_cache[:, :, self.kv_lora_rank :] = k_pe - # Save latent cache - forward_batch.token_to_kv_pool.set_kv_buffer( - self.attn_mha, forward_batch.out_cache_loc, latent_cache, None - ) + # Save latent cache + forward_batch.token_to_kv_pool.set_kv_buffer( + self.attn_mha, forward_batch.out_cache_loc, latent_cache, None + ) + else: + # To reduce a time-costing split operation + forward_batch.token_to_kv_pool.set_kv_buffer( + self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe + ) return q, k, v, forward_batch