Revert PTA upgrade PR (#3352)
we notice that torch npu 0919 doesn't work. This PR revert related change which rely on 0919 version. Revert PR: #3295 #3205 #3102 Related: #3353 - vLLM version: v0.11.0
This commit is contained in:
@@ -215,17 +215,15 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
|
||||
|
||||
with torch.npu.stream(update_stream):
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output,
|
||||
workspace=graph_params.workspaces.get(runtime_shape))
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
@@ -258,11 +256,5 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
|
||||
)
|
||||
|
||||
|
||||
def update_graph_params_workspaces(num_tokens: int, workspace: int):
|
||||
global _graph_params
|
||||
if _graph_params is not None:
|
||||
_graph_params.workspaces[num_tokens] = workspace
|
||||
|
||||
|
||||
def get_graph_params():
|
||||
return _graph_params
|
||||
|
||||
Reference in New Issue
Block a user