vllm-ascend support Ascend950 with Qwen dense model. (#4228)
### What this PR does / why we need it?
vllm-ascend support Ascend950 with Qwen dense model
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: wangyao <iwangyao@outlook.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
1
setup.py
1
setup.py
@@ -160,6 +160,7 @@ def gen_build_info():
|
|||||||
"ascend310p3vir02": "_310P",
|
"ascend310p3vir02": "_310P",
|
||||||
"ascend310p3vir04": "_310P",
|
"ascend310p3vir04": "_310P",
|
||||||
"ascend310p3vir08": "_310P",
|
"ascend310p3vir08": "_310P",
|
||||||
|
"ascend910_9579": "_910_95",
|
||||||
}
|
}
|
||||||
|
|
||||||
assert soc_version in soc_to_device, f"Undefined soc_version: {soc_version}. Please file an issue to vllm-ascend."
|
assert soc_version in soc_to_device, f"Undefined soc_version: {soc_version}. Please file an issue to vllm-ascend."
|
||||||
|
|||||||
@@ -37,7 +37,8 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
|||||||
split_decodes_and_prefills)
|
split_decodes_and_prefills)
|
||||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||||
update_graph_params_workspaces)
|
update_graph_params_workspaces)
|
||||||
from vllm_ascend.utils import weak_ref_tensors
|
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
|
||||||
|
weak_ref_tensors)
|
||||||
|
|
||||||
|
|
||||||
@register_backend(AttentionBackendEnum.CUSTOM, "ASCEND")
|
@register_backend(AttentionBackendEnum.CUSTOM, "ASCEND")
|
||||||
@@ -541,12 +542,45 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
output[:num_tokens] = attn_output[:num_tokens]
|
output[:num_tokens] = attn_output[:num_tokens]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def _forward_decode_only_ascend91095(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
attn_metadata: AscendMetadata,
|
||||||
|
output: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size = attn_metadata.query_lens.shape[0]
|
||||||
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||||
|
key = self.key_cache.view( # type: ignore
|
||||||
|
num_block, block_size, -1)
|
||||||
|
value = self.value_cache.view( # type: ignore
|
||||||
|
num_block, block_size, -1)
|
||||||
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||||
|
|
||||||
|
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||||
|
query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
block_table=attn_metadata.block_tables,
|
||||||
|
input_layout="TND",
|
||||||
|
block_size=block_size,
|
||||||
|
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
|
||||||
|
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||||
|
num_key_value_heads=self.num_kv_heads,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
scale=self.scale,
|
||||||
|
)
|
||||||
|
output[:batch_size] = attn_output[:batch_size]
|
||||||
|
return output
|
||||||
|
|
||||||
def _forward_decode_only(
|
def _forward_decode_only(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
attn_metadata: AscendMetadata,
|
attn_metadata: AscendMetadata,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if get_ascend_device_type() == AscendDeviceType._910_95:
|
||||||
|
return self._forward_decode_only_ascend91095(
|
||||||
|
query, attn_metadata, output)
|
||||||
if self.sliding_window is not None and attn_metadata.seq_lens.shape[
|
if self.sliding_window is not None and attn_metadata.seq_lens.shape[
|
||||||
0] == query.size(0):
|
0] == query.size(0):
|
||||||
batch_size = attn_metadata.seq_lens.shape[0]
|
batch_size = attn_metadata.seq_lens.shape[0]
|
||||||
@@ -633,6 +667,14 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
if self.key_cache is None:
|
if self.key_cache is None:
|
||||||
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
||||||
slots = attn_metadata.slot_mapping
|
slots = attn_metadata.slot_mapping
|
||||||
|
if get_ascend_device_type() == AscendDeviceType._910_95:
|
||||||
|
torch_npu.npu_scatter_pa_kv_cache(
|
||||||
|
key=key[:attn_metadata.num_actual_tokens],
|
||||||
|
value=value[:attn_metadata.num_actual_tokens].contiguous(),
|
||||||
|
key_cache=self.key_cache,
|
||||||
|
value_cache=self.value_cache,
|
||||||
|
slot_mapping=slots)
|
||||||
|
else:
|
||||||
torch_npu._npu_reshape_and_cache(
|
torch_npu._npu_reshape_and_cache(
|
||||||
key=key[:attn_metadata.num_actual_tokens],
|
key=key[:attn_metadata.num_actual_tokens],
|
||||||
value=value[:attn_metadata.num_actual_tokens],
|
value=value[:attn_metadata.num_actual_tokens],
|
||||||
|
|||||||
@@ -410,7 +410,8 @@ class AscendMRotaryEmbedding(MRotaryEmbedding):
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
):
|
):
|
||||||
if self.mrope_section != [16, 24, 24]:
|
if self.mrope_section != [16, 24, 24] or \
|
||||||
|
get_ascend_device_type() == AscendDeviceType._910_95:
|
||||||
return super().forward_oot(positions, query, key)
|
return super().forward_oot(positions, query, key)
|
||||||
|
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
|||||||
@@ -52,8 +52,9 @@ from vllm_ascend.device_allocator.camem import CaMemAllocator
|
|||||||
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||||
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
from vllm_ascend.utils import (check_ascend_device_type, enable_sp,
|
from vllm_ascend.utils import (AscendDeviceType, check_ascend_device_type,
|
||||||
is_enable_nz, register_ascend_customop)
|
enable_sp, get_ascend_device_type, is_enable_nz,
|
||||||
|
register_ascend_customop)
|
||||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||||
|
|
||||||
torch._dynamo.trace_rules.clear_lru_cache() # noqa: E402
|
torch._dynamo.trace_rules.clear_lru_cache() # noqa: E402
|
||||||
@@ -87,6 +88,7 @@ class NPUWorker(WorkerBase):
|
|||||||
# Register ops when worker init.
|
# Register ops when worker init.
|
||||||
from vllm_ascend import ops
|
from vllm_ascend import ops
|
||||||
ops.register_dummy_fusion_op()
|
ops.register_dummy_fusion_op()
|
||||||
|
if get_ascend_device_type() != AscendDeviceType._910_95:
|
||||||
_register_atb_extensions()
|
_register_atb_extensions()
|
||||||
register_ascend_customop(vllm_config)
|
register_ascend_customop(vllm_config)
|
||||||
# init ascend config and soc version
|
# init ascend config and soc version
|
||||||
@@ -356,6 +358,7 @@ class NPUWorker(WorkerBase):
|
|||||||
self.model_runner.capture_model()
|
self.model_runner.capture_model()
|
||||||
# Call ATB matmul to warm up; otherwise, the first operation (ReshapeAndCache)
|
# Call ATB matmul to warm up; otherwise, the first operation (ReshapeAndCache)
|
||||||
# may cause performance degradation at runtime.
|
# may cause performance degradation at runtime.
|
||||||
|
if get_ascend_device_type() != AscendDeviceType._910_95:
|
||||||
self._warm_up_atb()
|
self._warm_up_atb()
|
||||||
# Reset the seed to ensure that the random state is not affected by
|
# Reset the seed to ensure that the random state is not affected by
|
||||||
# the model initialization and profiling.
|
# the model initialization and profiling.
|
||||||
|
|||||||
Reference in New Issue
Block a user