diff --git a/setup.py b/setup.py index 2f07b7c1..b0c54cdf 100644 --- a/setup.py +++ b/setup.py @@ -160,6 +160,7 @@ def gen_build_info(): "ascend310p3vir02": "_310P", "ascend310p3vir04": "_310P", "ascend310p3vir08": "_310P", + "ascend910_9579": "_910_95", } assert soc_version in soc_to_device, f"Undefined soc_version: {soc_version}. Please file an issue to vllm-ascend." diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index d97acf65..b5e91096 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -37,7 +37,8 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, split_decodes_and_prefills) from vllm_ascend.compilation.acl_graph import (get_graph_params, update_graph_params_workspaces) -from vllm_ascend.utils import weak_ref_tensors +from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type, + weak_ref_tensors) @register_backend(AttentionBackendEnum.CUSTOM, "ASCEND") @@ -541,12 +542,45 @@ class AscendAttentionBackendImpl(AttentionImpl): output[:num_tokens] = attn_output[:num_tokens] return output + def _forward_decode_only_ascend91095( + self, + query: torch.Tensor, + attn_metadata: AscendMetadata, + output: torch.Tensor, + ) -> torch.Tensor: + batch_size = attn_metadata.query_lens.shape[0] + num_block, block_size, _, _ = self.key_cache.shape # type: ignore + key = self.key_cache.view( # type: ignore + num_block, block_size, -1) + value = self.value_cache.view( # type: ignore + num_block, block_size, -1) + actual_seq_lengths_kv = attn_metadata.seq_lens_list + + attn_output, _ = torch_npu.npu_fused_infer_attention_score( + query=query, + key=key, + value=value, + block_table=attn_metadata.block_tables, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=attn_metadata.actual_seq_lengths_q, + actual_seq_lengths_kv=actual_seq_lengths_kv, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale=self.scale, + ) + output[:batch_size] = attn_output[:batch_size] + return output + def _forward_decode_only( self, query: torch.Tensor, attn_metadata: AscendMetadata, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if get_ascend_device_type() == AscendDeviceType._910_95: + return self._forward_decode_only_ascend91095( + query, attn_metadata, output) if self.sliding_window is not None and attn_metadata.seq_lens.shape[ 0] == query.size(0): batch_size = attn_metadata.seq_lens.shape[0] @@ -633,12 +667,20 @@ class AscendAttentionBackendImpl(AttentionImpl): if self.key_cache is None: self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] slots = attn_metadata.slot_mapping - torch_npu._npu_reshape_and_cache( - key=key[:attn_metadata.num_actual_tokens], - value=value[:attn_metadata.num_actual_tokens], - key_cache=self.key_cache, - value_cache=self.value_cache, - slot_indices=slots) + if get_ascend_device_type() == AscendDeviceType._910_95: + torch_npu.npu_scatter_pa_kv_cache( + key=key[:attn_metadata.num_actual_tokens], + value=value[:attn_metadata.num_actual_tokens].contiguous(), + key_cache=self.key_cache, + value_cache=self.value_cache, + slot_mapping=slots) + else: + torch_npu._npu_reshape_and_cache( + key=key[:attn_metadata.num_actual_tokens], + value=value[:attn_metadata.num_actual_tokens], + key_cache=self.key_cache, + value_cache=self.value_cache, + slot_indices=slots) return key, value def forward_impl( diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index eb88e959..ef398fae 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -410,7 +410,8 @@ class AscendMRotaryEmbedding(MRotaryEmbedding): query: torch.Tensor, key: torch.Tensor, ): - if self.mrope_section != [16, 24, 24]: + if self.mrope_section != [16, 24, 24] or \ + get_ascend_device_type() == AscendDeviceType._910_95: return super().forward_oot(positions, query, key) import torch_npu diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index b15d71de..bd7bebfa 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -52,8 +52,9 @@ from vllm_ascend.device_allocator.camem import CaMemAllocator from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton from vllm_ascend.platform import NPUPlatform -from vllm_ascend.utils import (check_ascend_device_type, enable_sp, - is_enable_nz, register_ascend_customop) +from vllm_ascend.utils import (AscendDeviceType, check_ascend_device_type, + enable_sp, get_ascend_device_type, is_enable_nz, + register_ascend_customop) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner torch._dynamo.trace_rules.clear_lru_cache() # noqa: E402 @@ -87,7 +88,8 @@ class NPUWorker(WorkerBase): # Register ops when worker init. from vllm_ascend import ops ops.register_dummy_fusion_op() - _register_atb_extensions() + if get_ascend_device_type() != AscendDeviceType._910_95: + _register_atb_extensions() register_ascend_customop(vllm_config) # init ascend config and soc version init_ascend_config(vllm_config) @@ -356,7 +358,8 @@ class NPUWorker(WorkerBase): self.model_runner.capture_model() # Call ATB matmul to warm up; otherwise, the first operation (ReshapeAndCache) # may cause performance degradation at runtime. - self._warm_up_atb() + if get_ascend_device_type() != AscendDeviceType._910_95: + self._warm_up_atb() # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. NPUPlatform.seed_everything(self.model_config.seed)