[Feature]: Support 310P device run qwen2.5/3 dense and qwen2.5vl models (#5776)
### What this PR does / why we need it?
Add basic 310p support. Only dense models work with eager mode now.
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: Tflowers-0129 <2906339855@qq.com>
Signed-off-by: Shaoxu Cheng <2906339855@qq.com>
This commit is contained in:
@@ -62,14 +62,17 @@ set(VLLM_ASCEND_CUSTOM_OP
|
||||
)
|
||||
|
||||
set(VLLM_ASCEND_CUSTOM_OP_EXCLUDE
|
||||
${KERNEL_FILES}/bgmv_expand.cpp
|
||||
${KERNEL_FILES}/bgmv_shrink.cpp
|
||||
${KERNEL_FILES}/sgmv_expand.cpp
|
||||
${KERNEL_FILES}/sgmv_shrink.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/bgmv_expand.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/bgmv_shrink.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/sgmv_expand.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/sgmv_shrink.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/batch_matmul_transpose/op_kernel/batch_matmul_transpose_kernel.cpp
|
||||
)
|
||||
|
||||
if(SOC_VERSION STREQUAL "ASCEND310P3")
|
||||
if(SOC_VERSION STREQUAL "ascend310p3")
|
||||
message(STATUS "310P hardware detected: disabling MLAPO operators")
|
||||
message(STATUS "310P hardware detected: excluding batch_matmul_transpose operators")
|
||||
list(REMOVE_ITEM VLLM_ASCEND_CUSTOM_OP ${VLLM_ASCEND_CUSTOM_OP_EXCLUDE})
|
||||
endif()
|
||||
|
||||
@@ -79,7 +82,7 @@ ascendc_library(vllm_ascend_kernels SHARED
|
||||
|
||||
message("TORCH_NPU_PATH is ${TORCH_NPU_PATH}")
|
||||
|
||||
if(SOC_VERSION STREQUAL "ASCEND310P3")
|
||||
if(SOC_VERSION STREQUAL "ascend310p3")
|
||||
file(GLOB VLLM_ASCEND_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/aclnn_torch_adapter/*.cpp)
|
||||
|
||||
@@ -36,7 +36,9 @@ def test_llm_models(dtype: str, max_tokens: int) -> None:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
|
||||
def test_multimodal_vl():
|
||||
@pytest.mark.skip(reason="310P: multimodal test skipped, offline is ok")
|
||||
@pytest.mark.parametrize("dtype", ["float16"])
|
||||
def test_multimodal_vl(dtype: str):
|
||||
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
|
||||
|
||||
img_questions = [
|
||||
@@ -60,6 +62,7 @@ def test_multimodal_vl():
|
||||
"max_pixels": 1280 * 28 * 28,
|
||||
"fps": 1,
|
||||
},
|
||||
dtype=dtype,
|
||||
max_model_len=8192,
|
||||
enforce_eager=True,
|
||||
limit_mm_per_prompt={"image": 1}) as vllm_model:
|
||||
|
||||
@@ -1,3 +1,20 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
@@ -5,7 +22,6 @@ from tests.e2e.conftest import VllmRunner
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["float16"])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.skip("310p does not support parallel inference now. Fix me")
|
||||
def test_models(dtype: str, max_tokens: int) -> None:
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
|
||||
0
vllm_ascend/_310p/__init__.py
Normal file
0
vllm_ascend/_310p/__init__.py
Normal file
0
vllm_ascend/_310p/attention/__init__.py
Normal file
0
vllm_ascend/_310p/attention/__init__.py
Normal file
98
vllm_ascend/_310p/attention/attention_mask.py
Normal file
98
vllm_ascend/_310p/attention/attention_mask.py
Normal file
@@ -0,0 +1,98 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
import vllm_ascend.attention.attention_mask as _base_mask
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, nd_to_nz_spec
|
||||
|
||||
_BASE_BUILDER: Callable[[torch.device], Any] = _base_mask.AttentionMaskBuilder
|
||||
|
||||
|
||||
def _gen_causal_additive_mask_fp16(max_seq_len: int, device: torch.device) -> torch.Tensor:
|
||||
tril = torch.ones((max_seq_len, max_seq_len), dtype=torch.bool, device=device).tril_()
|
||||
upper = ~tril
|
||||
m = torch.zeros((max_seq_len, max_seq_len), dtype=torch.float16, device=device)
|
||||
m.masked_fill_(upper, float("-inf"))
|
||||
return m
|
||||
|
||||
|
||||
def build_splitfuse_attn_mask_310p(attn_metadata, device, *, full_mask_cache=None, full_mask_cache_len=0):
|
||||
qsl = attn_metadata.query_start_loc.detach().to("cpu", dtype=torch.int32)
|
||||
qlens = qsl[1:] - qsl[:-1]
|
||||
|
||||
context_lens = attn_metadata.seq_lens.to(dtype=torch.int32)
|
||||
L = int(context_lens.max().item())
|
||||
|
||||
q_list = qlens.tolist()
|
||||
c_list = context_lens.detach().to("cpu", dtype=torch.int64).tolist()
|
||||
pos_list = [p for ql, cl in zip(q_list, c_list) for p in range(cl - ql, cl)]
|
||||
position = torch.tensor(pos_list, dtype=torch.long, device=device)
|
||||
|
||||
if full_mask_cache is None or full_mask_cache.device != device or full_mask_cache_len < L:
|
||||
tril = torch.ones((L, L), dtype=torch.bool, device=device).tril_()
|
||||
full = torch.zeros((L, L), dtype=torch.float16, device=device)
|
||||
full.masked_fill_(~tril, float("-inf"))
|
||||
full_mask_cache, full_mask_cache_len = full, L
|
||||
else:
|
||||
full = full_mask_cache[:L, :L].contiguous()
|
||||
|
||||
rows = full.index_select(0, position).contiguous()
|
||||
mask = torch_npu.npu_format_cast(nd_to_nz_spec(rows).contiguous(), ACL_FORMAT_FRACTAL_NZ)
|
||||
return mask, full_mask_cache, full_mask_cache_len
|
||||
|
||||
|
||||
class _AttentionMaskBuilder310P:
|
||||
"""
|
||||
310P adapter:
|
||||
- overrides fp16 causal additive mask generation (use -inf fp16)
|
||||
- delegates all other behaviors to base AttentionMaskBuilder
|
||||
- pooling runner_type is NOT supported on 310P (explicit)
|
||||
"""
|
||||
|
||||
def __init__(self, device: torch.device):
|
||||
self._base = _BASE_BUILDER(device)
|
||||
|
||||
self._fp16_mask_cache: torch.Tensor | None = None
|
||||
self._fp16_mask_cached_len: int = 0
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._base, name)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self._base.device
|
||||
|
||||
def _get_fp16_mask(self, max_seq_len: int) -> torch.Tensor:
|
||||
if self._fp16_mask_cache is None or max_seq_len > self._fp16_mask_cached_len:
|
||||
self._fp16_mask_cache = _gen_causal_additive_mask_fp16(max_seq_len, self.device)
|
||||
self._fp16_mask_cached_len = max_seq_len
|
||||
assert self._fp16_mask_cache is not None
|
||||
return self._fp16_mask_cache[:max_seq_len, :max_seq_len].contiguous()
|
||||
|
||||
def get_attention_mask(self, model_config) -> torch.Tensor:
|
||||
if getattr(model_config, "runner_type", None) == "pooling":
|
||||
raise NotImplementedError("310P does not support runner_type='pooling'")
|
||||
return self._get_fp16_mask(2048)
|
||||
|
||||
|
||||
def AttentionMaskBuilder(device: torch.device) -> _AttentionMaskBuilder310P:
|
||||
return _AttentionMaskBuilder310P(device)
|
||||
172
vllm_ascend/_310p/attention/attention_v1.py
Normal file
172
vllm_ascend/_310p/attention/attention_v1.py
Normal file
@@ -0,0 +1,172 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder, build_splitfuse_attn_mask_310p
|
||||
from vllm_ascend._310p.attention.metadata_builder import AscendAttentionMetadataBuilder310P
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionBackend as _BaseBackend
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionBackendImpl as _BaseImpl
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionMetadataBuilder, AscendAttentionState
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, aligned_16, nd_to_nz_2d
|
||||
|
||||
|
||||
class AscendAttentionBackend310(_BaseBackend):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, head_size: int):
|
||||
# Align to a multiple of 16, as required by the 310P device.
|
||||
return (2, num_blocks, (num_kv_heads * head_size) // 16, block_size, 16)
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls():
|
||||
return AscendAttentionBackendImpl310
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
|
||||
return AscendAttentionMetadataBuilder310P
|
||||
|
||||
|
||||
class AscendAttentionBackendImpl310(_BaseImpl):
|
||||
def forward_paged_attention(self, query, attn_metadata, output):
|
||||
if attn_metadata.seq_lens.device != query.device:
|
||||
attn_metadata.seq_lens = attn_metadata.seq_lens.to(device=query.device, non_blocking=True)
|
||||
return super().forward_paged_attention(query, attn_metadata, output)
|
||||
|
||||
def _forward_prefill_310p_fallback(self, query, key, value, attn_metadata, output):
|
||||
real_tokens = int(attn_metadata.seq_lens.sum().item())
|
||||
|
||||
query, key, value, output = (aligned_16(t) for t in (query, key, value, output))
|
||||
|
||||
seq_len = attn_metadata.seq_lens
|
||||
if seq_len.dtype != torch.int32:
|
||||
seq_len = seq_len.to(torch.int32)
|
||||
|
||||
aligned_tokens = int(query.shape[0])
|
||||
delta = aligned_tokens - real_tokens
|
||||
if delta:
|
||||
seq_len = seq_len.clone()
|
||||
seq_len[-1] += delta
|
||||
|
||||
mask = attn_metadata.attn_mask
|
||||
if mask is not None and mask.dim() == 2:
|
||||
max_len = int(seq_len.max().item())
|
||||
aligned_len = ((max_len + 15) // 16) * 16
|
||||
|
||||
mask2d = mask[:aligned_len, :aligned_len].contiguous()
|
||||
mask2d = mask2d.to(torch.float16)
|
||||
mask_nz = nd_to_nz_2d(mask2d).contiguous()
|
||||
|
||||
bsz = int(seq_len.numel())
|
||||
if bsz > 1:
|
||||
mask_nz = mask_nz.repeat(bsz, 1, 1, 1).contiguous()
|
||||
|
||||
mask = torch_npu.npu_format_cast(mask_nz, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
torch_npu._npu_flash_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=seq_len,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output,
|
||||
)
|
||||
|
||||
out_real = output[:real_tokens, :, :]
|
||||
return out_real
|
||||
|
||||
def _forward_chunked_prefill_310p(self, query, attn_metadata, output):
|
||||
assert attn_metadata is not None
|
||||
|
||||
if query.dtype == torch.float32:
|
||||
query = query.to(torch.float16)
|
||||
|
||||
qsl_cpu = attn_metadata.query_start_loc.detach().to("cpu", dtype=torch.int32)
|
||||
qlens = (qsl_cpu[1:] - qsl_cpu[:-1]).to(torch.int32)
|
||||
|
||||
context_lens = attn_metadata.seq_lens
|
||||
if context_lens.dtype != torch.int32:
|
||||
context_lens = context_lens.to(torch.int32)
|
||||
|
||||
block_table = attn_metadata.block_tables.detach()
|
||||
if block_table.dtype != torch.int32:
|
||||
block_table = block_table.to(torch.int32)
|
||||
|
||||
if not hasattr(self, "_sf_full_mask_cache"):
|
||||
self._sf_full_mask_cache = None
|
||||
self._sf_full_mask_cache_len = 0
|
||||
|
||||
mask, self._sf_full_mask_cache, self._sf_full_mask_cache_len = build_splitfuse_attn_mask_310p(
|
||||
attn_metadata,
|
||||
query.device,
|
||||
full_mask_cache=self._sf_full_mask_cache,
|
||||
full_mask_cache_len=int(self._sf_full_mask_cache_len),
|
||||
)
|
||||
|
||||
if qlens.device.type != "cpu":
|
||||
qlens = qlens.to("cpu")
|
||||
if context_lens.device != query.device:
|
||||
context_lens = context_lens.to(query.device, non_blocking=True)
|
||||
|
||||
torch_npu._npu_paged_attention_splitfuse(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
mask=mask,
|
||||
block_table=block_table,
|
||||
seq_len=qlens,
|
||||
context_lens=context_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output,
|
||||
)
|
||||
|
||||
def forward_impl(self, query, key, value, kv_cache, attn_metadata, output):
|
||||
state = attn_metadata.attn_state
|
||||
|
||||
if state == AscendAttentionState.DecodeOnly:
|
||||
return self.forward_paged_attention(query, attn_metadata, output)
|
||||
|
||||
if state == AscendAttentionState.PrefillNoCache:
|
||||
num_tokens = query.shape[0]
|
||||
q = query[:num_tokens]
|
||||
k = key[:num_tokens]
|
||||
v = value[:num_tokens]
|
||||
out = self._forward_prefill_310p_fallback(q, k, v, attn_metadata, output)
|
||||
output[:num_tokens] = out
|
||||
return output
|
||||
|
||||
if state == AscendAttentionState.ChunkedPrefill:
|
||||
self._forward_chunked_prefill_310p(query, attn_metadata, output)
|
||||
return output
|
||||
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__}.forward_impl: 310P only supports "
|
||||
f"{AscendAttentionState.DecodeOnly.name}, "
|
||||
f"{AscendAttentionState.PrefillNoCache.name}, "
|
||||
f"{AscendAttentionState.ChunkedPrefill.name}, "
|
||||
f"got {state!r}."
|
||||
)
|
||||
40
vllm_ascend/_310p/attention/metadata_builder.py
Normal file
40
vllm_ascend/_310p/attention/metadata_builder.py
Normal file
@@ -0,0 +1,40 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from vllm_ascend._310p.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionMetadataBuilder as _BaseBuilder
|
||||
|
||||
|
||||
class AscendAttentionMetadataBuilder310P(_BaseBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.attn_mask_builder: Any = AttentionMaskBuilder(self.device)
|
||||
100
vllm_ascend/_310p/modelrunner_310p.py
Normal file
100
vllm_ascend/_310p/modelrunner_310p.py
Normal file
@@ -0,0 +1,100 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
|
||||
class NPUModelRunner310(NPUModelRunner):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._acl_format = ACL_FORMAT_FRACTAL_NZ
|
||||
|
||||
def _initialize_kv_cache_tensors_310p(self, kv_cache_config: KVCacheConfig) -> dict[str, Any]:
|
||||
if self.vllm_config.kv_transfer_config is not None:
|
||||
raise ValueError("KV cache transfer is not supported for 310P.")
|
||||
|
||||
kv_cache_sizes: dict[str, int] = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
assert len(kv_cache_tensor.shared_by) == 1, (
|
||||
"KV cache tensor shared by multiple layers is not supported in 310P."
|
||||
)
|
||||
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
|
||||
|
||||
kv_caches: dict[str, Any] = {}
|
||||
|
||||
for group in self._kv_cache_spec_attn_group_iterator():
|
||||
kv_cache_spec = group.kv_cache_spec
|
||||
attn_backend = group.backend
|
||||
|
||||
if not isinstance(kv_cache_spec, FullAttentionSpec):
|
||||
raise ValueError("Unknown KV cache spec type.")
|
||||
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
|
||||
tensor_size = kv_cache_sizes[layer_name]
|
||||
assert tensor_size % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
|
||||
assert num_blocks >= kv_cache_config.num_blocks
|
||||
|
||||
if hasattr(attn_backend, "get_supported_block_size") and self.use_hybrid_blocks:
|
||||
block_size = attn_backend.get_supported_block_size()[0]
|
||||
block_size_chunk = kv_cache_spec.block_size // block_size
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks * block_size_chunk,
|
||||
block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
)
|
||||
else:
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
)
|
||||
|
||||
dtype = kv_cache_spec.dtype
|
||||
|
||||
if "attn" in layer_name:
|
||||
k_tensor = torch.zeros(kv_cache_shape[1:], dtype=dtype, device=self.device)
|
||||
v_tensor = torch.zeros(kv_cache_shape[1:], dtype=dtype, device=self.device)
|
||||
k_cache = torch_npu.npu_format_cast(k_tensor, self._acl_format)
|
||||
v_cache = torch_npu.npu_format_cast(v_tensor, self._acl_format)
|
||||
kv_caches[layer_name] = (k_cache, v_cache)
|
||||
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_caches,
|
||||
1, # 310p devices donnot support: hf_config.model_type == "longcat_flash"
|
||||
)
|
||||
return kv_caches
|
||||
|
||||
def initialize_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str, Any]:
|
||||
return self._initialize_kv_cache_tensors_310p(kv_cache_config)
|
||||
0
vllm_ascend/_310p/ops/__init__.py
Normal file
0
vllm_ascend/_310p/ops/__init__.py
Normal file
30
vllm_ascend/_310p/ops/activation.py
Normal file
30
vllm_ascend/_310p/ops/activation.py
Normal file
@@ -0,0 +1,30 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm_ascend.ops.activation import AscendSiluAndMul as _Base
|
||||
|
||||
|
||||
class AscendSiluAndMul310(_Base):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
torch.ops.vllm.maybe_prefetch_mlp_down_proj(x)
|
||||
h = x.shape[-1] // 2
|
||||
out = (F.silu(x[..., :h].to(torch.float32)) * x[..., h:].to(torch.float32)).to(torch.float16)
|
||||
torch.ops.vllm.maybe_wait_prefetch_done(out)
|
||||
return out
|
||||
104
vllm_ascend/_310p/ops/mm_encoder_attention.py
Normal file
104
vllm_ascend/_310p/ops/mm_encoder_attention.py
Normal file
@@ -0,0 +1,104 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ops.mm_encoder_attention import MAX_PAD_SIZE, MIN_PAD_SIZE
|
||||
from vllm_ascend.ops.mm_encoder_attention import AscendMMEncoderAttention as _Base
|
||||
|
||||
|
||||
class AscendMMEncoderAttention310(_Base):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: int | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
bsz, q_len = query.size()[:2]
|
||||
kv_len = key.size(1)
|
||||
|
||||
q, k, v = self.reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len)
|
||||
|
||||
enable_pad = envs_ascend.USE_OPTIMIZED_MODEL and self.head_size > MIN_PAD_SIZE and self.head_size < MAX_PAD_SIZE
|
||||
|
||||
origin_shape = q.shape[-1]
|
||||
if enable_pad:
|
||||
pad_len = MAX_PAD_SIZE - origin_shape
|
||||
q = F.pad(q, (0, pad_len), mode="constant", value=0)
|
||||
k = F.pad(k, (0, pad_len), mode="constant", value=0)
|
||||
v = F.pad(v, (0, pad_len), mode="constant", value=0)
|
||||
|
||||
origin_dim = origin_shape
|
||||
cur_dim = q.shape[-1]
|
||||
pad16 = (16 - cur_dim % 16) % 16
|
||||
if pad16:
|
||||
q = F.pad(q, (0, pad16), mode="constant", value=0)
|
||||
k = F.pad(k, (0, pad16), mode="constant", value=0)
|
||||
v = F.pad(v, (0, pad16), mode="constant", value=0)
|
||||
|
||||
if cu_seqlens is None:
|
||||
cu_seqlens = torch.arange(
|
||||
0,
|
||||
(bsz + 1) * q_len,
|
||||
step=q_len,
|
||||
dtype=torch.int32,
|
||||
device=query.device,
|
||||
)
|
||||
|
||||
total_q_tokens = bsz * q_len
|
||||
context_flat = q.new_empty((total_q_tokens, self.num_heads, q.shape[-1]))
|
||||
|
||||
st = 0
|
||||
seg_lens = torch.diff(cu_seqlens).to("cpu", dtype=torch.int64).tolist()
|
||||
for seg_len in seg_lens:
|
||||
seg_len = int(seg_len)
|
||||
ed = st + seg_len
|
||||
|
||||
q_i = q[st:ed].unsqueeze(0) # [1, S, H, D]
|
||||
k_i = k[st:ed].unsqueeze(0)
|
||||
v_i = v[st:ed].unsqueeze(0)
|
||||
|
||||
qs = int(q_i.shape[1])
|
||||
kvs = int(k_i.shape[1])
|
||||
|
||||
out_i = torch_npu.npu_prompt_flash_attention(
|
||||
q_i,
|
||||
k_i,
|
||||
v_i,
|
||||
input_layout="BSND",
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
scale_value=self.head_size**-0.5,
|
||||
pre_tokens=qs,
|
||||
next_tokens=kvs,
|
||||
)
|
||||
context_flat[st:ed] = out_i[0]
|
||||
st = ed
|
||||
|
||||
context_flat = context_flat[..., :origin_dim]
|
||||
context_layer = einops.rearrange(context_flat, "(b s) h d -> b s h d", b=bsz).contiguous()
|
||||
return context_layer
|
||||
23
vllm_ascend/_310p/ops/rotary_embedding.py
Normal file
23
vllm_ascend/_310p/ops/rotary_embedding.py
Normal file
@@ -0,0 +1,23 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
|
||||
|
||||
class AscendMRotaryEmbedding310(MRotaryEmbedding):
|
||||
def forward_oot(self, positions, query, key):
|
||||
return super().forward_oot(positions, query, key)
|
||||
37
vllm_ascend/_310p/worker_310p.py
Normal file
37
vllm_ascend/_310p/worker_310p.py
Normal file
@@ -0,0 +1,37 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch_npu
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend._310p.modelrunner_310p import NPUModelRunner310
|
||||
from vllm_ascend.worker.worker import NPUWorker, init_workspace_manager
|
||||
|
||||
|
||||
class NPUWorker310(NPUWorker):
|
||||
def init_device(self):
|
||||
self.device = self._init_device()
|
||||
|
||||
torch_npu.npu.set_compile_mode(jit_compile=False)
|
||||
|
||||
init_workspace_manager(self.device, num_ubatches=1)
|
||||
|
||||
self.model_runner = NPUModelRunner310(self.vllm_config, self.device)
|
||||
|
||||
def _warm_up_atb(self):
|
||||
# 310p device donot support torch_npu._npu_matmul_add_fp32 atb ops
|
||||
logger.info("Skip warm-up atb ops for 310P device")
|
||||
@@ -23,7 +23,6 @@ from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
|
||||
|
||||
class NullHandle:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@@ -32,12 +31,12 @@ class NullHandle:
|
||||
|
||||
|
||||
def communication_adaptation_310p():
|
||||
|
||||
def broadcast310p_wrapper(fn):
|
||||
def broadcast310p(tensor, src=0, group=None, async_op=False, group_src=None):
|
||||
root = group_src if group_src is not None else src
|
||||
|
||||
def broadcast310p(tensor, src, group=None, async_op=False):
|
||||
if tensor.device == torch.device('cpu'):
|
||||
return fn(tensor, src, group, async_op)
|
||||
if tensor.device == torch.device("cpu"):
|
||||
return fn(tensor, src=root, group=group, async_op=async_op)
|
||||
rank = torch.distributed.get_rank(group)
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
||||
@@ -51,13 +50,10 @@ def communication_adaptation_310p():
|
||||
|
||||
return broadcast310p
|
||||
|
||||
torch.distributed.broadcast = broadcast310p_wrapper(
|
||||
torch.distributed.broadcast)
|
||||
torch.distributed.distributed_c10d.broadcast = broadcast310p_wrapper(
|
||||
torch.distributed.distributed_c10d.broadcast)
|
||||
torch.distributed.broadcast = broadcast310p_wrapper(torch.distributed.broadcast)
|
||||
torch.distributed.distributed_c10d.broadcast = broadcast310p_wrapper(torch.distributed.distributed_c10d.broadcast)
|
||||
|
||||
def all_reduce_wrapper_310p(fn):
|
||||
|
||||
def all_reduce(
|
||||
tensor,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
@@ -83,10 +79,10 @@ def communication_adaptation_310p():
|
||||
|
||||
return all_reduce
|
||||
|
||||
torch.distributed.all_reduce = all_reduce_wrapper_310p(
|
||||
torch.distributed.all_reduce)
|
||||
torch.distributed.all_reduce = all_reduce_wrapper_310p(torch.distributed.all_reduce)
|
||||
torch.distributed.distributed_c10d.all_reduce = all_reduce_wrapper_310p(
|
||||
torch.distributed.distributed_c10d.all_reduce)
|
||||
torch.distributed.distributed_c10d.all_reduce
|
||||
)
|
||||
|
||||
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
|
||||
@@ -48,6 +48,7 @@ from vllm_ascend.utils import (
|
||||
update_aclgraph_sizes,
|
||||
update_cudagraph_capture_sizes,
|
||||
update_default_aclgraph_sizes,
|
||||
is_310p,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -322,7 +323,9 @@ class NPUPlatform(Platform):
|
||||
if parallel_config and parallel_config.worker_cls == "auto":
|
||||
# TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm.
|
||||
parallel_config.all2all_backend = "flashinfer_all2allv"
|
||||
if ascend_config.xlite_graph_config.enabled:
|
||||
if is_310p():
|
||||
parallel_config.worker_cls = "vllm_ascend._310p.worker_310p.NPUWorker310"
|
||||
elif ascend_config.xlite_graph_config.enabled:
|
||||
logger.info("openEuler Xlite enabled. See: https://atomgit.com/openeuler/GVirt/tree/master/xlite")
|
||||
parallel_config.worker_cls = "vllm_ascend.xlite.xlite_worker.XliteWorker"
|
||||
else:
|
||||
@@ -394,13 +397,27 @@ class NPUPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend, attn_selector_config):
|
||||
key = (attn_selector_config.use_mla, attn_selector_config.use_sparse)
|
||||
|
||||
backend_map = {
|
||||
(True, False): "vllm_ascend.attention.mla_v1.AscendMLABackend",
|
||||
(False, False): "vllm_ascend.attention.attention_v1.AscendAttentionBackend",
|
||||
(True, True): "vllm_ascend.attention.sfa_v1.AscendSFABackend",
|
||||
}
|
||||
backend_map_310 = {
|
||||
(
|
||||
False,
|
||||
False,
|
||||
): "vllm_ascend._310p.attention.attention_v1.AscendAttentionBackend310",
|
||||
# TODO If MLA/SFA is supported in the future, consider implementing the logic described in these comments.
|
||||
# (True, False): "...AscendMLABackend310",
|
||||
# (True, True): "...AscendSFABackend310",
|
||||
}
|
||||
|
||||
return backend_map[(attn_selector_config.use_mla, attn_selector_config.use_sparse)]
|
||||
if is_310p():
|
||||
return backend_map_310.get(key, backend_map_310[(False, False)])
|
||||
|
||||
return backend_map[key]
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
|
||||
@@ -74,6 +74,10 @@ _GRAPH_PRINT_STREAM_LOCK = Lock()
|
||||
_HAS_ROPE = None
|
||||
|
||||
|
||||
def is_310p():
|
||||
return get_ascend_device_type() == AscendDeviceType._310P
|
||||
|
||||
|
||||
def _print_callback_on_stream(*args):
|
||||
"""Callback function to print arguments on the dedicated print stream."""
|
||||
global _GRAPH_PRINT_STREAM
|
||||
@@ -713,6 +717,22 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None):
|
||||
"ApplyRotaryEmb": AscendApplyRotaryEmb,
|
||||
}
|
||||
|
||||
# 310P: override selected ops with 310P implementations (keep minimal changes outside _310p)
|
||||
if is_310p():
|
||||
from vllm_ascend._310p.ops.activation import AscendSiluAndMul310
|
||||
from vllm_ascend._310p.ops.mm_encoder_attention import AscendMMEncoderAttention310
|
||||
from vllm_ascend._310p.ops.rotary_embedding import (
|
||||
AscendMRotaryEmbedding310,
|
||||
)
|
||||
|
||||
REGISTERED_ASCEND_OPS.update(
|
||||
{
|
||||
"SiluAndMul": AscendSiluAndMul310,
|
||||
"MMEncoderAttention": AscendMMEncoderAttention310,
|
||||
"MRotaryEmbedding": AscendMRotaryEmbedding310,
|
||||
}
|
||||
)
|
||||
|
||||
for name, op_cls in REGISTERED_ASCEND_OPS.items():
|
||||
CustomOp.register_oot(_decorated_op_cls=op_cls, name=name)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user