[CI] upgrade to newest pta (#187)
Upgrade to newest torch-npu Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
@@ -577,13 +577,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self.num_kv_heads,
|
||||
self.head_size)
|
||||
slots = attn_metadata.slot_mapping
|
||||
torch_npu.npu_reshapecache(key=key,
|
||||
value=value,
|
||||
keyCache=key_cache,
|
||||
valueCache=value_cache,
|
||||
slotMapping=slots,
|
||||
compressType=0,
|
||||
kvCacheCfg=0)
|
||||
torch_npu._npu_reshape_and_cache(key=key,
|
||||
value=value,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
slot_indices=slots)
|
||||
|
||||
if attn_metadata.num_prefills > 0:
|
||||
|
||||
@@ -596,32 +594,15 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
np.array(
|
||||
attn_metadata.prefill_metadata.seq_lens).astype(
|
||||
np.int32))
|
||||
torch_npu.npu_selfattention(
|
||||
torch_npu._npu_flash_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
maskType=1,
|
||||
isTriuMask=0,
|
||||
seqLen=self.seq_lens_tensor_cpu,
|
||||
scale=self.scale,
|
||||
qScale=1,
|
||||
headNum=self.num_heads,
|
||||
kvHeadNum=self.num_kv_heads,
|
||||
mlaVHeadSize=0,
|
||||
calcType=3,
|
||||
kernelType=0,
|
||||
clampType=0,
|
||||
scaleType=0,
|
||||
quantType=0,
|
||||
cacheType=0,
|
||||
batchRunStatusEnable=False,
|
||||
kvcacheCfg=0,
|
||||
clampMin=0,
|
||||
clampMax=0,
|
||||
inputLayout=0,
|
||||
windowSize=0,
|
||||
outDataType=0,
|
||||
seq_len=self.seq_lens_tensor_cpu,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output)
|
||||
else:
|
||||
# TODO: Will support prefix cache and chunked prefill soon.
|
||||
@@ -634,26 +615,16 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
||||
np.int32))
|
||||
block_tables = attn_metadata.decode_metadata.block_tables
|
||||
torch_npu.npu_pagedattention(
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
keyCache=key_cache,
|
||||
valueCache=value_cache,
|
||||
contextLens=self.seq_lens_tensor_cpu,
|
||||
maskType=0,
|
||||
kvHeadNum=self.num_kv_heads,
|
||||
headNum=self.num_heads,
|
||||
mlaVHeadSize=0,
|
||||
qkScale=self.scale,
|
||||
scaleType=0,
|
||||
blockTables=block_tables,
|
||||
batchRunStatusEnable=False,
|
||||
hasQuantOffset=False,
|
||||
calcType=3,
|
||||
quantType=0,
|
||||
compressType=0,
|
||||
inputLayout=0,
|
||||
outDataType=0,
|
||||
attnOut=output)
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=block_tables,
|
||||
context_lens=self.seq_lens_tensor_cpu,
|
||||
out=output)
|
||||
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user