[0.11.0]Chery pick pta upgrade change (#3940)
This PR cherry-pick two commit from main to upgrade torch-npu to 2.7.1 official release --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -18,8 +18,6 @@ from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_ascend.utils import version_check
|
||||
|
||||
from ..utils import weak_ref_tensors
|
||||
|
||||
|
||||
@@ -213,32 +211,20 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
|
||||
output,
|
||||
) = param
|
||||
seq_lens = forward_context.attn_metadata[key].seq_lens
|
||||
torch_npu_check = version_check()
|
||||
|
||||
with torch.npu.stream(update_stream):
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
if torch_npu_check:
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output,
|
||||
workspace=graph_params.workspaces.get(runtime_shape))
|
||||
else:
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output)
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output,
|
||||
workspace=graph_params.workspaces.get(runtime_shape))
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
|
||||
Reference in New Issue
Block a user