[Attn] Support encoder-only attention with torch sdpa (#290)

### What this PR does / why we need it?
Support encoder-only attention with torch sdpa
fix
https://github.com/vllm-project/vllm-ascend/pull/229#issuecomment-2695942741

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
Test locally with `pytest
vllm-project/vllm/tests/entrypoints/openai/test_score.py`
**Note**: Since torch compile on npu are still work in process, we need
to comment the following code to make UT run:

https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/vocab_parallel_embedding.py#L138

result:
```bash
/home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset.
The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"

  warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET))
================================================================================== test session starts ===================================================================================
platform linux -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0
rootdir: /home/xxx/code/vllm-cpu/vllm
configfile: pyproject.toml
plugins: shard-0.1.2, rerunfailures-15.0, asyncio-0.25.3, anyio-4.8.0, mock-3.14.0, forked-1.6.0, typeguard-4.3.0
asyncio: mode=strict, asyncio_default_fixture_loop_scope=None
collected 8 items                                                                                                                                                                        
Running 8 items in this shard

tests/entrypoints/openai/test_score.py ........                                                                                                                                    [100%]

==================================================================================== warnings summary ====================================================================================
../../../miniconda3/envs/atb/lib/python3.10/site-packages/torch_npu/dynamo/torchair/__init__.py:8
  /home/cmq/miniconda3/envs/atb/lib/python3.10/site-packages/torch_npu/dynamo/torchair/__init__.py:8: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
    import pkg_resources

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================== 8 passed, 1 warning in 131.42s (0:02:11) ========================================================================
```

This ut will be included in CI when torch compile feature is done.

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-03-12 08:57:29 +08:00
committed by GitHub
parent 12aa7115b5
commit 5c7a95b01d

View File

@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import numpy as np
import torch
from torch.nn.functional import scaled_dot_product_attention
try:
import torch_npu # noqa: F401
@@ -715,6 +716,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
value = value.view(-1, self.num_kv_heads, self.head_size)
# TODO: Remove this contiguous in the future.
value = value.contiguous()
attn_type = self.attn_type
output = torch.empty(num_tokens,
self.num_heads,
@@ -758,23 +760,50 @@ class AscendAttentionBackendImpl(AttentionImpl):
if (attn_metadata.block_tables is None
or attn_metadata.block_tables.numel() == 0):
assert attn_metadata.attn_mask is not None
mask = attn_metadata.attn_mask
assert attn_metadata.prefill_metadata is not None
self.seq_lens_tensor_cpu = torch.from_numpy(
np.array(
attn_metadata.prefill_metadata.seq_lens).astype(
np.int32))
torch_npu._npu_flash_attention(
query=query,
key=key,
value=value,
mask=mask,
seq_len=self.seq_lens_tensor_cpu,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output)
if attn_type == AttentionType.ENCODER_ONLY:
# TODO: change to use torch_npu encoder attention op, instead
# of torch sdpa
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
causal_attn = (attn_type == AttentionType.DECODER)
if attn_metadata.seq_lens is not None:
seq_lens_q = seq_lens_kv = attn_metadata.seq_lens
attn_masks = [None] * len(seq_lens_q)
start_q, start_kv = 0, 0
for seq_len_q, seq_len_kv, mask in zip(
seq_lens_q, seq_lens_kv, attn_masks):
end_q = start_q + seq_len_q
end_kv = start_kv + seq_len_kv
sub_out = scaled_dot_product_attention(
query[None, :, start_q:end_q, :],
key[None, :, start_kv:end_kv, :],
value[None, :, start_kv:end_kv, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=causal_attn and mask is None,
scale=self.scale).squeeze(0).movedim(
query.dim() - 2, 0)
output[start_q:end_q, :, :] = sub_out
start_q, start_kv = end_q, end_kv
else:
assert attn_metadata.attn_mask is not None
mask = attn_metadata.attn_mask
assert attn_metadata.prefill_metadata is not None
self.seq_lens_tensor_cpu = torch.from_numpy(
np.array(attn_metadata.prefill_metadata.seq_lens).
astype(np.int32))
torch_npu._npu_flash_attention(
query=query,
key=key,
value=value,
mask=mask,
seq_len=self.seq_lens_tensor_cpu,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output)
else:
# TODO: Will support prefix cache and chunked prefill soon.
raise RuntimeError(