vllm-ascend support Ascend950 with Qwen dense model. (#4228)
### What this PR does / why we need it?
vllm-ascend support Ascend950 with Qwen dense model
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: wangyao <iwangyao@outlook.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -37,7 +37,8 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
update_graph_params_workspaces)
|
||||
from vllm_ascend.utils import weak_ref_tensors
|
||||
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
|
||||
weak_ref_tensors)
|
||||
|
||||
|
||||
@register_backend(AttentionBackendEnum.CUSTOM, "ASCEND")
|
||||
@@ -541,12 +542,45 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
output[:num_tokens] = attn_output[:num_tokens]
|
||||
return output
|
||||
|
||||
def _forward_decode_only_ascend91095(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
batch_size = attn_metadata.query_lens.shape[0]
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
key = self.key_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
value = self.value_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||
|
||||
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
block_table=attn_metadata.block_tables,
|
||||
input_layout="TND",
|
||||
block_size=block_size,
|
||||
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale=self.scale,
|
||||
)
|
||||
output[:batch_size] = attn_output[:batch_size]
|
||||
return output
|
||||
|
||||
def _forward_decode_only(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if get_ascend_device_type() == AscendDeviceType._910_95:
|
||||
return self._forward_decode_only_ascend91095(
|
||||
query, attn_metadata, output)
|
||||
if self.sliding_window is not None and attn_metadata.seq_lens.shape[
|
||||
0] == query.size(0):
|
||||
batch_size = attn_metadata.seq_lens.shape[0]
|
||||
@@ -633,12 +667,20 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
if self.key_cache is None:
|
||||
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=key[:attn_metadata.num_actual_tokens],
|
||||
value=value[:attn_metadata.num_actual_tokens],
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_indices=slots)
|
||||
if get_ascend_device_type() == AscendDeviceType._910_95:
|
||||
torch_npu.npu_scatter_pa_kv_cache(
|
||||
key=key[:attn_metadata.num_actual_tokens],
|
||||
value=value[:attn_metadata.num_actual_tokens].contiguous(),
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_mapping=slots)
|
||||
else:
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=key[:attn_metadata.num_actual_tokens],
|
||||
value=value[:attn_metadata.num_actual_tokens],
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_indices=slots)
|
||||
return key, value
|
||||
|
||||
def forward_impl(
|
||||
|
||||
Reference in New Issue
Block a user