[Refactor] Refactor Ascend attention implementation forward (#3714)
### What this PR does / why we need it?
This PR refactors the Ascend attention implementation to align with
vLLM's core interfaces, simplifying the code and improving
maintainability.
### Key Changes:
* **Align with vLLM's Attention Interface**: The `forward` method
signature in `AscendAttentionBackendImpl` now matches the base
`AttentionImpl` in vLLM, removing the custom `trace_flag`.
* **Enable Opaque Attention Operator**: By adding `opaque_attention_op`
to `AscendPlatform`, we allow vLLM to wrap our attention kernel in its
standard `vllm.unified_attention_with_output` operator. This avoids the
need for a custom call path.
* **Remove Obsolete Code**:
* The custom op `vllm.unified_ascend_attention_with_output` has been
deleted as it is now redundant.
* The `trace_flag` and its associated logic were removed, reducing code
complexity.
* An outdated quantization branch within the attention implementation
was cleaned up.
* **Improve Readability**: Renamed output variables (`output` vs.
`intermediate_output`) and added comments to clarify the in-place nature
of the attention output.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
No extra tests needed.
- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -293,10 +293,7 @@ class NPUPlatform(Platform):
|
||||
"When enabling VLLM_COMPILE aclgraph, please make sure compilation_config.mode == CompilationMode.VLLM_COMPILE and compilation_config.cudagraph_mode == CUDAGraphMode.VLLM_COMPILE"
|
||||
compilation_config.set_splitting_ops_for_v1()
|
||||
compilation_config.use_inductor = False
|
||||
compilation_config.splitting_ops.extend([
|
||||
"vllm::unified_ascend_attention_with_output",
|
||||
"vllm::mla_forward"
|
||||
])
|
||||
compilation_config.splitting_ops.extend(["vllm::mla_forward"])
|
||||
update_aclgraph_sizes(vllm_config)
|
||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||
logger.info(
|
||||
@@ -453,6 +450,10 @@ class NPUPlatform(Platform):
|
||||
def is_pin_memory_available(cls):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def opaque_attention_op(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_static_graph_wrapper_cls(cls) -> str:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user