[Feature] Support NPUGraph for DeepSeek on Ascend NPU (#9355)
Co-authored-by: Even Zhou <even.y.zhou@outlook.com>
This commit is contained in:
@@ -1,6 +1,12 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from sglang.srt.disaggregation.ascend.transfer_engine import AscendTransferEngine
|
||||
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
||||
from sglang.srt.disaggregation.mooncake.conn import (
|
||||
MooncakeKVBootstrapServer,
|
||||
MooncakeKVManager,
|
||||
@@ -29,6 +35,75 @@ class AscendKVManager(MooncakeKVManager):
|
||||
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
||||
)
|
||||
|
||||
def send_kvcache(
|
||||
self,
|
||||
mooncake_session_id: str,
|
||||
prefill_kv_indices: npt.NDArray[np.int32],
|
||||
dst_kv_ptrs: list[int],
|
||||
dst_kv_indices: npt.NDArray[np.int32],
|
||||
executor: concurrent.futures.ThreadPoolExecutor,
|
||||
):
|
||||
# Group by indices
|
||||
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
|
||||
prefill_kv_indices, dst_kv_indices
|
||||
)
|
||||
|
||||
num_layers = len(self.kv_args.kv_data_ptrs)
|
||||
layers_params = [
|
||||
(
|
||||
self.kv_args.kv_data_ptrs[layer_id],
|
||||
dst_kv_ptrs[layer_id],
|
||||
self.kv_args.kv_item_lens[layer_id],
|
||||
)
|
||||
for layer_id in range(num_layers)
|
||||
]
|
||||
|
||||
def set_transfer_blocks(
|
||||
src_ptr: int, dst_ptr: int, item_len: int
|
||||
) -> List[Tuple[int, int, int]]:
|
||||
transfer_blocks = []
|
||||
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
||||
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
||||
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
||||
length = item_len * len(prefill_index)
|
||||
transfer_blocks.append((src_addr, dst_addr, length))
|
||||
return transfer_blocks
|
||||
|
||||
# Worker function for processing a single layer
|
||||
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
||||
transfer_blocks = set_transfer_blocks(src_ptr, dst_ptr, item_len)
|
||||
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
||||
|
||||
# Worker function for processing all layers in a batch
|
||||
def process_layers(layers_params: List[Tuple[int, int, int]]) -> int:
|
||||
transfer_blocks = []
|
||||
for src_ptr, dst_ptr, item_len in layers_params:
|
||||
transfer_blocks.extend(set_transfer_blocks(src_ptr, dst_ptr, item_len))
|
||||
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
||||
|
||||
if self.enable_custom_mem_pool:
|
||||
futures = [
|
||||
executor.submit(
|
||||
process_layer,
|
||||
src_ptr,
|
||||
dst_ptr,
|
||||
item_len,
|
||||
)
|
||||
for (src_ptr, dst_ptr, item_len) in layers_params
|
||||
]
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
status = future.result()
|
||||
if status != 0:
|
||||
for f in futures:
|
||||
f.cancel()
|
||||
return status
|
||||
else:
|
||||
# Combining all layers' params in one batch transfer is more efficient
|
||||
# compared to using multiple threads
|
||||
return process_layers(layers_params)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
class AscendKVSender(MooncakeKVSender):
|
||||
pass
|
||||
|
||||
@@ -158,7 +158,7 @@ class AscendAttnBackend(AttentionBackend):
|
||||
self.graph_mode = True
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
@@ -167,7 +167,7 @@ class AscendAttnBackend(AttentionBackend):
|
||||
v,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
save_kv_cache: bool = True,
|
||||
):
|
||||
if not self.use_mla:
|
||||
if save_kv_cache:
|
||||
@@ -253,6 +253,136 @@ class AscendAttnBackend(AttentionBackend):
|
||||
|
||||
return attn_output
|
||||
|
||||
def forward_decode_graph(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = True,
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if save_kv_cache:
|
||||
if self.use_mla:
|
||||
k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
|
||||
k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, k_rope
|
||||
)
|
||||
else:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
if not self.use_mla:
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
||||
layer.layer_id
|
||||
).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
|
||||
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
|
||||
layer.layer_id
|
||||
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
|
||||
query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
if self.forward_metadata.seq_lens_cpu_int is None:
|
||||
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
|
||||
else:
|
||||
actual_seq_len_kv = (
|
||||
self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
|
||||
)
|
||||
num_tokens = query.shape[0]
|
||||
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||
query,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
block_size=self.page_size,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_key_value_heads=layer.tp_k_head_num,
|
||||
input_layout="BSH",
|
||||
scale=layer.scaling,
|
||||
actual_seq_lengths_kv=actual_seq_len_kv,
|
||||
)
|
||||
output = torch.empty(
|
||||
(num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
)
|
||||
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
query,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
block_size=self.page_size,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_key_value_heads=layer.tp_k_head_num,
|
||||
input_layout="BSH",
|
||||
scale=layer.scaling,
|
||||
actual_seq_lengths_kv=actual_seq_len_kv,
|
||||
workspace=workspace,
|
||||
out=[output, softmax_lse],
|
||||
)
|
||||
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
|
||||
else:
|
||||
c_kv, k_rope = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
||||
k_rope_cache = k_rope.view(
|
||||
-1, layer.tp_k_head_num, self.page_size, self.qk_rope_head_dim
|
||||
)
|
||||
c_kv_cache = c_kv.view(
|
||||
-1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
|
||||
)
|
||||
|
||||
q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank)
|
||||
q_rope = q_rope.view(-1, layer.tp_q_head_num, 1, self.qk_rope_head_dim)
|
||||
if self.forward_metadata.seq_lens_cpu_int is None:
|
||||
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
|
||||
else:
|
||||
actual_seq_len_kv = (
|
||||
self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
|
||||
)
|
||||
|
||||
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||
q_nope,
|
||||
c_kv_cache,
|
||||
c_kv_cache,
|
||||
query_rope=q_rope,
|
||||
key_rope=k_rope_cache,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_key_value_heads=layer.tp_k_head_num,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
block_size=self.page_size,
|
||||
input_layout="BNSD",
|
||||
scale=layer.scaling,
|
||||
actual_seq_lengths_kv=actual_seq_len_kv,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
sparse_mode=0,
|
||||
)
|
||||
output = torch.zeros_like(q_nope, dtype=q.dtype, device=q.device)
|
||||
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
|
||||
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
q_nope,
|
||||
c_kv_cache,
|
||||
c_kv_cache,
|
||||
query_rope=q_rope,
|
||||
key_rope=k_rope_cache,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_key_value_heads=layer.tp_k_head_num,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
block_size=self.page_size,
|
||||
input_layout="BNSD",
|
||||
scale=layer.scaling,
|
||||
actual_seq_lengths_kv=actual_seq_len_kv,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
sparse_mode=0,
|
||||
workspace=workspace,
|
||||
out=[output, softmax_lse],
|
||||
)
|
||||
return output.view(-1, layer.tp_q_head_num * self.kv_lora_rank)
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
@@ -260,106 +390,73 @@ class AscendAttnBackend(AttentionBackend):
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = False,
|
||||
save_kv_cache: bool = True,
|
||||
# For multi-head latent attention
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.graph_mode:
|
||||
return self.forward_decode_graph(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
layer,
|
||||
forward_batch,
|
||||
save_kv_cache,
|
||||
q_rope=q_rope,
|
||||
k_rope=k_rope,
|
||||
)
|
||||
|
||||
if not self.use_mla:
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
num_tokens = q.shape[0]
|
||||
if self.graph_mode:
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
||||
layer.layer_id
|
||||
).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
|
||||
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
|
||||
layer.layer_id
|
||||
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
|
||||
query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
workspace = (
|
||||
torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||
query,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
block_size=self.page_size,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_key_value_heads=layer.tp_k_head_num,
|
||||
input_layout="BSH",
|
||||
scale=layer.scaling,
|
||||
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
|
||||
)
|
||||
)
|
||||
attn_output = torch.empty(
|
||||
(num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
)
|
||||
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
query,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
block_size=self.page_size,
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
||||
if self.use_fia:
|
||||
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
q.view(
|
||||
forward_batch.batch_size,
|
||||
-1,
|
||||
layer.tp_q_head_num,
|
||||
layer.qk_head_dim,
|
||||
),
|
||||
k_cache.view(
|
||||
-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim
|
||||
),
|
||||
v_cache.view(
|
||||
-1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim
|
||||
),
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_key_value_heads=layer.tp_k_head_num,
|
||||
input_layout="BSH",
|
||||
input_layout="BSND",
|
||||
atten_mask=None,
|
||||
block_size=self.page_size,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
|
||||
scale=layer.scaling,
|
||||
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
|
||||
workspace=workspace,
|
||||
out=[attn_output, softmax_lse],
|
||||
)
|
||||
else:
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
|
||||
layer.layer_id
|
||||
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||
attn_output = torch.empty(
|
||||
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
if self.use_fia:
|
||||
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
q.view(
|
||||
forward_batch.batch_size,
|
||||
-1,
|
||||
layer.tp_q_head_num,
|
||||
layer.qk_head_dim,
|
||||
),
|
||||
k_cache.view(
|
||||
-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim
|
||||
),
|
||||
v_cache.view(
|
||||
-1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim
|
||||
),
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_key_value_heads=layer.tp_k_head_num,
|
||||
input_layout="BSND",
|
||||
atten_mask=None,
|
||||
block_size=self.page_size,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
|
||||
scale=layer.scaling,
|
||||
)
|
||||
else:
|
||||
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||
attn_output = torch.empty(
|
||||
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=k_cache,
|
||||
value_cache=v_cache,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_kv_heads=layer.tp_k_head_num,
|
||||
scale_value=layer.scaling,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||
out=attn_output,
|
||||
)
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=k_cache,
|
||||
value_cache=v_cache,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_kv_heads=layer.tp_k_head_num,
|
||||
scale_value=layer.scaling,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||
out=attn_output,
|
||||
)
|
||||
return attn_output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
|
||||
else:
|
||||
if save_kv_cache:
|
||||
@@ -370,9 +467,7 @@ class AscendAttnBackend(AttentionBackend):
|
||||
kv_c = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
k_pe = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
||||
|
||||
if (self.graph_mode or self.use_fia) and (
|
||||
layer.tp_q_head_num // layer.tp_k_head_num
|
||||
) >= 8:
|
||||
if self.use_fia and (layer.tp_q_head_num // layer.tp_k_head_num) >= 8:
|
||||
"""layer.tp_q_head_num // layer.tp_k_head_num < 8 will support in the later version of CANN"""
|
||||
kv_c = kv_c.view(
|
||||
-1, self.page_size, layer.tp_k_head_num * self.kv_lora_rank
|
||||
|
||||
@@ -746,19 +746,25 @@ class DeepEPMoE(EPMoE):
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w13_weight],
|
||||
scale=[self.w13_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[pertoken_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=seg_indptr,
|
||||
output_dtype=output_dtype,
|
||||
output_dtype=torch.int32,
|
||||
)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=self.w13_weight_scale.to(torch.float32),
|
||||
activation_scale=pertoken_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=seg_indptr,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
|
||||
@@ -304,12 +304,12 @@ class TopK(CustomOp):
|
||||
global_num_experts = router_logits.shape[-1]
|
||||
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if global_num_experts == 256 and self.topk_config.renormalize is True:
|
||||
if global_num_experts == 256:
|
||||
|
||||
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
|
||||
router_logits = router_logits.to(torch.float32)
|
||||
|
||||
return torch_npu.npu_moe_gating_top_k(
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=self.topk_config.top_k,
|
||||
bias=self.topk_config.correction_bias.to(torch.float32),
|
||||
@@ -321,6 +321,16 @@ class TopK(CustomOp):
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
eps=float(1e-20),
|
||||
)
|
||||
|
||||
if self.topk_config.renormalize:
|
||||
topk_weights_sum = (
|
||||
topk_weights.sum(dim=-1, keepdim=True)
|
||||
if self.topk_config.num_fused_shared_experts == 0
|
||||
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
||||
)
|
||||
topk_weights = topk_weights / topk_weights_sum
|
||||
|
||||
return StandardTopKOutput(topk_weights, topk_ids, _)
|
||||
else:
|
||||
self.topk_config.torch_native = True
|
||||
return select_experts(
|
||||
|
||||
@@ -551,7 +551,7 @@ class NPU_W8A8LinearMethodImpl:
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
|
||||
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
|
||||
params_dict["input_offset"] = torch.empty(1, dtype=params_dtype)
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
@@ -582,11 +582,11 @@ class NPU_W8A8LinearMethodImpl:
|
||||
if original_dtype != torch.int8:
|
||||
x = torch_npu.npu_quantize(
|
||||
x,
|
||||
layer.aclnn_input_scale,
|
||||
layer.aclnn_input_scale_reciprocal,
|
||||
layer.aclnn_input_offset,
|
||||
torch.qint8,
|
||||
-1,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||
# bias will not get added more than once in Attention TP>1 case)
|
||||
@@ -608,6 +608,10 @@ class NPU_W8A8LinearMethodImpl:
|
||||
layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
|
||||
layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.aclnn_input_offset = torch.nn.Parameter(
|
||||
layer.input_offset.data.repeat(expanding_factor).to(device="npu"),
|
||||
requires_grad=False,
|
||||
|
||||
@@ -918,6 +918,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
layer_num,
|
||||
self.size // self.page_size + 1,
|
||||
self.page_size,
|
||||
1,
|
||||
self.kv_lora_rank,
|
||||
),
|
||||
dtype=self.store_dtype,
|
||||
@@ -928,6 +929,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
layer_num,
|
||||
self.size // self.page_size + 1,
|
||||
self.page_size,
|
||||
1,
|
||||
self.qk_rope_head_dim,
|
||||
),
|
||||
dtype=self.store_dtype,
|
||||
@@ -1000,9 +1002,11 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
layer_id = layer.layer_id
|
||||
if cache_k.dtype != self.dtype:
|
||||
cache_k = cache_k.to(self.dtype)
|
||||
cache_v = cache_v.to(self.dtype)
|
||||
|
||||
if self.store_dtype != self.dtype:
|
||||
cache_k = cache_k.view(self.store_dtype)
|
||||
cache_v = cache_v.view(self.store_dtype)
|
||||
|
||||
if cache_v is None:
|
||||
cache_k, cache_v = cache_k.split(
|
||||
|
||||
@@ -114,6 +114,7 @@ from sglang.srt.utils import (
|
||||
is_flashinfer_available,
|
||||
is_hip,
|
||||
is_non_idle_and_non_empty,
|
||||
is_npu,
|
||||
is_sm100_supported,
|
||||
log_info_on_rank0,
|
||||
make_layers,
|
||||
@@ -122,6 +123,7 @@ from sglang.srt.utils import (
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
_is_npu = is_npu()
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
@@ -1181,13 +1183,19 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
k[..., : self.qk_nope_head_dim] = k_nope
|
||||
k[..., self.qk_nope_head_dim :] = k_pe
|
||||
|
||||
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
||||
latent_cache[:, :, self.kv_lora_rank :] = k_pe
|
||||
if not _is_npu:
|
||||
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
||||
latent_cache[:, :, self.kv_lora_rank :] = k_pe
|
||||
|
||||
# Save latent cache
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
||||
)
|
||||
# Save latent cache
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
||||
)
|
||||
else:
|
||||
# To reduce a time-costing split operation
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
|
||||
)
|
||||
|
||||
return q, k, v, forward_batch
|
||||
|
||||
|
||||
Reference in New Issue
Block a user