[Feature] Support NPUGraph for DeepSeek on Ascend NPU (#9355)
Co-authored-by: Even Zhou <even.y.zhou@outlook.com>
This commit is contained in:
@@ -1,6 +1,12 @@
|
|||||||
|
import concurrent.futures
|
||||||
import logging
|
import logging
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
from sglang.srt.disaggregation.ascend.transfer_engine import AscendTransferEngine
|
from sglang.srt.disaggregation.ascend.transfer_engine import AscendTransferEngine
|
||||||
|
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
||||||
from sglang.srt.disaggregation.mooncake.conn import (
|
from sglang.srt.disaggregation.mooncake.conn import (
|
||||||
MooncakeKVBootstrapServer,
|
MooncakeKVBootstrapServer,
|
||||||
MooncakeKVManager,
|
MooncakeKVManager,
|
||||||
@@ -29,6 +35,75 @@ class AscendKVManager(MooncakeKVManager):
|
|||||||
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def send_kvcache(
|
||||||
|
self,
|
||||||
|
mooncake_session_id: str,
|
||||||
|
prefill_kv_indices: npt.NDArray[np.int32],
|
||||||
|
dst_kv_ptrs: list[int],
|
||||||
|
dst_kv_indices: npt.NDArray[np.int32],
|
||||||
|
executor: concurrent.futures.ThreadPoolExecutor,
|
||||||
|
):
|
||||||
|
# Group by indices
|
||||||
|
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
|
||||||
|
prefill_kv_indices, dst_kv_indices
|
||||||
|
)
|
||||||
|
|
||||||
|
num_layers = len(self.kv_args.kv_data_ptrs)
|
||||||
|
layers_params = [
|
||||||
|
(
|
||||||
|
self.kv_args.kv_data_ptrs[layer_id],
|
||||||
|
dst_kv_ptrs[layer_id],
|
||||||
|
self.kv_args.kv_item_lens[layer_id],
|
||||||
|
)
|
||||||
|
for layer_id in range(num_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
def set_transfer_blocks(
|
||||||
|
src_ptr: int, dst_ptr: int, item_len: int
|
||||||
|
) -> List[Tuple[int, int, int]]:
|
||||||
|
transfer_blocks = []
|
||||||
|
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
||||||
|
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
||||||
|
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
||||||
|
length = item_len * len(prefill_index)
|
||||||
|
transfer_blocks.append((src_addr, dst_addr, length))
|
||||||
|
return transfer_blocks
|
||||||
|
|
||||||
|
# Worker function for processing a single layer
|
||||||
|
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
||||||
|
transfer_blocks = set_transfer_blocks(src_ptr, dst_ptr, item_len)
|
||||||
|
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
||||||
|
|
||||||
|
# Worker function for processing all layers in a batch
|
||||||
|
def process_layers(layers_params: List[Tuple[int, int, int]]) -> int:
|
||||||
|
transfer_blocks = []
|
||||||
|
for src_ptr, dst_ptr, item_len in layers_params:
|
||||||
|
transfer_blocks.extend(set_transfer_blocks(src_ptr, dst_ptr, item_len))
|
||||||
|
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
||||||
|
|
||||||
|
if self.enable_custom_mem_pool:
|
||||||
|
futures = [
|
||||||
|
executor.submit(
|
||||||
|
process_layer,
|
||||||
|
src_ptr,
|
||||||
|
dst_ptr,
|
||||||
|
item_len,
|
||||||
|
)
|
||||||
|
for (src_ptr, dst_ptr, item_len) in layers_params
|
||||||
|
]
|
||||||
|
for future in concurrent.futures.as_completed(futures):
|
||||||
|
status = future.result()
|
||||||
|
if status != 0:
|
||||||
|
for f in futures:
|
||||||
|
f.cancel()
|
||||||
|
return status
|
||||||
|
else:
|
||||||
|
# Combining all layers' params in one batch transfer is more efficient
|
||||||
|
# compared to using multiple threads
|
||||||
|
return process_layers(layers_params)
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
class AscendKVSender(MooncakeKVSender):
|
class AscendKVSender(MooncakeKVSender):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
self.graph_mode = True
|
self.graph_mode = True
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
return 1
|
return 0
|
||||||
|
|
||||||
def forward_extend(
|
def forward_extend(
|
||||||
self,
|
self,
|
||||||
@@ -167,7 +167,7 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
v,
|
v,
|
||||||
layer: RadixAttention,
|
layer: RadixAttention,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache: bool = True,
|
||||||
):
|
):
|
||||||
if not self.use_mla:
|
if not self.use_mla:
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
@@ -253,6 +253,136 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
def forward_decode_graph(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer: RadixAttention,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
save_kv_cache: bool = True,
|
||||||
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
|
k_rope: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
if save_kv_cache:
|
||||||
|
if self.use_mla:
|
||||||
|
k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
|
||||||
|
k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer, forward_batch.out_cache_loc, k, k_rope
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer, forward_batch.out_cache_loc, k, v
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.use_mla:
|
||||||
|
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
||||||
|
layer.layer_id
|
||||||
|
).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
|
||||||
|
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
|
||||||
|
layer.layer_id
|
||||||
|
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
|
||||||
|
query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||||
|
if self.forward_metadata.seq_lens_cpu_int is None:
|
||||||
|
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
|
||||||
|
else:
|
||||||
|
actual_seq_len_kv = (
|
||||||
|
self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
|
||||||
|
)
|
||||||
|
num_tokens = query.shape[0]
|
||||||
|
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||||
|
query,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
block_table=self.forward_metadata.block_tables,
|
||||||
|
block_size=self.page_size,
|
||||||
|
num_heads=layer.tp_q_head_num,
|
||||||
|
num_key_value_heads=layer.tp_k_head_num,
|
||||||
|
input_layout="BSH",
|
||||||
|
scale=layer.scaling,
|
||||||
|
actual_seq_lengths_kv=actual_seq_len_kv,
|
||||||
|
)
|
||||||
|
output = torch.empty(
|
||||||
|
(num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
|
||||||
|
dtype=q.dtype,
|
||||||
|
device=q.device,
|
||||||
|
)
|
||||||
|
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
|
||||||
|
torch_npu.npu_fused_infer_attention_score.out(
|
||||||
|
query,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
block_table=self.forward_metadata.block_tables,
|
||||||
|
block_size=self.page_size,
|
||||||
|
num_heads=layer.tp_q_head_num,
|
||||||
|
num_key_value_heads=layer.tp_k_head_num,
|
||||||
|
input_layout="BSH",
|
||||||
|
scale=layer.scaling,
|
||||||
|
actual_seq_lengths_kv=actual_seq_len_kv,
|
||||||
|
workspace=workspace,
|
||||||
|
out=[output, softmax_lse],
|
||||||
|
)
|
||||||
|
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
|
else:
|
||||||
|
c_kv, k_rope = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
||||||
|
k_rope_cache = k_rope.view(
|
||||||
|
-1, layer.tp_k_head_num, self.page_size, self.qk_rope_head_dim
|
||||||
|
)
|
||||||
|
c_kv_cache = c_kv.view(
|
||||||
|
-1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
|
||||||
|
)
|
||||||
|
|
||||||
|
q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank)
|
||||||
|
q_rope = q_rope.view(-1, layer.tp_q_head_num, 1, self.qk_rope_head_dim)
|
||||||
|
if self.forward_metadata.seq_lens_cpu_int is None:
|
||||||
|
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
|
||||||
|
else:
|
||||||
|
actual_seq_len_kv = (
|
||||||
|
self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
|
||||||
|
)
|
||||||
|
|
||||||
|
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||||
|
q_nope,
|
||||||
|
c_kv_cache,
|
||||||
|
c_kv_cache,
|
||||||
|
query_rope=q_rope,
|
||||||
|
key_rope=k_rope_cache,
|
||||||
|
num_heads=layer.tp_q_head_num,
|
||||||
|
num_key_value_heads=layer.tp_k_head_num,
|
||||||
|
block_table=self.forward_metadata.block_tables,
|
||||||
|
block_size=self.page_size,
|
||||||
|
input_layout="BNSD",
|
||||||
|
scale=layer.scaling,
|
||||||
|
actual_seq_lengths_kv=actual_seq_len_kv,
|
||||||
|
antiquant_mode=0,
|
||||||
|
antiquant_scale=None,
|
||||||
|
sparse_mode=0,
|
||||||
|
)
|
||||||
|
output = torch.zeros_like(q_nope, dtype=q.dtype, device=q.device)
|
||||||
|
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
|
||||||
|
|
||||||
|
torch_npu.npu_fused_infer_attention_score.out(
|
||||||
|
q_nope,
|
||||||
|
c_kv_cache,
|
||||||
|
c_kv_cache,
|
||||||
|
query_rope=q_rope,
|
||||||
|
key_rope=k_rope_cache,
|
||||||
|
num_heads=layer.tp_q_head_num,
|
||||||
|
num_key_value_heads=layer.tp_k_head_num,
|
||||||
|
block_table=self.forward_metadata.block_tables,
|
||||||
|
block_size=self.page_size,
|
||||||
|
input_layout="BNSD",
|
||||||
|
scale=layer.scaling,
|
||||||
|
actual_seq_lengths_kv=actual_seq_len_kv,
|
||||||
|
antiquant_mode=0,
|
||||||
|
antiquant_scale=None,
|
||||||
|
sparse_mode=0,
|
||||||
|
workspace=workspace,
|
||||||
|
out=[output, softmax_lse],
|
||||||
|
)
|
||||||
|
return output.view(-1, layer.tp_q_head_num * self.kv_lora_rank)
|
||||||
|
|
||||||
def forward_decode(
|
def forward_decode(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
@@ -260,106 +390,73 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
layer: RadixAttention,
|
layer: RadixAttention,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache: bool = False,
|
save_kv_cache: bool = True,
|
||||||
# For multi-head latent attention
|
# For multi-head latent attention
|
||||||
q_rope: Optional[torch.Tensor] = None,
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
k_rope: Optional[torch.Tensor] = None,
|
k_rope: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
if self.graph_mode:
|
||||||
|
return self.forward_decode_graph(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
layer,
|
||||||
|
forward_batch,
|
||||||
|
save_kv_cache,
|
||||||
|
q_rope=q_rope,
|
||||||
|
k_rope=k_rope,
|
||||||
|
)
|
||||||
|
|
||||||
if not self.use_mla:
|
if not self.use_mla:
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
layer, forward_batch.out_cache_loc, k, v
|
layer, forward_batch.out_cache_loc, k, v
|
||||||
)
|
)
|
||||||
num_tokens = q.shape[0]
|
num_tokens = q.shape[0]
|
||||||
if self.graph_mode:
|
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
||||||
layer.layer_id
|
if self.use_fia:
|
||||||
).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
|
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
||||||
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
|
q.view(
|
||||||
layer.layer_id
|
forward_batch.batch_size,
|
||||||
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
|
-1,
|
||||||
query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
|
layer.tp_q_head_num,
|
||||||
workspace = (
|
layer.qk_head_dim,
|
||||||
torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
),
|
||||||
query,
|
k_cache.view(
|
||||||
k_cache,
|
-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim
|
||||||
v_cache,
|
),
|
||||||
block_table=self.forward_metadata.block_tables,
|
v_cache.view(
|
||||||
block_size=self.page_size,
|
-1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim
|
||||||
num_heads=layer.tp_q_head_num,
|
),
|
||||||
num_key_value_heads=layer.tp_k_head_num,
|
|
||||||
input_layout="BSH",
|
|
||||||
scale=layer.scaling,
|
|
||||||
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
attn_output = torch.empty(
|
|
||||||
(num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
|
|
||||||
dtype=q.dtype,
|
|
||||||
device=q.device,
|
|
||||||
)
|
|
||||||
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
|
|
||||||
torch_npu.npu_fused_infer_attention_score.out(
|
|
||||||
query,
|
|
||||||
k_cache,
|
|
||||||
v_cache,
|
|
||||||
block_table=self.forward_metadata.block_tables,
|
|
||||||
block_size=self.page_size,
|
|
||||||
num_heads=layer.tp_q_head_num,
|
num_heads=layer.tp_q_head_num,
|
||||||
num_key_value_heads=layer.tp_k_head_num,
|
num_key_value_heads=layer.tp_k_head_num,
|
||||||
input_layout="BSH",
|
input_layout="BSND",
|
||||||
|
atten_mask=None,
|
||||||
|
block_size=self.page_size,
|
||||||
|
block_table=self.forward_metadata.block_tables,
|
||||||
|
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
|
||||||
scale=layer.scaling,
|
scale=layer.scaling,
|
||||||
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
|
|
||||||
workspace=workspace,
|
|
||||||
out=[attn_output, softmax_lse],
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||||
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
|
attn_output = torch.empty(
|
||||||
layer.layer_id
|
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
|
||||||
|
dtype=query.dtype,
|
||||||
|
device=query.device,
|
||||||
)
|
)
|
||||||
if self.use_fia:
|
|
||||||
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
|
||||||
q.view(
|
|
||||||
forward_batch.batch_size,
|
|
||||||
-1,
|
|
||||||
layer.tp_q_head_num,
|
|
||||||
layer.qk_head_dim,
|
|
||||||
),
|
|
||||||
k_cache.view(
|
|
||||||
-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim
|
|
||||||
),
|
|
||||||
v_cache.view(
|
|
||||||
-1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim
|
|
||||||
),
|
|
||||||
num_heads=layer.tp_q_head_num,
|
|
||||||
num_key_value_heads=layer.tp_k_head_num,
|
|
||||||
input_layout="BSND",
|
|
||||||
atten_mask=None,
|
|
||||||
block_size=self.page_size,
|
|
||||||
block_table=self.forward_metadata.block_tables,
|
|
||||||
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
|
|
||||||
scale=layer.scaling,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
|
||||||
attn_output = torch.empty(
|
|
||||||
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
|
|
||||||
dtype=query.dtype,
|
|
||||||
device=query.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_npu._npu_paged_attention(
|
torch_npu._npu_paged_attention(
|
||||||
query=query,
|
query=query,
|
||||||
key_cache=k_cache,
|
key_cache=k_cache,
|
||||||
value_cache=v_cache,
|
value_cache=v_cache,
|
||||||
num_heads=layer.tp_q_head_num,
|
num_heads=layer.tp_q_head_num,
|
||||||
num_kv_heads=layer.tp_k_head_num,
|
num_kv_heads=layer.tp_k_head_num,
|
||||||
scale_value=layer.scaling,
|
scale_value=layer.scaling,
|
||||||
block_table=self.forward_metadata.block_tables,
|
block_table=self.forward_metadata.block_tables,
|
||||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||||
out=attn_output,
|
out=attn_output,
|
||||||
)
|
)
|
||||||
return attn_output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
|
return attn_output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
else:
|
else:
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
@@ -370,9 +467,7 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
kv_c = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
kv_c = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
k_pe = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
k_pe = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
||||||
|
|
||||||
if (self.graph_mode or self.use_fia) and (
|
if self.use_fia and (layer.tp_q_head_num // layer.tp_k_head_num) >= 8:
|
||||||
layer.tp_q_head_num // layer.tp_k_head_num
|
|
||||||
) >= 8:
|
|
||||||
"""layer.tp_q_head_num // layer.tp_k_head_num < 8 will support in the later version of CANN"""
|
"""layer.tp_q_head_num // layer.tp_k_head_num < 8 will support in the later version of CANN"""
|
||||||
kv_c = kv_c.view(
|
kv_c = kv_c.view(
|
||||||
-1, self.page_size, layer.tp_k_head_num * self.kv_lora_rank
|
-1, self.page_size, layer.tp_k_head_num * self.kv_lora_rank
|
||||||
|
|||||||
@@ -746,19 +746,25 @@ class DeepEPMoE(EPMoE):
|
|||||||
hidden_states = torch_npu.npu_grouped_matmul(
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
x=[hidden_states],
|
x=[hidden_states],
|
||||||
weight=[self.w13_weight],
|
weight=[self.w13_weight],
|
||||||
scale=[self.w13_weight_scale.to(output_dtype)],
|
|
||||||
per_token_scale=[pertoken_scale],
|
|
||||||
split_item=2,
|
split_item=2,
|
||||||
group_list_type=group_list_type,
|
group_list_type=group_list_type,
|
||||||
group_type=0,
|
group_type=0,
|
||||||
group_list=seg_indptr,
|
group_list=seg_indptr,
|
||||||
output_dtype=output_dtype,
|
output_dtype=torch.int32,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# act_fn: swiglu
|
# act_fn: swiglu
|
||||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||||
|
x=hidden_states,
|
||||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
weight_scale=self.w13_weight_scale.to(torch.float32),
|
||||||
|
activation_scale=pertoken_scale,
|
||||||
|
bias=None,
|
||||||
|
quant_scale=None,
|
||||||
|
quant_offset=None,
|
||||||
|
group_index=seg_indptr,
|
||||||
|
activate_left=True,
|
||||||
|
quant_mode=1,
|
||||||
|
)
|
||||||
|
|
||||||
# gmm2: down_proj
|
# gmm2: down_proj
|
||||||
hidden_states = torch_npu.npu_grouped_matmul(
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
|
|||||||
@@ -304,12 +304,12 @@ class TopK(CustomOp):
|
|||||||
global_num_experts = router_logits.shape[-1]
|
global_num_experts = router_logits.shape[-1]
|
||||||
|
|
||||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||||
if global_num_experts == 256 and self.topk_config.renormalize is True:
|
if global_num_experts == 256:
|
||||||
|
|
||||||
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
|
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
|
||||||
router_logits = router_logits.to(torch.float32)
|
router_logits = router_logits.to(torch.float32)
|
||||||
|
|
||||||
return torch_npu.npu_moe_gating_top_k(
|
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||||
router_logits,
|
router_logits,
|
||||||
k=self.topk_config.top_k,
|
k=self.topk_config.top_k,
|
||||||
bias=self.topk_config.correction_bias.to(torch.float32),
|
bias=self.topk_config.correction_bias.to(torch.float32),
|
||||||
@@ -321,6 +321,16 @@ class TopK(CustomOp):
|
|||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
eps=float(1e-20),
|
eps=float(1e-20),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.topk_config.renormalize:
|
||||||
|
topk_weights_sum = (
|
||||||
|
topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
if self.topk_config.num_fused_shared_experts == 0
|
||||||
|
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
||||||
|
)
|
||||||
|
topk_weights = topk_weights / topk_weights_sum
|
||||||
|
|
||||||
|
return StandardTopKOutput(topk_weights, topk_ids, _)
|
||||||
else:
|
else:
|
||||||
self.topk_config.torch_native = True
|
self.topk_config.torch_native = True
|
||||||
return select_experts(
|
return select_experts(
|
||||||
|
|||||||
@@ -551,7 +551,7 @@ class NPU_W8A8LinearMethodImpl:
|
|||||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||||
params_dict = {}
|
params_dict = {}
|
||||||
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
|
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
|
||||||
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
|
params_dict["input_offset"] = torch.empty(1, dtype=params_dtype)
|
||||||
return params_dict
|
return params_dict
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -582,11 +582,11 @@ class NPU_W8A8LinearMethodImpl:
|
|||||||
if original_dtype != torch.int8:
|
if original_dtype != torch.int8:
|
||||||
x = torch_npu.npu_quantize(
|
x = torch_npu.npu_quantize(
|
||||||
x,
|
x,
|
||||||
layer.aclnn_input_scale,
|
layer.aclnn_input_scale_reciprocal,
|
||||||
layer.aclnn_input_offset,
|
layer.aclnn_input_offset,
|
||||||
torch.qint8,
|
torch.qint8,
|
||||||
-1,
|
-1,
|
||||||
True,
|
False,
|
||||||
)
|
)
|
||||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||||
# bias will not get added more than once in Attention TP>1 case)
|
# bias will not get added more than once in Attention TP>1 case)
|
||||||
@@ -608,6 +608,10 @@ class NPU_W8A8LinearMethodImpl:
|
|||||||
layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
|
layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
|
layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
|
||||||
|
layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
layer.aclnn_input_offset = torch.nn.Parameter(
|
layer.aclnn_input_offset = torch.nn.Parameter(
|
||||||
layer.input_offset.data.repeat(expanding_factor).to(device="npu"),
|
layer.input_offset.data.repeat(expanding_factor).to(device="npu"),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
|
|||||||
@@ -918,6 +918,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|||||||
layer_num,
|
layer_num,
|
||||||
self.size // self.page_size + 1,
|
self.size // self.page_size + 1,
|
||||||
self.page_size,
|
self.page_size,
|
||||||
|
1,
|
||||||
self.kv_lora_rank,
|
self.kv_lora_rank,
|
||||||
),
|
),
|
||||||
dtype=self.store_dtype,
|
dtype=self.store_dtype,
|
||||||
@@ -928,6 +929,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|||||||
layer_num,
|
layer_num,
|
||||||
self.size // self.page_size + 1,
|
self.size // self.page_size + 1,
|
||||||
self.page_size,
|
self.page_size,
|
||||||
|
1,
|
||||||
self.qk_rope_head_dim,
|
self.qk_rope_head_dim,
|
||||||
),
|
),
|
||||||
dtype=self.store_dtype,
|
dtype=self.store_dtype,
|
||||||
@@ -1000,9 +1002,11 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|||||||
layer_id = layer.layer_id
|
layer_id = layer.layer_id
|
||||||
if cache_k.dtype != self.dtype:
|
if cache_k.dtype != self.dtype:
|
||||||
cache_k = cache_k.to(self.dtype)
|
cache_k = cache_k.to(self.dtype)
|
||||||
|
cache_v = cache_v.to(self.dtype)
|
||||||
|
|
||||||
if self.store_dtype != self.dtype:
|
if self.store_dtype != self.dtype:
|
||||||
cache_k = cache_k.view(self.store_dtype)
|
cache_k = cache_k.view(self.store_dtype)
|
||||||
|
cache_v = cache_v.view(self.store_dtype)
|
||||||
|
|
||||||
if cache_v is None:
|
if cache_v is None:
|
||||||
cache_k, cache_v = cache_k.split(
|
cache_k, cache_v = cache_k.split(
|
||||||
|
|||||||
@@ -114,6 +114,7 @@ from sglang.srt.utils import (
|
|||||||
is_flashinfer_available,
|
is_flashinfer_available,
|
||||||
is_hip,
|
is_hip,
|
||||||
is_non_idle_and_non_empty,
|
is_non_idle_and_non_empty,
|
||||||
|
is_npu,
|
||||||
is_sm100_supported,
|
is_sm100_supported,
|
||||||
log_info_on_rank0,
|
log_info_on_rank0,
|
||||||
make_layers,
|
make_layers,
|
||||||
@@ -122,6 +123,7 @@ from sglang.srt.utils import (
|
|||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
_is_npu = is_npu()
|
||||||
_is_fp8_fnuz = is_fp8_fnuz()
|
_is_fp8_fnuz = is_fp8_fnuz()
|
||||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||||
_is_cpu_amx_available = cpu_has_amx_support()
|
_is_cpu_amx_available = cpu_has_amx_support()
|
||||||
@@ -1181,13 +1183,19 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
k[..., : self.qk_nope_head_dim] = k_nope
|
k[..., : self.qk_nope_head_dim] = k_nope
|
||||||
k[..., self.qk_nope_head_dim :] = k_pe
|
k[..., self.qk_nope_head_dim :] = k_pe
|
||||||
|
|
||||||
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
if not _is_npu:
|
||||||
latent_cache[:, :, self.kv_lora_rank :] = k_pe
|
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
||||||
|
latent_cache[:, :, self.kv_lora_rank :] = k_pe
|
||||||
|
|
||||||
# Save latent cache
|
# Save latent cache
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# To reduce a time-costing split operation
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
|
||||||
|
)
|
||||||
|
|
||||||
return q, k, v, forward_batch
|
return q, k, v, forward_batch
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user