[CI] Upgrade to newest pta.(MLA and FusedMoE) (#189)

Upgrade to newest pta.(MLA and FusedMoE)

---------

Signed-off-by: SidaoY <1024863041@qq.com>
This commit is contained in:
HongtaoYang
2025-02-27 18:50:52 +08:00
committed by GitHub
parent c131e43e7d
commit 1715230867
2 changed files with 26 additions and 56 deletions

View File

@@ -770,13 +770,9 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
num_blocks, block_size, self.num_kv_heads, num_blocks, block_size, self.num_kv_heads,
self.qk_rope_head_dim + self.kv_lora_rank) self.qk_rope_head_dim + self.kv_lora_rank)
slots = attn_metadata.slot_mapping slots = attn_metadata.slot_mapping
torch_npu.npu_reshapecache(key=k_cache, torch_npu._npu_reshape_and_cache_siso(key=k_cache,
value=None, key_cache=key_cache,
keyCache=key_cache, slot_indices=slots)
valueCache=None,
slotMapping=slots,
compressType=0,
kvCacheCfg=1)
if attn_metadata.num_prefills > 0: if attn_metadata.num_prefills > 0:
attn_output = torch.empty(num_tokens, attn_output = torch.empty(num_tokens,
@@ -793,32 +789,16 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
self.seq_lens_tensor_cpu = torch.from_numpy( self.seq_lens_tensor_cpu = torch.from_numpy(
np.array(attn_metadata.prefill_metadata.seq_lens).astype( np.array(attn_metadata.prefill_metadata.seq_lens).astype(
np.int32)) np.int32))
torch_npu.npu_selfattention(query=query, torch_npu._npu_flash_attention(
key=key, query=query,
value=value, key=key,
kvcacheCfg=0, value=value,
mask=mask, mask=mask,
maskType=1, seq_len=self.seq_lens_tensor_cpu,
isTriuMask=0, scale_value=self.scale,
seqLen=self.seq_lens_tensor_cpu, num_heads=self.num_heads,
scale=self.scale, num_kv_heads=self.num_heads,
qScale=1, out=attn_output)
scaleType=0,
headNum=self.num_heads,
kvHeadNum=self.num_heads,
mlaVHeadSize=0,
calcType=3,
kernelType=0,
clampType=0,
quantType=0,
cacheType=0,
windowSize=0,
clampMin=0,
clampMax=0,
batchRunStatusEnable=False,
inputLayout=0,
outDataType=0,
out=attn_output)
else: else:
# TODO: Will support prefix cache and chunked prefill soon. # TODO: Will support prefix cache and chunked prefill soon.
raise RuntimeError( raise RuntimeError(
@@ -835,25 +815,16 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
np.array(attn_metadata.decode_metadata.seq_lens).astype( np.array(attn_metadata.decode_metadata.seq_lens).astype(
np.int32)) np.int32))
block_tables = attn_metadata.decode_metadata.block_tables block_tables = attn_metadata.decode_metadata.block_tables
torch_npu.npu_pagedattention(query=query, torch_npu._npu_paged_attention_mla(
keyCache=key_cache, query=query,
valueCache=None, key_cache=key_cache,
contextLens=self.seq_lens_tensor_cpu, num_kv_heads=self.num_kv_heads,
maskType=0, num_heads=self.num_heads,
kvHeadNum=self.num_kv_heads, scale_value=self.scale,
headNum=self.num_heads, block_table=block_tables,
mlaVHeadSize=self.kv_lora_rank, context_lens=self.seq_lens_tensor_cpu,
qkScale=self.scale, mla_vheadsize=self.kv_lora_rank,
blockTables=block_tables, out=attn_output)
batchRunStatusEnable=False,
hasQuantOffset=False,
compressType=0,
calcType=0,
scaleType=0,
quantType=0,
inputLayout=0,
outDataType=-1,
attnOut=attn_output)
attn_output_t = torch.transpose(attn_output, 0, 1) attn_output_t = torch.transpose(attn_output, 0, 1)
attn_output_t = torch.bmm(attn_output_t, self.w_vc) attn_output_t = torch.bmm(attn_output_t, self.w_vc)
attn_output = torch.transpose(attn_output_t, 0, 1) attn_output = torch.transpose(attn_output_t, 0, 1)

View File

@@ -50,10 +50,9 @@ def group_topk(hidden_states: torch.Tensor,
topk_group = 0 if topk_group is None else topk_group topk_group = 0 if topk_group is None else topk_group
num_expert_group = 0 if num_expert_group is None else num_expert_group num_expert_group = 0 if num_expert_group is None else num_expert_group
torch_npu.npu_group_topk(input=scores, torch_npu._npu_group_topk(self=scores,
out=scores, k=topk_group,
group_num=num_expert_group, group_num=num_expert_group)
k=topk_group)
if e_score_correction_bias is not None: if e_score_correction_bias is not None:
topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False)[1] topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False)[1]
# Use original unbiased scores for the routing weights # Use original unbiased scores for the routing weights