[Attn] Support encoder-only attention with torch sdpa (#290)
### What this PR does / why we need it? Support encoder-only attention with torch sdpa fix https://github.com/vllm-project/vllm-ascend/pull/229#issuecomment-2695942741 ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? Test locally with `pytest vllm-project/vllm/tests/entrypoints/openai/test_score.py` **Note**: Since torch compile on npu are still work in process, we need to comment the following code to make UT run: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/vocab_parallel_embedding.py#L138 result: ```bash /home/xxx/miniconda3/envs/atb/lib/python3.10/site-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) ================================================================================== test session starts =================================================================================== platform linux -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0 rootdir: /home/xxx/code/vllm-cpu/vllm configfile: pyproject.toml plugins: shard-0.1.2, rerunfailures-15.0, asyncio-0.25.3, anyio-4.8.0, mock-3.14.0, forked-1.6.0, typeguard-4.3.0 asyncio: mode=strict, asyncio_default_fixture_loop_scope=None collected 8 items Running 8 items in this shard tests/entrypoints/openai/test_score.py ........ [100%] ==================================================================================== warnings summary ==================================================================================== ../../../miniconda3/envs/atb/lib/python3.10/site-packages/torch_npu/dynamo/torchair/__init__.py:8 /home/cmq/miniconda3/envs/atb/lib/python3.10/site-packages/torch_npu/dynamo/torchair/__init__.py:8: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html import pkg_resources -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ======================================================================== 8 passed, 1 warning in 131.42s (0:02:11) ======================================================================== ``` This ut will be included in CI when torch compile feature is done. Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
try:
|
||||
import torch_npu # noqa: F401
|
||||
@@ -715,6 +716,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
# TODO: Remove this contiguous in the future.
|
||||
value = value.contiguous()
|
||||
attn_type = self.attn_type
|
||||
|
||||
output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
@@ -758,23 +760,50 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
if (attn_metadata.block_tables is None
|
||||
or attn_metadata.block_tables.numel() == 0):
|
||||
assert attn_metadata.attn_mask is not None
|
||||
mask = attn_metadata.attn_mask
|
||||
assert attn_metadata.prefill_metadata is not None
|
||||
self.seq_lens_tensor_cpu = torch.from_numpy(
|
||||
np.array(
|
||||
attn_metadata.prefill_metadata.seq_lens).astype(
|
||||
np.int32))
|
||||
torch_npu._npu_flash_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=self.seq_lens_tensor_cpu,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output)
|
||||
if attn_type == AttentionType.ENCODER_ONLY:
|
||||
# TODO: change to use torch_npu encoder attention op, instead
|
||||
# of torch sdpa
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
key = key.movedim(0, key.dim() - 2)
|
||||
value = value.movedim(0, value.dim() - 2)
|
||||
|
||||
causal_attn = (attn_type == AttentionType.DECODER)
|
||||
if attn_metadata.seq_lens is not None:
|
||||
seq_lens_q = seq_lens_kv = attn_metadata.seq_lens
|
||||
attn_masks = [None] * len(seq_lens_q)
|
||||
start_q, start_kv = 0, 0
|
||||
for seq_len_q, seq_len_kv, mask in zip(
|
||||
seq_lens_q, seq_lens_kv, attn_masks):
|
||||
end_q = start_q + seq_len_q
|
||||
end_kv = start_kv + seq_len_kv
|
||||
sub_out = scaled_dot_product_attention(
|
||||
query[None, :, start_q:end_q, :],
|
||||
key[None, :, start_kv:end_kv, :],
|
||||
value[None, :, start_kv:end_kv, :],
|
||||
attn_mask=mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=causal_attn and mask is None,
|
||||
scale=self.scale).squeeze(0).movedim(
|
||||
query.dim() - 2, 0)
|
||||
output[start_q:end_q, :, :] = sub_out
|
||||
start_q, start_kv = end_q, end_kv
|
||||
else:
|
||||
assert attn_metadata.attn_mask is not None
|
||||
mask = attn_metadata.attn_mask
|
||||
assert attn_metadata.prefill_metadata is not None
|
||||
self.seq_lens_tensor_cpu = torch.from_numpy(
|
||||
np.array(attn_metadata.prefill_metadata.seq_lens).
|
||||
astype(np.int32))
|
||||
torch_npu._npu_flash_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=self.seq_lens_tensor_cpu,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output)
|
||||
else:
|
||||
# TODO: Will support prefix cache and chunked prefill soon.
|
||||
raise RuntimeError(
|
||||
|
||||
Reference in New Issue
Block a user