[Refactor] Refactor Ascend attention implementation forward (#3714)
### What this PR does / why we need it?
This PR refactors the Ascend attention implementation to align with
vLLM's core interfaces, simplifying the code and improving
maintainability.
### Key Changes:
* **Align with vLLM's Attention Interface**: The `forward` method
signature in `AscendAttentionBackendImpl` now matches the base
`AttentionImpl` in vLLM, removing the custom `trace_flag`.
* **Enable Opaque Attention Operator**: By adding `opaque_attention_op`
to `AscendPlatform`, we allow vLLM to wrap our attention kernel in its
standard `vllm.unified_attention_with_output` operator. This avoids the
need for a custom call path.
* **Remove Obsolete Code**:
* The custom op `vllm.unified_ascend_attention_with_output` has been
deleted as it is now redundant.
* The `trace_flag` and its associated logic were removed, reducing code
complexity.
* An outdated quantization branch within the attention implementation
was cleaned up.
* **Improve Readability**: Renamed output variables (`output` vs.
`intermediate_output`) and added comments to clarify the in-place nature
of the attention output.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
No extra tests needed.
- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -255,59 +255,6 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
attn_type=self.attention_type.DECODER,
|
attn_type=self.attention_type.DECODER,
|
||||||
kv_sharing_target_layer_name=None)
|
kv_sharing_target_layer_name=None)
|
||||||
|
|
||||||
@patch('torch.ops.vllm.unified_ascend_attention_with_output')
|
|
||||||
def test_forward_trace_flag_true(self, mock_unified_attention):
|
|
||||||
"""Test forward pass when trace_flag is True"""
|
|
||||||
query = torch.randn(10, 8 * 64)
|
|
||||||
key = torch.randn(10, 8 * 64)
|
|
||||||
value = torch.randn(10, 8 * 64)
|
|
||||||
kv_cache = torch.empty(2, 0, 0, 8, 64)
|
|
||||||
metadata = self.attn_metadata
|
|
||||||
layer = self.layer
|
|
||||||
|
|
||||||
output = self.impl.forward(layer,
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
kv_cache,
|
|
||||||
metadata,
|
|
||||||
trace_flag=True)
|
|
||||||
|
|
||||||
mock_unified_attention.assert_called_once()
|
|
||||||
assert output.shape == (10, 8 * 64)
|
|
||||||
|
|
||||||
@patch('torch_npu._npu_paged_attention_splitfuse')
|
|
||||||
def test_forward_with_quant_method(self, mock_paged_attention):
|
|
||||||
"""Test forward pass when layer has quant_method"""
|
|
||||||
query = torch.randn(10, 8 * 64)
|
|
||||||
key = torch.randn(10, 8 * 64)
|
|
||||||
value = torch.randn(10, 8 * 64)
|
|
||||||
k_cache = torch.ones(1, 10, 8, 64, dtype=torch.int8)
|
|
||||||
v_cache = torch.ones(1, 10, 8, 64, dtype=torch.int8)
|
|
||||||
kv_cache = [k_cache, v_cache]
|
|
||||||
ret_value = torch.ones(1, 1, 10, 8, 64, dtype=torch.int8)
|
|
||||||
|
|
||||||
metadata = MagicMock()
|
|
||||||
metadata.num_actual_tokens = torch.randn(10, 8 * 64)
|
|
||||||
metadata.block_tables = torch.randn(10, 8 * 64)
|
|
||||||
metadata.seq_lens = torch.randn(10, 8 * 64)
|
|
||||||
metadata.attn_mask = torch.randn(10, 8 * 64)
|
|
||||||
metadata.query_lens = torch.randn(10, 8 * 64)
|
|
||||||
layer = self.layer
|
|
||||||
layer.quant_method = MagicMock()
|
|
||||||
layer.quant_method.apply.return_value = ret_value
|
|
||||||
|
|
||||||
output = self.impl.forward(layer,
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
kv_cache,
|
|
||||||
metadata,
|
|
||||||
trace_flag=False)
|
|
||||||
|
|
||||||
layer.quant_method.apply.assert_called_once()
|
|
||||||
assert output.shape == (10, 8 * 64)
|
|
||||||
|
|
||||||
def test_forward_no_attn_metadata(self):
|
def test_forward_no_attn_metadata(self):
|
||||||
"""Test forward pass when attn_metadata is None"""
|
"""Test forward pass when attn_metadata is None"""
|
||||||
query = torch.randn(10, 8 * 64)
|
query = torch.randn(10, 8 * 64)
|
||||||
@@ -315,14 +262,10 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
value = torch.randn(10, 8 * 64)
|
value = torch.randn(10, 8 * 64)
|
||||||
kv_cache = torch.empty(2, 0, 0, 8, 64)
|
kv_cache = torch.empty(2, 0, 0, 8, 64)
|
||||||
layer = self.layer_no_quant
|
layer = self.layer_no_quant
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
|
||||||
output = self.impl.forward(layer,
|
output = self.impl.forward(layer, query, key, value, kv_cache, None,
|
||||||
query,
|
output)
|
||||||
key,
|
|
||||||
value,
|
|
||||||
kv_cache,
|
|
||||||
None,
|
|
||||||
trace_flag=False)
|
|
||||||
|
|
||||||
assert output.shape == (10, 8 * 64)
|
assert output.shape == (10, 8 * 64)
|
||||||
|
|
||||||
@@ -335,6 +278,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
key = torch.randn(10, 8 * 64)
|
key = torch.randn(10, 8 * 64)
|
||||||
value = torch.randn(10, 8 * 64)
|
value = torch.randn(10, 8 * 64)
|
||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
metadata = self.attn_metadata
|
||||||
metadata.attn_state = AscendAttentionState.PrefillNoCache
|
metadata.attn_state = AscendAttentionState.PrefillNoCache
|
||||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||||
@@ -344,15 +289,9 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
metadata.num_decodes = 0
|
metadata.num_decodes = 0
|
||||||
metadata.num_prefills = 10
|
metadata.num_prefills = 10
|
||||||
layer = self.layer_no_quant
|
layer = self.layer_no_quant
|
||||||
# layer.quant_method.apply.return_value = metadata
|
|
||||||
print(self.layer_no_quant._v_scale_float)
|
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||||
output = self.impl.forward(layer,
|
metadata, output)
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
kv_cache,
|
|
||||||
metadata,
|
|
||||||
trace_flag=False)
|
|
||||||
|
|
||||||
mock_reshape_cache.assert_called_once()
|
mock_reshape_cache.assert_called_once()
|
||||||
mock_flash_attention.assert_called_once()
|
mock_flash_attention.assert_called_once()
|
||||||
@@ -367,6 +306,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
key = torch.randn(10, 8 * 64)
|
key = torch.randn(10, 8 * 64)
|
||||||
value = torch.randn(10, 8 * 64)
|
value = torch.randn(10, 8 * 64)
|
||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
metadata = self.attn_metadata
|
||||||
metadata.attn_state = AscendAttentionState.PrefillCacheHit
|
metadata.attn_state = AscendAttentionState.PrefillCacheHit
|
||||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||||
@@ -379,13 +320,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
metadata.num_prefills = 10
|
metadata.num_prefills = 10
|
||||||
layer = self.layer_no_quant
|
layer = self.layer_no_quant
|
||||||
|
|
||||||
output = self.impl.forward(layer,
|
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||||
query,
|
metadata, output)
|
||||||
key,
|
|
||||||
value,
|
|
||||||
kv_cache,
|
|
||||||
metadata,
|
|
||||||
trace_flag=False)
|
|
||||||
|
|
||||||
mock_flash_attention_qlens.assert_called_once()
|
mock_flash_attention_qlens.assert_called_once()
|
||||||
assert output.shape == (10, 8 * 64)
|
assert output.shape == (10, 8 * 64)
|
||||||
@@ -401,6 +337,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
key = torch.randn(10, 8 * 64)
|
key = torch.randn(10, 8 * 64)
|
||||||
value = torch.randn(10, 8 * 64)
|
value = torch.randn(10, 8 * 64)
|
||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
metadata = self.attn_metadata
|
||||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||||
metadata.seq_lens = torch.tensor([10])
|
metadata.seq_lens = torch.tensor([10])
|
||||||
@@ -413,13 +351,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
|
|
||||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||||
|
|
||||||
output = self.impl.forward(layer,
|
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||||
query,
|
metadata, output)
|
||||||
key,
|
|
||||||
value,
|
|
||||||
kv_cache,
|
|
||||||
metadata,
|
|
||||||
trace_flag=False)
|
|
||||||
|
|
||||||
mock_paged_attention.assert_called_once()
|
mock_paged_attention.assert_called_once()
|
||||||
assert output.shape == (10, 8 * 64)
|
assert output.shape == (10, 8 * 64)
|
||||||
@@ -509,6 +442,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
key = torch.randn(10, 8 * 64)
|
key = torch.randn(10, 8 * 64)
|
||||||
value = torch.randn(10, 8 * 64)
|
value = torch.randn(10, 8 * 64)
|
||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
metadata = self.attn_metadata
|
||||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||||
metadata.seq_lens = torch.tensor([10])
|
metadata.seq_lens = torch.tensor([10])
|
||||||
@@ -522,13 +457,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
mock_get_forward_context.return_value = MagicMock(capturing=True)
|
mock_get_forward_context.return_value = MagicMock(capturing=True)
|
||||||
mock_get_graph_params.return_value = graph_params
|
mock_get_graph_params.return_value = graph_params
|
||||||
|
|
||||||
output = self.impl.forward(layer,
|
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||||
query,
|
metadata, output)
|
||||||
key,
|
|
||||||
value,
|
|
||||||
kv_cache,
|
|
||||||
metadata,
|
|
||||||
trace_flag=False)
|
|
||||||
|
|
||||||
mock_paged_attention.assert_called_once()
|
mock_paged_attention.assert_called_once()
|
||||||
self.assertEqual(len(graph_params.handles[num_tokens]), 0)
|
self.assertEqual(len(graph_params.handles[num_tokens]), 0)
|
||||||
@@ -542,6 +472,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
key = torch.randn(10, 8 * 64)
|
key = torch.randn(10, 8 * 64)
|
||||||
value = torch.randn(10, 8 * 64)
|
value = torch.randn(10, 8 * 64)
|
||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||||
|
output = torch.empty(10, 8, 64)
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
metadata = self.attn_metadata
|
||||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||||
metadata.seq_lens = torch.tensor([10] * 10)
|
metadata.seq_lens = torch.tensor([10] * 10)
|
||||||
@@ -553,16 +485,11 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
layer = self.layer_no_quant
|
layer = self.layer_no_quant
|
||||||
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
|
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
|
||||||
64), 1)
|
64), 1)
|
||||||
output = self.impl_swa.forward(layer,
|
output = self.impl_swa.forward(layer, query, key, value, kv_cache,
|
||||||
query,
|
metadata, output)
|
||||||
key,
|
|
||||||
value,
|
|
||||||
kv_cache,
|
|
||||||
metadata,
|
|
||||||
trace_flag=False)
|
|
||||||
print(output.shape)
|
print(output.shape)
|
||||||
mock_fused_infer_attention_score.assert_called_once()
|
mock_fused_infer_attention_score.assert_called_once()
|
||||||
assert output.shape == (10, 8 * 64)
|
assert output.shape == (10, 8, 64)
|
||||||
|
|
||||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
@patch('torch_npu._npu_reshape_and_cache')
|
||||||
@@ -576,6 +503,7 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
key = torch.randn(10, 8 * 64)
|
key = torch.randn(10, 8 * 64)
|
||||||
value = torch.randn(10, 8 * 64)
|
value = torch.randn(10, 8 * 64)
|
||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
metadata = self.attn_metadata
|
||||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||||
@@ -583,6 +511,7 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||||
metadata.num_actual_tokens = 10
|
metadata.num_actual_tokens = 10
|
||||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||||
|
layer = self.layer_no_quant
|
||||||
metadata.num_decodes = 10
|
metadata.num_decodes = 10
|
||||||
metadata.num_prefills = 0
|
metadata.num_prefills = 0
|
||||||
|
|
||||||
@@ -591,13 +520,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
|
|
||||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||||
|
|
||||||
output = self.impl_swa.forward(self.layer_no_quant,
|
output = self.impl_swa.forward(layer, query, key, value, kv_cache,
|
||||||
query,
|
metadata, output)
|
||||||
key,
|
|
||||||
value,
|
|
||||||
kv_cache,
|
|
||||||
metadata,
|
|
||||||
trace_flag=False)
|
|
||||||
|
|
||||||
mock_paged_attention.assert_called_once()
|
mock_paged_attention.assert_called_once()
|
||||||
mock_fused_infer_attention_score.assert_not_called()
|
mock_fused_infer_attention_score.assert_not_called()
|
||||||
@@ -618,6 +542,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
key = torch.randn(10, 8 * 192)
|
key = torch.randn(10, 8 * 192)
|
||||||
value = torch.randn(10, 8 * 192)
|
value = torch.randn(10, 8 * 192)
|
||||||
kv_cache = torch.empty(2, 5, 128, 8, 192)
|
kv_cache = torch.empty(2, 5, 128, 8, 192)
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
metadata = self.attn_metadata
|
||||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||||
metadata.query_lens = torch.tensor([10])
|
metadata.query_lens = torch.tensor([10])
|
||||||
@@ -631,13 +557,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
mock_version.cann = "8.4.RC1"
|
mock_version.cann = "8.4.RC1"
|
||||||
mock_vanilla_prefill.return_value = MagicMock()
|
mock_vanilla_prefill.return_value = MagicMock()
|
||||||
|
|
||||||
output = self.impl_192.forward(layer,
|
output = self.impl_192.forward(layer, query, key, value, kv_cache,
|
||||||
query,
|
metadata, output)
|
||||||
key,
|
|
||||||
value,
|
|
||||||
kv_cache,
|
|
||||||
metadata,
|
|
||||||
trace_flag=False)
|
|
||||||
|
|
||||||
mock_vanilla_prefill.assert_called_once()
|
mock_vanilla_prefill.assert_called_once()
|
||||||
assert output.shape == (10, 8 * 192)
|
assert output.shape == (10, 8 * 192)
|
||||||
@@ -653,6 +574,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
key = torch.randn(10, 8 * 64)
|
key = torch.randn(10, 8 * 64)
|
||||||
value = torch.randn(10, 8 * 64)
|
value = torch.randn(10, 8 * 64)
|
||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
metadata = self.attn_metadata
|
||||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||||
metadata.query_lens = torch.tensor([10])
|
metadata.query_lens = torch.tensor([10])
|
||||||
@@ -666,13 +589,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
|
|
||||||
mock_version.cann = "8.4.RC1"
|
mock_version.cann = "8.4.RC1"
|
||||||
|
|
||||||
output = self.impl.forward(layer,
|
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||||
query,
|
metadata, output)
|
||||||
key,
|
|
||||||
value,
|
|
||||||
kv_cache,
|
|
||||||
metadata,
|
|
||||||
trace_flag=False)
|
|
||||||
|
|
||||||
mock_paged_attention.assert_called_once()
|
mock_paged_attention.assert_called_once()
|
||||||
assert output.shape == (10, 8 * 64)
|
assert output.shape == (10, 8 * 64)
|
||||||
@@ -690,6 +608,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
key = torch.randn(10, 8 * 64)
|
key = torch.randn(10, 8 * 64)
|
||||||
value = torch.randn(10, 8 * 64)
|
value = torch.randn(10, 8 * 64)
|
||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
metadata = self.attn_metadata
|
||||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||||
metadata.query_lens = torch.tensor([10])
|
metadata.query_lens = torch.tensor([10])
|
||||||
@@ -703,13 +623,9 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
|
|
||||||
mock_npu_format_cast.return_value = metadata.attn_mask
|
mock_npu_format_cast.return_value = metadata.attn_mask
|
||||||
mock_version.cann = "8.4.RC1"
|
mock_version.cann = "8.4.RC1"
|
||||||
output = self.impl.forward(layer,
|
|
||||||
query,
|
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||||
key,
|
metadata, output)
|
||||||
value,
|
|
||||||
kv_cache,
|
|
||||||
metadata,
|
|
||||||
trace_flag=False)
|
|
||||||
|
|
||||||
mock_paged_attention.assert_called_once()
|
mock_paged_attention.assert_called_once()
|
||||||
assert output.shape == (10, 8 * 64)
|
assert output.shape == (10, 8 * 64)
|
||||||
@@ -720,6 +636,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
key = torch.randn(10, 8 * 64)
|
key = torch.randn(10, 8 * 64)
|
||||||
value = torch.randn(10, 8 * 64)
|
value = torch.randn(10, 8 * 64)
|
||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
metadata = self.attn_metadata
|
||||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||||
metadata.query_lens = torch.tensor([10])
|
metadata.query_lens = torch.tensor([10])
|
||||||
@@ -732,10 +650,5 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
layer = self.layer_no_quant
|
layer = self.layer_no_quant
|
||||||
|
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
self.impl_error.forward(layer,
|
self.impl_error.forward(layer, query, key, value, kv_cache,
|
||||||
query,
|
metadata, output)
|
||||||
key,
|
|
||||||
value,
|
|
||||||
kv_cache,
|
|
||||||
metadata,
|
|
||||||
trace_flag=False)
|
|
||||||
|
|||||||
@@ -91,5 +91,5 @@ class TestAscendAttentionTorchairBackendImpl(TestBase):
|
|||||||
torch.ones(1))
|
torch.ones(1))
|
||||||
|
|
||||||
result = self.impl.forward(layer, query, key, value, kv_cache,
|
result = self.impl.forward(layer, query, key, value, kv_cache,
|
||||||
metadata, output, False)
|
metadata, output)
|
||||||
self.assertEqual(result.shape[0], num_tokens)
|
self.assertEqual(result.shape[0], num_tokens)
|
||||||
|
|||||||
@@ -32,25 +32,22 @@ from vllm.distributed import (get_dcp_group,
|
|||||||
get_decode_context_model_parallel_rank,
|
get_decode_context_model_parallel_rank,
|
||||||
get_decode_context_model_parallel_world_size)
|
get_decode_context_model_parallel_world_size)
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
from vllm.utils import cdiv, direct_register_custom_op
|
from vllm.utils import cdiv
|
||||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
# isort: off
|
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
maybe_save_kv_layer_to_connector,
|
split_decodes_and_prefills)
|
||||||
split_decodes_and_prefills,
|
|
||||||
wait_for_kv_layer_from_connector)
|
|
||||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||||
update_graph_params_workspaces)
|
update_graph_params_workspaces)
|
||||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
||||||
nd_to_nz_2d, nd_to_nz_spec,
|
nd_to_nz_2d, nd_to_nz_spec,
|
||||||
prefill_context_parallel_enable, version_check)
|
prefill_context_parallel_enable, version_check,
|
||||||
|
weak_ref_tensors)
|
||||||
from ..utils import weak_ref_tensors
|
|
||||||
|
|
||||||
|
# isort: off
|
||||||
if prefill_context_parallel_enable():
|
if prefill_context_parallel_enable():
|
||||||
from vllm.distributed import (get_pcp_group,
|
from vllm.distributed import (get_pcp_group,
|
||||||
get_prefill_context_model_parallel_rank,
|
get_prefill_context_model_parallel_rank,
|
||||||
@@ -484,7 +481,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
out=output)
|
out=output)
|
||||||
assert output is not None
|
assert output is not None
|
||||||
return output[:num_tokens, :, :]
|
return output[:num_tokens]
|
||||||
|
|
||||||
def _forward_prefill_cache_hit(
|
def _forward_prefill_cache_hit(
|
||||||
self,
|
self,
|
||||||
@@ -937,169 +934,14 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
attn_lse_allgather)
|
attn_lse_allgather)
|
||||||
return attn_out
|
return attn_out
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
layer: AttentionLayer,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
kv_cache: Tuple[torch.Tensor],
|
|
||||||
attn_metadata: AscendMetadata,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
|
||||||
trace_flag: bool = True,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Forward pass with Ascend attention.
|
|
||||||
Args:
|
|
||||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
|
||||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
|
||||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
|
||||||
kv_cache: shape = [key_cache, value_cache]
|
|
||||||
key_cache = [num_blocks, block_size,
|
|
||||||
num_kv_heads, head_size]
|
|
||||||
value_cache = [num_blocks, block_size,
|
|
||||||
num_kv_heads, head_size]
|
|
||||||
attn_metadata: Metadata for attention.
|
|
||||||
Returns:
|
|
||||||
shape = [batch_size * seq_len, num_heads, head_size]
|
|
||||||
"""
|
|
||||||
num_tokens = query.shape[0]
|
|
||||||
use_kv_cache_int8 = len(
|
|
||||||
kv_cache) > 0 and kv_cache[0].dtype == torch.int8
|
|
||||||
if output is None:
|
|
||||||
output = torch.empty(num_tokens,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_size,
|
|
||||||
dtype=query.dtype,
|
|
||||||
device=query.device)
|
|
||||||
ori_output = output
|
|
||||||
if trace_flag:
|
|
||||||
torch.ops.vllm.unified_ascend_attention_with_output(
|
|
||||||
query=query,
|
|
||||||
key=key,
|
|
||||||
value=value,
|
|
||||||
output=output,
|
|
||||||
layer_name=layer.layer_name)
|
|
||||||
|
|
||||||
elif hasattr(layer, 'quant_method') and use_kv_cache_int8:
|
|
||||||
output = layer.quant_method.apply(layer, query, key, value,
|
|
||||||
kv_cache, attn_metadata,
|
|
||||||
self.attn_type, self.scale,
|
|
||||||
output)
|
|
||||||
|
|
||||||
else:
|
|
||||||
if attn_metadata is None:
|
|
||||||
return output.view(num_tokens, self.hidden_size).fill_(0)
|
|
||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
|
||||||
has_decode = attn_metadata.num_decodes > 0
|
|
||||||
has_prefill = attn_metadata.num_prefills > 0
|
|
||||||
|
|
||||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
|
||||||
attn_type = self.attn_type
|
|
||||||
if attn_type != AttentionType.DECODER and attn_type != AttentionType.ENCODER_ONLY:
|
|
||||||
raise NotImplementedError("Encoder/decoder cross-attention "
|
|
||||||
"are not implemented for "
|
|
||||||
"PallasAttentionBackendImpl")
|
|
||||||
# View q k v to BSH.
|
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
|
||||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
|
||||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
|
||||||
# TODO: Remove this contiguous in the future.
|
|
||||||
value = value.contiguous()
|
|
||||||
|
|
||||||
if len(kv_cache) > 1:
|
|
||||||
if self.key_cache is None:
|
|
||||||
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
|
||||||
|
|
||||||
if has_decode:
|
|
||||||
slot_mapping = attn_metadata.slot_mapping[:num_decode_tokens * self.pcp_size: self.pcp_size] \
|
|
||||||
if self.pcp_size * self.dcp_size > 1 else attn_metadata.slot_mapping[:num_decode_tokens]
|
|
||||||
torch_npu._npu_reshape_and_cache(
|
|
||||||
key=key[:num_decode_tokens],
|
|
||||||
value=value[:num_decode_tokens],
|
|
||||||
key_cache=self.key_cache,
|
|
||||||
value_cache=self.value_cache,
|
|
||||||
slot_indices=slot_mapping)
|
|
||||||
|
|
||||||
if has_prefill:
|
|
||||||
if self.pcp_size > 1:
|
|
||||||
kv = torch.cat([key, value], dim=-1)
|
|
||||||
all_kv = get_pcp_group().all_gather(kv, dim=0)
|
|
||||||
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_allgather_restore_idx if attn_metadata.prefill else None
|
|
||||||
all_kv = torch.index_select(all_kv, 0,
|
|
||||||
pcp_allgather_restore_idx)
|
|
||||||
key, value = all_kv.split(
|
|
||||||
[self.head_size, self.head_size], dim=-1)
|
|
||||||
|
|
||||||
torch_npu._npu_reshape_and_cache(
|
|
||||||
key=key[self.pcp_size *
|
|
||||||
num_decode_tokens:attn_metadata.
|
|
||||||
num_actual_tokens_pcp_padded],
|
|
||||||
value=value[self.pcp_size *
|
|
||||||
num_decode_tokens:attn_metadata.
|
|
||||||
num_actual_tokens_pcp_padded],
|
|
||||||
key_cache=self.key_cache,
|
|
||||||
value_cache=self.value_cache,
|
|
||||||
slot_indices=attn_metadata.
|
|
||||||
slot_mapping[self.pcp_size *
|
|
||||||
num_decode_tokens:attn_metadata.
|
|
||||||
num_actual_tokens_pcp_padded])
|
|
||||||
|
|
||||||
if self.pcp_size * self.dcp_size > 1:
|
|
||||||
output = self._forward_pcp_dcp(query, key, value,
|
|
||||||
attn_metadata, output)
|
|
||||||
|
|
||||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
|
||||||
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
|
|
||||||
attn_out = torch_npu.npu_fusion_attention(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
head_num=self.num_heads,
|
|
||||||
input_layout="TND",
|
|
||||||
scale=self.scale,
|
|
||||||
sparse_mode=4,
|
|
||||||
atten_mask=attn_metadata.attn_mask,
|
|
||||||
next_tockens=attn_metadata.max_query_len,
|
|
||||||
actual_seq_qlen=cum_seq_len,
|
|
||||||
actual_seq_kvlen=cum_seq_len,
|
|
||||||
)
|
|
||||||
output = attn_out[0]
|
|
||||||
# V0-Style scheduler situation.
|
|
||||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
|
||||||
output = self._forward_prefill_no_cache(
|
|
||||||
query, key, value, attn_metadata, output, num_tokens)
|
|
||||||
elif attn_metadata.attn_state == \
|
|
||||||
AscendAttentionState.PrefillCacheHit:
|
|
||||||
output = self._forward_prefill_cache_hit(
|
|
||||||
query, attn_metadata, output)
|
|
||||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
|
||||||
output = self._forward_decode_only(query, attn_metadata,
|
|
||||||
output)
|
|
||||||
# Normal V1 situation.
|
|
||||||
else:
|
|
||||||
if torch.version.cann.startswith("8.3"):
|
|
||||||
# npu_fused_infer_attention_score does not support cases
|
|
||||||
# where query.shape[0] != attn_metadata.query_start_loc[-1].
|
|
||||||
# Thus we need unpad it here.
|
|
||||||
num_tokens = attn_metadata.query_start_loc[-1]
|
|
||||||
query = query[:num_tokens]
|
|
||||||
output = self._forward_v1_style(query, attn_metadata, output)
|
|
||||||
|
|
||||||
# to make in-place change to the output tensor
|
|
||||||
if hasattr(layer, 'quant_method') and use_kv_cache_int8:
|
|
||||||
output = output.view(num_tokens, self.num_heads, self.head_size)
|
|
||||||
ori_output[:num_tokens, :, :] = output[:num_tokens, :, :]
|
|
||||||
return output.view(num_tokens, self.hidden_size)
|
|
||||||
|
|
||||||
def _forward_pcp_dcp(self, query: torch.Tensor, key: torch.Tensor,
|
def _forward_pcp_dcp(self, query: torch.Tensor, key: torch.Tensor,
|
||||||
value: torch.Tensor, attn_metadata: AscendMetadata,
|
value: torch.Tensor, attn_metadata: AscendMetadata,
|
||||||
output: Optional[torch.Tensor]) -> torch.Tensor:
|
output: torch.Tensor) -> torch.Tensor:
|
||||||
assert attn_metadata is not None
|
assert attn_metadata is not None
|
||||||
has_decode = attn_metadata.num_decodes > 0
|
has_decode = attn_metadata.num_decodes > 0
|
||||||
has_prefill = attn_metadata.num_prefills > 0
|
has_prefill = attn_metadata.num_prefills > 0
|
||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
if output is None:
|
|
||||||
raise ValueError("Output buffer is required")
|
|
||||||
if has_decode:
|
if has_decode:
|
||||||
decode_query = query[:num_decode_tokens]
|
decode_query = query[:num_decode_tokens]
|
||||||
output_decode = self._forward_decode_pcp_dcp(
|
output_decode = self._forward_decode_pcp_dcp(
|
||||||
@@ -1131,47 +973,137 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
output[num_decode_tokens:] = output_prefill
|
output[num_decode_tokens:] = output_prefill
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def forward(
|
||||||
def unified_ascend_attention_with_output(
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
output: torch.Tensor,
|
kv_cache: Tuple[torch.Tensor],
|
||||||
layer_name: str,
|
attn_metadata: AscendMetadata,
|
||||||
) -> None:
|
output: Optional[torch.Tensor] = None,
|
||||||
wait_for_kv_layer_from_connector(layer_name)
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
forward_context: ForwardContext = get_forward_context()
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
attn_metadata = forward_context.attn_metadata
|
) -> torch.Tensor:
|
||||||
if isinstance(attn_metadata, dict):
|
"""Forward pass with Ascend attention.
|
||||||
attn_metadata = attn_metadata[layer_name]
|
Args:
|
||||||
self = forward_context.no_compile_layers[layer_name]
|
query: shape = [num_tokens, num_heads, head_size]
|
||||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||||
self.impl.forward(self,
|
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||||
|
kv_cache: shape =
|
||||||
|
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||||
|
attn_metadata: Metadata for attention.
|
||||||
|
Returns:
|
||||||
|
shape = [num_tokens, num_heads * head_size]
|
||||||
|
"""
|
||||||
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
|
if output_scale is not None or output_block_scale is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"fused output quantization is not yet supported"
|
||||||
|
" for AscendAttentionBackendImpl")
|
||||||
|
|
||||||
|
num_tokens = query.shape[0]
|
||||||
|
if attn_metadata is None:
|
||||||
|
return output
|
||||||
|
|
||||||
|
# NOTE: Currently, we have various attention paths for different
|
||||||
|
# scenarios, and not all of them are in-place operations. Therefore,
|
||||||
|
# we need to create a separate tensor to hold the attention result.
|
||||||
|
# In the future, we may consolidate them into fewer paths, which will
|
||||||
|
# hopefully allow us to use in-place operation by default.
|
||||||
|
intermediate_output: torch.Tensor
|
||||||
|
|
||||||
|
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||||
|
attn_type = self.attn_type
|
||||||
|
if attn_type != AttentionType.DECODER and attn_type != AttentionType.ENCODER_ONLY:
|
||||||
|
raise NotImplementedError("Encoder/decoder cross-attention "
|
||||||
|
"are not implemented for "
|
||||||
|
"PallasAttentionBackendImpl")
|
||||||
|
|
||||||
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
|
has_decode = attn_metadata.num_decodes > 0
|
||||||
|
has_prefill = attn_metadata.num_prefills > 0
|
||||||
|
|
||||||
|
if len(kv_cache) > 1:
|
||||||
|
if self.key_cache is None:
|
||||||
|
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
||||||
|
|
||||||
|
if has_decode:
|
||||||
|
slot_mapping = attn_metadata.slot_mapping[:num_decode_tokens * self.pcp_size: self.pcp_size] \
|
||||||
|
if self.pcp_size * self.dcp_size > 1 else attn_metadata.slot_mapping[:num_decode_tokens]
|
||||||
|
torch_npu._npu_reshape_and_cache(
|
||||||
|
key=key[:num_decode_tokens],
|
||||||
|
value=value[:num_decode_tokens],
|
||||||
|
key_cache=self.key_cache,
|
||||||
|
value_cache=self.value_cache,
|
||||||
|
slot_indices=slot_mapping)
|
||||||
|
|
||||||
|
if has_prefill:
|
||||||
|
if self.pcp_size > 1:
|
||||||
|
kv = torch.cat([key, value], dim=-1)
|
||||||
|
all_kv = get_pcp_group().all_gather(kv, dim=0)
|
||||||
|
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_allgather_restore_idx if attn_metadata.prefill else None
|
||||||
|
all_kv = torch.index_select(all_kv, 0,
|
||||||
|
pcp_allgather_restore_idx)
|
||||||
|
key, value = all_kv.split([self.head_size, self.head_size],
|
||||||
|
dim=-1)
|
||||||
|
|
||||||
|
torch_npu._npu_reshape_and_cache(
|
||||||
|
key=key[self.pcp_size * num_decode_tokens:attn_metadata.
|
||||||
|
num_actual_tokens_pcp_padded],
|
||||||
|
value=value[self.pcp_size *
|
||||||
|
num_decode_tokens:attn_metadata.
|
||||||
|
num_actual_tokens_pcp_padded],
|
||||||
|
key_cache=self.key_cache,
|
||||||
|
value_cache=self.value_cache,
|
||||||
|
slot_indices=attn_metadata.
|
||||||
|
slot_mapping[self.pcp_size *
|
||||||
|
num_decode_tokens:attn_metadata.
|
||||||
|
num_actual_tokens_pcp_padded])
|
||||||
|
|
||||||
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
|
intermediate_output = self._forward_pcp_dcp(
|
||||||
|
query, key, value, attn_metadata, output)
|
||||||
|
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||||
|
# TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly.
|
||||||
|
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
|
||||||
|
intermediate_output = torch_npu.npu_fusion_attention(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
kv_cache,
|
head_num=self.num_heads,
|
||||||
attn_metadata,
|
input_layout="TND",
|
||||||
output,
|
scale=self.scale,
|
||||||
trace_flag=False)
|
sparse_mode=4,
|
||||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
atten_mask=attn_metadata.attn_mask,
|
||||||
return
|
pre_tockens=attn_metadata.max_query_len,
|
||||||
|
next_tockens=attn_metadata.max_query_len,
|
||||||
|
actual_seq_qlen=cum_seq_len,
|
||||||
|
actual_seq_kvlen=cum_seq_len,
|
||||||
|
)[0]
|
||||||
|
# V0-Style scheduler situation.
|
||||||
|
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||||
|
intermediate_output = self._forward_prefill_no_cache(
|
||||||
|
query, key, value, attn_metadata, output, num_tokens)
|
||||||
|
elif attn_metadata.attn_state == \
|
||||||
|
AscendAttentionState.PrefillCacheHit:
|
||||||
|
intermediate_output = self._forward_prefill_cache_hit(
|
||||||
|
query, attn_metadata, output)
|
||||||
|
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||||
|
intermediate_output = self._forward_decode_only(
|
||||||
|
query, attn_metadata, output)
|
||||||
|
# Normal V1 situation.
|
||||||
|
else:
|
||||||
|
if torch.version.cann.startswith("8.3"):
|
||||||
|
# npu_fused_infer_attention_score does not support cases
|
||||||
|
# where query.shape[0] != attn_metadata.query_start_loc[-1].
|
||||||
|
# Thus we need unpad it here.
|
||||||
|
num_tokens = attn_metadata.query_start_loc[-1]
|
||||||
|
query = query[:num_tokens]
|
||||||
|
intermediate_output = self._forward_v1_style(
|
||||||
|
query, attn_metadata, output)
|
||||||
|
|
||||||
|
output[:num_tokens] = intermediate_output[:num_tokens]
|
||||||
|
|
||||||
def unified_attention_with_output_fake(
|
return output
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
output: torch.Tensor,
|
|
||||||
layer_name: str,
|
|
||||||
) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="unified_ascend_attention_with_output",
|
|
||||||
op_func=unified_ascend_attention_with_output,
|
|
||||||
mutates_args=["output"],
|
|
||||||
fake_impl=unified_attention_with_output_fake,
|
|
||||||
dispatch_key="PrivateUse1",
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -293,10 +293,7 @@ class NPUPlatform(Platform):
|
|||||||
"When enabling VLLM_COMPILE aclgraph, please make sure compilation_config.mode == CompilationMode.VLLM_COMPILE and compilation_config.cudagraph_mode == CUDAGraphMode.VLLM_COMPILE"
|
"When enabling VLLM_COMPILE aclgraph, please make sure compilation_config.mode == CompilationMode.VLLM_COMPILE and compilation_config.cudagraph_mode == CUDAGraphMode.VLLM_COMPILE"
|
||||||
compilation_config.set_splitting_ops_for_v1()
|
compilation_config.set_splitting_ops_for_v1()
|
||||||
compilation_config.use_inductor = False
|
compilation_config.use_inductor = False
|
||||||
compilation_config.splitting_ops.extend([
|
compilation_config.splitting_ops.extend(["vllm::mla_forward"])
|
||||||
"vllm::unified_ascend_attention_with_output",
|
|
||||||
"vllm::mla_forward"
|
|
||||||
])
|
|
||||||
update_aclgraph_sizes(vllm_config)
|
update_aclgraph_sizes(vllm_config)
|
||||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -453,6 +450,10 @@ class NPUPlatform(Platform):
|
|||||||
def is_pin_memory_available(cls):
|
def is_pin_memory_available(cls):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def opaque_attention_op(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_static_graph_wrapper_cls(cls) -> str:
|
def get_static_graph_wrapper_cls(cls) -> str:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -125,7 +125,6 @@ class CustomQwen2Attention(Qwen2Attention):
|
|||||||
v,
|
v,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
trace_flag=False,
|
|
||||||
**forward_kwargs)
|
**forward_kwargs)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -257,7 +257,6 @@ class CustomQwen3MoeAttention(Qwen3MoeAttention):
|
|||||||
v,
|
v,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
trace_flag=False,
|
|
||||||
**forward_kwargs)
|
**forward_kwargs)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -625,7 +625,7 @@ class PanguProMoEAttention(nn.Module):
|
|||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
if self.torchair_graph_enabled:
|
if self.torchair_graph_enabled:
|
||||||
forward_kwargs = {'trace_flag': False}
|
forward_kwargs = {}
|
||||||
output_shape = q.shape
|
output_shape = q.shape
|
||||||
attn_output = torch.empty(output_shape,
|
attn_output = torch.empty(output_shape,
|
||||||
dtype=q.dtype,
|
dtype=q.dtype,
|
||||||
|
|||||||
@@ -314,7 +314,6 @@ class AscendAttentionTorchairBackendImpl(AttentionImpl):
|
|||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: AscendTorchairMetadata,
|
attn_metadata: AscendTorchairMetadata,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
trace_flag: bool = False,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with Ascend attention.
|
"""Forward pass with Ascend attention.
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -3740,7 +3740,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
splitting_ops_contain_attention = (
|
splitting_ops_contain_attention = (
|
||||||
self.compilation_config.splitting_ops is not None
|
self.compilation_config.splitting_ops is not None
|
||||||
and all(op in self.compilation_config.splitting_ops for op in [
|
and all(op in self.compilation_config.splitting_ops for op in [
|
||||||
"vllm.unified_ascend_attention_with_output",
|
|
||||||
"vllm.mla_forward",
|
"vllm.mla_forward",
|
||||||
]))
|
]))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user