diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 36ac972..61e26e1 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -30,6 +30,7 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch from vllm_ascend.ops.attention import vanilla_chunked_prefill +from vllm_ascend.utils import vllm_version_is class AscendAttentionBackend(AttentionBackend): @@ -141,8 +142,14 @@ class AscendAttentionMetadataBuilder: def build(self, num_reqs, num_actual_tokens, max_query_len, common_prefix_len): - block_table = ( - self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) + if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"): + block_table = (self.runner.input_batch.block_table. + get_device_tensor()[:num_reqs]) + else: + block_table = self.runner.input_batch.block_table[ + 0].get_device_tensor() + block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = ( + block_table[:num_reqs]) query_lens = self.runner.query_lens seq_lens = self.runner.seq_lens_cpu[:num_reqs] diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index d987eab..eb40f41 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -16,6 +16,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla +from vllm_ascend.utils import vllm_version_is from vllm_ascend.worker.model_runner_v1 import NPUModelRunner if TYPE_CHECKING: @@ -238,8 +239,14 @@ class AscendMLAMetadataBuilder: # function. We should avoid GPU -> CPU sync as much as possible because # it blocks on all previous kernels. device = self.runner.device - block_table = ( - self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) + if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"): + block_table = (self.runner.input_batch.block_table. + get_device_tensor()[:num_reqs]) + else: + block_table = self.runner.input_batch.block_table[ + 0].get_device_tensor() + block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = ( + block_table[:num_reqs]) slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( device, non_blocking=True) input_positions = self.runner.positions_cpu[:num_actual_tokens].to( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 2ee7426..91f8195 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -114,6 +114,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): def __init__(self, vllm_config: VllmConfig, device: torch.device): self.vllm_config = vllm_config self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config self.lora_config = vllm_config.lora_config self.scheduler_config = vllm_config.scheduler_config self.speculative_config = vllm_config.speculative_config @@ -172,24 +173,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): raise NotImplementedError( "Non-Attention backend is not supported by V1 NPUModelRunner.") - self.attn_backend = get_attn_backend( - self.head_size, - self.dtype, - self.kv_cache_dtype, - self.block_size, - self.model_config.is_attention_free, - use_mla=self.model_config.use_mla, - ) - if self.attn_backend is None: - error_msg = ( - f"Error with get_att_backend: {self.head_size=}, " - f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, " - f"{self.model_config.is_attention_free=}, " - f"{self.model_config.use_mla=}") - logger.error(error_msg) - raise NotImplementedError( - "Non-Attention backend is not supported by V1 GPUModelRunner.") - self.attn_metadata_builder = self.attn_backend.get_builder_cls()( weakref.proxy(self)) @@ -237,16 +220,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): pin_memory=True, vocab_size=self.model_config.get_vocab_size(), ) - else: - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.model_config.max_model_len, - max_num_blocks_per_req=self.max_num_blocks_per_req, - max_num_batched_tokens=self.max_num_tokens, - device=self.device, - pin_memory=True, - vocab_size=self.model_config.get_vocab_size(), - ) self.input_ids = torch.zeros(self.max_num_tokens, dtype=torch.int32, @@ -600,7 +573,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): block_table_indices = (req_indices * self.max_num_blocks_per_req + positions_np // self.block_size) - block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"): + block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + else: + block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() block_offsets = positions_np % self.block_size np.add(block_numbers * self.block_size, @@ -1206,6 +1182,16 @@ class NPUModelRunner(LoRAModelRunnerMixin): """ import torch_npu kv_caches: Dict[str, torch.Tensor] = {} + if not (vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")): + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.model_config.max_model_len, + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=True, + vocab_size=self.model_config.get_vocab_size(), + block_size=self.cache_config.block_size, + ) for kv_cache_group in kv_cache_config.kv_cache_groups: kv_cache_spec = kv_cache_group.kv_cache_spec