[310P]: refactoring for 310p kvcache and some ops class (#6117)
### What this PR does / why we need it?
* Refactor the LayerNorm and activation operator classes to decouple the
310P device implementation from the main branch.
* Refactor `mm_encoder_attention` on 310P to use the
`torch_npu._npu_flash_attention_unpad` operator.
* Refactor the QKV inputs in the prefill stage of `attention_v1` on 310P
so they are no longer padded to 16× alignment.
* Refactor `model_runner` on 310P to align the KV-cache initialization
logic with the mainline implementation.
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
use the e2e tests.
- vLLM version: v0.13.0
- vLLM main:
d68209402d
---------
Signed-off-by: Tflowers-0129 <2906339855@qq.com>
This commit is contained in:
@@ -25,7 +25,7 @@ from vllm_ascend._310p.attention.metadata_builder import AscendAttentionMetadata
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionBackend as _BaseBackend
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionBackendImpl as _BaseImpl
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionMetadataBuilder, AscendAttentionState, AscendMetadata
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, aligned_16, nd_to_nz_2d
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, nd_to_nz_2d
|
||||
|
||||
|
||||
class AscendAttentionBackend310(_BaseBackend):
|
||||
@@ -64,8 +64,6 @@ class AscendAttentionBackendImpl310(_BaseImpl):
|
||||
def _forward_prefill_310p_fallback(self, query, key, value, attn_metadata, output):
|
||||
real_tokens = int(attn_metadata.seq_lens.sum().item())
|
||||
|
||||
query, key, value, output = (aligned_16(t) for t in (query, key, value, output))
|
||||
|
||||
seq_len = attn_metadata.seq_lens
|
||||
if seq_len.dtype != torch.int32:
|
||||
seq_len = seq_len.to(torch.int32)
|
||||
|
||||
186
vllm_ascend/_310p/model_runner_310p.py
Normal file
186
vllm_ascend/_310p/model_runner_310p.py
Normal file
@@ -0,0 +1,186 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
|
||||
class NPUModelRunner310(NPUModelRunner):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._acl_format = ACL_FORMAT_FRACTAL_NZ
|
||||
|
||||
def initialize_kv_cache_tensors(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Initialize KV cache tensors for 310P.
|
||||
|
||||
1) allocate buffers
|
||||
2) reshape / transform to the final layout
|
||||
3) optional cross-layer sharing
|
||||
4) bind buffers to the static forward context
|
||||
"""
|
||||
# 310P limitation: KV transfer is not supported.
|
||||
if self.vllm_config.kv_transfer_config is not None:
|
||||
raise ValueError("KV cache transfer is not supported for 310P.")
|
||||
|
||||
kv_cache_raw_tensors = self._allocate_kv_cache_tensors_310p(kv_cache_config)
|
||||
kv_caches = self._reshape_kv_cache_tensors_310p(kv_cache_config, kv_cache_raw_tensors)
|
||||
|
||||
# Keep the same cross-layer KV cache sharing logic as the main branch.
|
||||
# For 310P, this is expected to be empty in most cases, but keeping it
|
||||
# makes the code path consistent and easier to reason about.
|
||||
for layer_name, target_layer_name in self.shared_kv_cache_layers.items():
|
||||
kv_caches[layer_name] = kv_caches[target_layer_name]
|
||||
|
||||
# 310P devices do not support the "longcat_flash" special case here, so always be "1".
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_caches,
|
||||
1,
|
||||
)
|
||||
return kv_caches
|
||||
|
||||
def _allocate_kv_cache_tensors_310p(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Allocate KV cache buffers for each attention layer.
|
||||
|
||||
Unlike the non-310p path, 310P uses torch.zeros directly with the final dtype,
|
||||
and defers layout casting (ACL format) to the reshape step.
|
||||
"""
|
||||
# Build a mapping: layer_name -> tensor_size(bytes).
|
||||
kv_cache_sizes: dict[str, int] = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
# 310P limitation: a KV cache tensor must not be shared by multiple layers.
|
||||
assert len(kv_cache_tensor.shared_by) == 1, (
|
||||
"KV cache tensor shared by multiple layers is not supported in 310P."
|
||||
)
|
||||
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
|
||||
|
||||
kv_cache_raw_tensors: dict[str, tuple[torch.Tensor, torch.Tensor]] = {}
|
||||
|
||||
for group in self._kv_cache_spec_attn_group_iterator():
|
||||
kv_cache_spec = group.kv_cache_spec
|
||||
attn_backend = group.backend
|
||||
|
||||
if not isinstance(kv_cache_spec, FullAttentionSpec):
|
||||
raise ValueError("Unknown KV cache spec type.")
|
||||
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
|
||||
if "attn" not in layer_name:
|
||||
continue
|
||||
|
||||
# Compute how many blocks this layer can hold.
|
||||
tensor_size = kv_cache_sizes[layer_name]
|
||||
assert tensor_size % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
|
||||
|
||||
# `num_blocks` must be >= the number KVCacheManager may allocate.
|
||||
assert num_blocks >= kv_cache_config.num_blocks
|
||||
|
||||
# Determine the KV cache shape from backend.
|
||||
kv_cache_shape = self._get_kv_cache_shape_310p(
|
||||
attn_backend=attn_backend,
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
num_blocks=num_blocks,
|
||||
)
|
||||
|
||||
shape = kv_cache_shape[1:]
|
||||
dtype = kv_cache_spec.dtype
|
||||
|
||||
k_tensor = torch.zeros(shape, dtype=dtype, device=self.device)
|
||||
v_tensor = torch.zeros(shape, dtype=dtype, device=self.device)
|
||||
kv_cache_raw_tensors[layer_name] = (k_tensor, v_tensor)
|
||||
|
||||
return kv_cache_raw_tensors
|
||||
|
||||
def _reshape_kv_cache_tensors_310p(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
kv_cache_raw_tensors: dict[str, tuple[torch.Tensor, torch.Tensor]],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Transform allocated KV cache buffers into the final layout required by 310P.
|
||||
|
||||
For 310P, this mainly means casting tensors into the expected ACL format.
|
||||
"""
|
||||
kv_caches: dict[str, Any] = {}
|
||||
|
||||
for group in self._kv_cache_spec_attn_group_iterator():
|
||||
kv_cache_spec = group.kv_cache_spec
|
||||
if not isinstance(kv_cache_spec, FullAttentionSpec):
|
||||
raise ValueError("Unknown KV cache spec type.")
|
||||
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
if "attn" not in layer_name:
|
||||
continue
|
||||
|
||||
k_tensor, v_tensor = kv_cache_raw_tensors[layer_name]
|
||||
|
||||
# In-place ACL layout cast to avoid the extra allocation of npu_format_cast,
|
||||
# which can spike peak memory (~2x KV cache) during initialization and trigger OOM.
|
||||
torch_npu.npu_format_cast_(k_tensor, self._acl_format)
|
||||
torch_npu.npu_format_cast_(v_tensor, self._acl_format)
|
||||
kv_caches[layer_name] = (k_tensor, v_tensor)
|
||||
|
||||
return kv_caches
|
||||
|
||||
def _get_kv_cache_shape_310p(
|
||||
self,
|
||||
attn_backend: Any,
|
||||
kv_cache_spec: FullAttentionSpec,
|
||||
num_blocks: int,
|
||||
) -> tuple[int, ...]:
|
||||
"""
|
||||
Compute KV cache shape with (optional) hybrid block support.
|
||||
"""
|
||||
if hasattr(attn_backend, "get_supported_block_size") and self.use_hybrid_blocks:
|
||||
block_size = attn_backend.get_supported_block_size()[0]
|
||||
block_size_chunk = kv_cache_spec.block_size // block_size
|
||||
return attn_backend.get_kv_cache_shape(
|
||||
num_blocks * block_size_chunk,
|
||||
block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
)
|
||||
|
||||
return attn_backend.get_kv_cache_shape(
|
||||
num_blocks,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
)
|
||||
@@ -1,100 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
|
||||
class NPUModelRunner310(NPUModelRunner):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._acl_format = ACL_FORMAT_FRACTAL_NZ
|
||||
|
||||
def _initialize_kv_cache_tensors_310p(self, kv_cache_config: KVCacheConfig) -> dict[str, Any]:
|
||||
if self.vllm_config.kv_transfer_config is not None:
|
||||
raise ValueError("KV cache transfer is not supported for 310P.")
|
||||
|
||||
kv_cache_sizes: dict[str, int] = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
assert len(kv_cache_tensor.shared_by) == 1, (
|
||||
"KV cache tensor shared by multiple layers is not supported in 310P."
|
||||
)
|
||||
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
|
||||
|
||||
kv_caches: dict[str, Any] = {}
|
||||
|
||||
for group in self._kv_cache_spec_attn_group_iterator():
|
||||
kv_cache_spec = group.kv_cache_spec
|
||||
attn_backend = group.backend
|
||||
|
||||
if not isinstance(kv_cache_spec, FullAttentionSpec):
|
||||
raise ValueError("Unknown KV cache spec type.")
|
||||
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
|
||||
tensor_size = kv_cache_sizes[layer_name]
|
||||
assert tensor_size % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
|
||||
assert num_blocks >= kv_cache_config.num_blocks
|
||||
|
||||
if hasattr(attn_backend, "get_supported_block_size") and self.use_hybrid_blocks:
|
||||
block_size = attn_backend.get_supported_block_size()[0]
|
||||
block_size_chunk = kv_cache_spec.block_size // block_size
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks * block_size_chunk,
|
||||
block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
)
|
||||
else:
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
)
|
||||
|
||||
dtype = kv_cache_spec.dtype
|
||||
|
||||
if "attn" in layer_name:
|
||||
k_tensor = torch.zeros(kv_cache_shape[1:], dtype=dtype, device=self.device)
|
||||
v_tensor = torch.zeros(kv_cache_shape[1:], dtype=dtype, device=self.device)
|
||||
k_cache = torch_npu.npu_format_cast(k_tensor, self._acl_format)
|
||||
v_cache = torch_npu.npu_format_cast(v_tensor, self._acl_format)
|
||||
kv_caches[layer_name] = (k_cache, v_cache)
|
||||
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_caches,
|
||||
1, # 310p devices donnot support: hf_config.model_type == "longcat_flash"
|
||||
)
|
||||
return kv_caches
|
||||
|
||||
def initialize_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str, Any]:
|
||||
return self._initialize_kv_cache_tensors_310p(kv_cache_config)
|
||||
44
vllm_ascend/_310p/ops/layernorm.py
Normal file
44
vllm_ascend/_310p/ops/layernorm.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm
|
||||
|
||||
|
||||
class AscendRMSNorm310(AscendRMSNorm):
|
||||
def forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is not None:
|
||||
orig_dtype = residual.dtype
|
||||
if x is None or x.numel() == 0 or x.shape[-1] == 0:
|
||||
x = residual.to(dtype=residual.dtype)
|
||||
else:
|
||||
x = x + residual.to(x.dtype)
|
||||
|
||||
residual = x.to(orig_dtype)
|
||||
x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
|
||||
return x, residual
|
||||
|
||||
x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
|
||||
if self.bias is not None:
|
||||
x.add_(self.bias)
|
||||
return x
|
||||
|
||||
|
||||
class AscendGemmaRMSNorm310(AscendGemmaRMSNorm):
|
||||
def forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is not None:
|
||||
orig_dtype = residual.dtype
|
||||
x = x + residual.to(x.dtype)
|
||||
residual = x.to(orig_dtype)
|
||||
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, self.variance_epsilon)
|
||||
return x, residual
|
||||
|
||||
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, self.variance_epsilon)
|
||||
return x
|
||||
@@ -17,11 +17,8 @@
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ops.mm_encoder_attention import MAX_PAD_SIZE, MIN_PAD_SIZE
|
||||
from vllm_ascend.ops.mm_encoder_attention import AscendMMEncoderAttention as _Base
|
||||
|
||||
|
||||
@@ -43,23 +40,6 @@ class AscendMMEncoderAttention310(_Base):
|
||||
|
||||
q, k, v = self.reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len)
|
||||
|
||||
enable_pad = envs_ascend.USE_OPTIMIZED_MODEL and self.head_size > MIN_PAD_SIZE and self.head_size < MAX_PAD_SIZE
|
||||
|
||||
origin_shape = q.shape[-1]
|
||||
if enable_pad:
|
||||
pad_len = MAX_PAD_SIZE - origin_shape
|
||||
q = F.pad(q, (0, pad_len), mode="constant", value=0)
|
||||
k = F.pad(k, (0, pad_len), mode="constant", value=0)
|
||||
v = F.pad(v, (0, pad_len), mode="constant", value=0)
|
||||
|
||||
origin_dim = origin_shape
|
||||
cur_dim = q.shape[-1]
|
||||
pad16 = (16 - cur_dim % 16) % 16
|
||||
if pad16:
|
||||
q = F.pad(q, (0, pad16), mode="constant", value=0)
|
||||
k = F.pad(k, (0, pad16), mode="constant", value=0)
|
||||
v = F.pad(v, (0, pad16), mode="constant", value=0)
|
||||
|
||||
if cu_seqlens is None:
|
||||
cu_seqlens = torch.arange(
|
||||
0,
|
||||
@@ -69,36 +49,19 @@ class AscendMMEncoderAttention310(_Base):
|
||||
device=query.device,
|
||||
)
|
||||
|
||||
total_q_tokens = bsz * q_len
|
||||
context_flat = q.new_empty((total_q_tokens, self.num_heads, q.shape[-1]))
|
||||
seq_len = torch.diff(cu_seqlens).to("cpu", dtype=torch.int32)
|
||||
|
||||
st = 0
|
||||
seg_lens = torch.diff(cu_seqlens).to("cpu", dtype=torch.int64).tolist()
|
||||
for seg_len in seg_lens:
|
||||
seg_len = int(seg_len)
|
||||
ed = st + seg_len
|
||||
context_layer = torch.empty_like(q)
|
||||
torch_npu._npu_flash_attention_unpad(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
seq_len=seq_len,
|
||||
scale_value=self.head_size**-0.5,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=context_layer,
|
||||
)
|
||||
|
||||
q_i = q[st:ed].unsqueeze(0) # [1, S, H, D]
|
||||
k_i = k[st:ed].unsqueeze(0)
|
||||
v_i = v[st:ed].unsqueeze(0)
|
||||
|
||||
qs = int(q_i.shape[1])
|
||||
kvs = int(k_i.shape[1])
|
||||
|
||||
out_i = torch_npu.npu_prompt_flash_attention(
|
||||
q_i,
|
||||
k_i,
|
||||
v_i,
|
||||
input_layout="BSND",
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
scale_value=self.head_size**-0.5,
|
||||
pre_tokens=qs,
|
||||
next_tokens=kvs,
|
||||
)
|
||||
context_flat[st:ed] = out_i[0]
|
||||
st = ed
|
||||
|
||||
context_flat = context_flat[..., :origin_dim]
|
||||
context_layer = einops.rearrange(context_flat, "(b s) h d -> b s h d", b=bsz).contiguous()
|
||||
context_layer = einops.rearrange(context_layer, "(b s) h d -> b s h d", b=bsz).contiguous()
|
||||
return context_layer
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
import torch_npu
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend._310p.modelrunner_310p import NPUModelRunner310
|
||||
from vllm_ascend._310p.model_runner_310p import NPUModelRunner310
|
||||
from vllm_ascend.worker.worker import NPUWorker, init_workspace_manager
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user