[CI] Upgrade to newest pta.(MLA and FusedMoE) (#189)

Upgrade to newest pta.(MLA and FusedMoE)

---------

Signed-off-by: SidaoY <1024863041@qq.com>
This commit is contained in:
HongtaoYang
2025-02-27 18:50:52 +08:00
committed by GitHub
parent c131e43e7d
commit 1715230867
2 changed files with 26 additions and 56 deletions

View File

@@ -770,13 +770,9 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
num_blocks, block_size, self.num_kv_heads,
self.qk_rope_head_dim + self.kv_lora_rank)
slots = attn_metadata.slot_mapping
torch_npu.npu_reshapecache(key=k_cache,
value=None,
keyCache=key_cache,
valueCache=None,
slotMapping=slots,
compressType=0,
kvCacheCfg=1)
torch_npu._npu_reshape_and_cache_siso(key=k_cache,
key_cache=key_cache,
slot_indices=slots)
if attn_metadata.num_prefills > 0:
attn_output = torch.empty(num_tokens,
@@ -793,32 +789,16 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
self.seq_lens_tensor_cpu = torch.from_numpy(
np.array(attn_metadata.prefill_metadata.seq_lens).astype(
np.int32))
torch_npu.npu_selfattention(query=query,
key=key,
value=value,
kvcacheCfg=0,
mask=mask,
maskType=1,
isTriuMask=0,
seqLen=self.seq_lens_tensor_cpu,
scale=self.scale,
qScale=1,
scaleType=0,
headNum=self.num_heads,
kvHeadNum=self.num_heads,
mlaVHeadSize=0,
calcType=3,
kernelType=0,
clampType=0,
quantType=0,
cacheType=0,
windowSize=0,
clampMin=0,
clampMax=0,
batchRunStatusEnable=False,
inputLayout=0,
outDataType=0,
out=attn_output)
torch_npu._npu_flash_attention(
query=query,
key=key,
value=value,
mask=mask,
seq_len=self.seq_lens_tensor_cpu,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_heads,
out=attn_output)
else:
# TODO: Will support prefix cache and chunked prefill soon.
raise RuntimeError(
@@ -835,25 +815,16 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
np.array(attn_metadata.decode_metadata.seq_lens).astype(
np.int32))
block_tables = attn_metadata.decode_metadata.block_tables
torch_npu.npu_pagedattention(query=query,
keyCache=key_cache,
valueCache=None,
contextLens=self.seq_lens_tensor_cpu,
maskType=0,
kvHeadNum=self.num_kv_heads,
headNum=self.num_heads,
mlaVHeadSize=self.kv_lora_rank,
qkScale=self.scale,
blockTables=block_tables,
batchRunStatusEnable=False,
hasQuantOffset=False,
compressType=0,
calcType=0,
scaleType=0,
quantType=0,
inputLayout=0,
outDataType=-1,
attnOut=attn_output)
torch_npu._npu_paged_attention_mla(
query=query,
key_cache=key_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=block_tables,
context_lens=self.seq_lens_tensor_cpu,
mla_vheadsize=self.kv_lora_rank,
out=attn_output)
attn_output_t = torch.transpose(attn_output, 0, 1)
attn_output_t = torch.bmm(attn_output_t, self.w_vc)
attn_output = torch.transpose(attn_output_t, 0, 1)

View File

@@ -50,10 +50,9 @@ def group_topk(hidden_states: torch.Tensor,
topk_group = 0 if topk_group is None else topk_group
num_expert_group = 0 if num_expert_group is None else num_expert_group
torch_npu.npu_group_topk(input=scores,
out=scores,
group_num=num_expert_group,
k=topk_group)
torch_npu._npu_group_topk(self=scores,
k=topk_group,
group_num=num_expert_group)
if e_score_correction_bias is not None:
topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False)[1]
# Use original unbiased scores for the routing weights