[Feature]: Support 310P device run qwen2.5/3 dense and qwen2.5vl models (#5776)
### What this PR does / why we need it?
Add basic 310p support. Only dense models work with eager mode now.
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: Tflowers-0129 <2906339855@qq.com>
Signed-off-by: Shaoxu Cheng <2906339855@qq.com>
This commit is contained in:
0
vllm_ascend/_310p/attention/__init__.py
Normal file
0
vllm_ascend/_310p/attention/__init__.py
Normal file
98
vllm_ascend/_310p/attention/attention_mask.py
Normal file
98
vllm_ascend/_310p/attention/attention_mask.py
Normal file
@@ -0,0 +1,98 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
import vllm_ascend.attention.attention_mask as _base_mask
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, nd_to_nz_spec
|
||||
|
||||
_BASE_BUILDER: Callable[[torch.device], Any] = _base_mask.AttentionMaskBuilder
|
||||
|
||||
|
||||
def _gen_causal_additive_mask_fp16(max_seq_len: int, device: torch.device) -> torch.Tensor:
|
||||
tril = torch.ones((max_seq_len, max_seq_len), dtype=torch.bool, device=device).tril_()
|
||||
upper = ~tril
|
||||
m = torch.zeros((max_seq_len, max_seq_len), dtype=torch.float16, device=device)
|
||||
m.masked_fill_(upper, float("-inf"))
|
||||
return m
|
||||
|
||||
|
||||
def build_splitfuse_attn_mask_310p(attn_metadata, device, *, full_mask_cache=None, full_mask_cache_len=0):
|
||||
qsl = attn_metadata.query_start_loc.detach().to("cpu", dtype=torch.int32)
|
||||
qlens = qsl[1:] - qsl[:-1]
|
||||
|
||||
context_lens = attn_metadata.seq_lens.to(dtype=torch.int32)
|
||||
L = int(context_lens.max().item())
|
||||
|
||||
q_list = qlens.tolist()
|
||||
c_list = context_lens.detach().to("cpu", dtype=torch.int64).tolist()
|
||||
pos_list = [p for ql, cl in zip(q_list, c_list) for p in range(cl - ql, cl)]
|
||||
position = torch.tensor(pos_list, dtype=torch.long, device=device)
|
||||
|
||||
if full_mask_cache is None or full_mask_cache.device != device or full_mask_cache_len < L:
|
||||
tril = torch.ones((L, L), dtype=torch.bool, device=device).tril_()
|
||||
full = torch.zeros((L, L), dtype=torch.float16, device=device)
|
||||
full.masked_fill_(~tril, float("-inf"))
|
||||
full_mask_cache, full_mask_cache_len = full, L
|
||||
else:
|
||||
full = full_mask_cache[:L, :L].contiguous()
|
||||
|
||||
rows = full.index_select(0, position).contiguous()
|
||||
mask = torch_npu.npu_format_cast(nd_to_nz_spec(rows).contiguous(), ACL_FORMAT_FRACTAL_NZ)
|
||||
return mask, full_mask_cache, full_mask_cache_len
|
||||
|
||||
|
||||
class _AttentionMaskBuilder310P:
|
||||
"""
|
||||
310P adapter:
|
||||
- overrides fp16 causal additive mask generation (use -inf fp16)
|
||||
- delegates all other behaviors to base AttentionMaskBuilder
|
||||
- pooling runner_type is NOT supported on 310P (explicit)
|
||||
"""
|
||||
|
||||
def __init__(self, device: torch.device):
|
||||
self._base = _BASE_BUILDER(device)
|
||||
|
||||
self._fp16_mask_cache: torch.Tensor | None = None
|
||||
self._fp16_mask_cached_len: int = 0
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._base, name)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self._base.device
|
||||
|
||||
def _get_fp16_mask(self, max_seq_len: int) -> torch.Tensor:
|
||||
if self._fp16_mask_cache is None or max_seq_len > self._fp16_mask_cached_len:
|
||||
self._fp16_mask_cache = _gen_causal_additive_mask_fp16(max_seq_len, self.device)
|
||||
self._fp16_mask_cached_len = max_seq_len
|
||||
assert self._fp16_mask_cache is not None
|
||||
return self._fp16_mask_cache[:max_seq_len, :max_seq_len].contiguous()
|
||||
|
||||
def get_attention_mask(self, model_config) -> torch.Tensor:
|
||||
if getattr(model_config, "runner_type", None) == "pooling":
|
||||
raise NotImplementedError("310P does not support runner_type='pooling'")
|
||||
return self._get_fp16_mask(2048)
|
||||
|
||||
|
||||
def AttentionMaskBuilder(device: torch.device) -> _AttentionMaskBuilder310P:
|
||||
return _AttentionMaskBuilder310P(device)
|
||||
172
vllm_ascend/_310p/attention/attention_v1.py
Normal file
172
vllm_ascend/_310p/attention/attention_v1.py
Normal file
@@ -0,0 +1,172 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder, build_splitfuse_attn_mask_310p
|
||||
from vllm_ascend._310p.attention.metadata_builder import AscendAttentionMetadataBuilder310P
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionBackend as _BaseBackend
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionBackendImpl as _BaseImpl
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionMetadataBuilder, AscendAttentionState
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, aligned_16, nd_to_nz_2d
|
||||
|
||||
|
||||
class AscendAttentionBackend310(_BaseBackend):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, head_size: int):
|
||||
# Align to a multiple of 16, as required by the 310P device.
|
||||
return (2, num_blocks, (num_kv_heads * head_size) // 16, block_size, 16)
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls():
|
||||
return AscendAttentionBackendImpl310
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
|
||||
return AscendAttentionMetadataBuilder310P
|
||||
|
||||
|
||||
class AscendAttentionBackendImpl310(_BaseImpl):
|
||||
def forward_paged_attention(self, query, attn_metadata, output):
|
||||
if attn_metadata.seq_lens.device != query.device:
|
||||
attn_metadata.seq_lens = attn_metadata.seq_lens.to(device=query.device, non_blocking=True)
|
||||
return super().forward_paged_attention(query, attn_metadata, output)
|
||||
|
||||
def _forward_prefill_310p_fallback(self, query, key, value, attn_metadata, output):
|
||||
real_tokens = int(attn_metadata.seq_lens.sum().item())
|
||||
|
||||
query, key, value, output = (aligned_16(t) for t in (query, key, value, output))
|
||||
|
||||
seq_len = attn_metadata.seq_lens
|
||||
if seq_len.dtype != torch.int32:
|
||||
seq_len = seq_len.to(torch.int32)
|
||||
|
||||
aligned_tokens = int(query.shape[0])
|
||||
delta = aligned_tokens - real_tokens
|
||||
if delta:
|
||||
seq_len = seq_len.clone()
|
||||
seq_len[-1] += delta
|
||||
|
||||
mask = attn_metadata.attn_mask
|
||||
if mask is not None and mask.dim() == 2:
|
||||
max_len = int(seq_len.max().item())
|
||||
aligned_len = ((max_len + 15) // 16) * 16
|
||||
|
||||
mask2d = mask[:aligned_len, :aligned_len].contiguous()
|
||||
mask2d = mask2d.to(torch.float16)
|
||||
mask_nz = nd_to_nz_2d(mask2d).contiguous()
|
||||
|
||||
bsz = int(seq_len.numel())
|
||||
if bsz > 1:
|
||||
mask_nz = mask_nz.repeat(bsz, 1, 1, 1).contiguous()
|
||||
|
||||
mask = torch_npu.npu_format_cast(mask_nz, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
torch_npu._npu_flash_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=seq_len,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output,
|
||||
)
|
||||
|
||||
out_real = output[:real_tokens, :, :]
|
||||
return out_real
|
||||
|
||||
def _forward_chunked_prefill_310p(self, query, attn_metadata, output):
|
||||
assert attn_metadata is not None
|
||||
|
||||
if query.dtype == torch.float32:
|
||||
query = query.to(torch.float16)
|
||||
|
||||
qsl_cpu = attn_metadata.query_start_loc.detach().to("cpu", dtype=torch.int32)
|
||||
qlens = (qsl_cpu[1:] - qsl_cpu[:-1]).to(torch.int32)
|
||||
|
||||
context_lens = attn_metadata.seq_lens
|
||||
if context_lens.dtype != torch.int32:
|
||||
context_lens = context_lens.to(torch.int32)
|
||||
|
||||
block_table = attn_metadata.block_tables.detach()
|
||||
if block_table.dtype != torch.int32:
|
||||
block_table = block_table.to(torch.int32)
|
||||
|
||||
if not hasattr(self, "_sf_full_mask_cache"):
|
||||
self._sf_full_mask_cache = None
|
||||
self._sf_full_mask_cache_len = 0
|
||||
|
||||
mask, self._sf_full_mask_cache, self._sf_full_mask_cache_len = build_splitfuse_attn_mask_310p(
|
||||
attn_metadata,
|
||||
query.device,
|
||||
full_mask_cache=self._sf_full_mask_cache,
|
||||
full_mask_cache_len=int(self._sf_full_mask_cache_len),
|
||||
)
|
||||
|
||||
if qlens.device.type != "cpu":
|
||||
qlens = qlens.to("cpu")
|
||||
if context_lens.device != query.device:
|
||||
context_lens = context_lens.to(query.device, non_blocking=True)
|
||||
|
||||
torch_npu._npu_paged_attention_splitfuse(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
mask=mask,
|
||||
block_table=block_table,
|
||||
seq_len=qlens,
|
||||
context_lens=context_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output,
|
||||
)
|
||||
|
||||
def forward_impl(self, query, key, value, kv_cache, attn_metadata, output):
|
||||
state = attn_metadata.attn_state
|
||||
|
||||
if state == AscendAttentionState.DecodeOnly:
|
||||
return self.forward_paged_attention(query, attn_metadata, output)
|
||||
|
||||
if state == AscendAttentionState.PrefillNoCache:
|
||||
num_tokens = query.shape[0]
|
||||
q = query[:num_tokens]
|
||||
k = key[:num_tokens]
|
||||
v = value[:num_tokens]
|
||||
out = self._forward_prefill_310p_fallback(q, k, v, attn_metadata, output)
|
||||
output[:num_tokens] = out
|
||||
return output
|
||||
|
||||
if state == AscendAttentionState.ChunkedPrefill:
|
||||
self._forward_chunked_prefill_310p(query, attn_metadata, output)
|
||||
return output
|
||||
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__}.forward_impl: 310P only supports "
|
||||
f"{AscendAttentionState.DecodeOnly.name}, "
|
||||
f"{AscendAttentionState.PrefillNoCache.name}, "
|
||||
f"{AscendAttentionState.ChunkedPrefill.name}, "
|
||||
f"got {state!r}."
|
||||
)
|
||||
40
vllm_ascend/_310p/attention/metadata_builder.py
Normal file
40
vllm_ascend/_310p/attention/metadata_builder.py
Normal file
@@ -0,0 +1,40 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionMetadataBuilder as _BaseBuilder
|
||||
|
||||
|
||||
class AscendAttentionMetadataBuilder310P(_BaseBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.attn_mask_builder: Any = AttentionMaskBuilder(self.device)
|
||||
Reference in New Issue
Block a user