[Bugfix] Adjust inputbatch to be compatible with latest vllm (#945)

Adjust inputbatch to be compatible with latest vllm, as kvcache group
feature has been redo in https://github.com/vllm-project/vllm/pull/18593

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-05-26 10:33:28 +08:00
committed by GitHub
parent 1f9fb869ad
commit a0c3e9ba50
3 changed files with 33 additions and 33 deletions

View File

@@ -30,6 +30,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import vllm_version_is
class AscendAttentionBackend(AttentionBackend):
@@ -141,8 +142,14 @@ class AscendAttentionMetadataBuilder:
def build(self, num_reqs, num_actual_tokens, max_query_len,
common_prefix_len):
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
block_table = (self.runner.input_batch.block_table.
get_device_tensor()[:num_reqs])
else:
block_table = self.runner.input_batch.block_table[
0].get_device_tensor()
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
block_table[:num_reqs])
query_lens = self.runner.query_lens
seq_lens = self.runner.seq_lens_cpu[:num_reqs]

View File

@@ -16,6 +16,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
from vllm_ascend.utils import vllm_version_is
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
if TYPE_CHECKING:
@@ -238,8 +239,14 @@ class AscendMLAMetadataBuilder:
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.runner.device
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
block_table = (self.runner.input_batch.block_table.
get_device_tensor()[:num_reqs])
else:
block_table = self.runner.input_batch.block_table[
0].get_device_tensor()
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
block_table[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
device, non_blocking=True)
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(

View File

@@ -114,6 +114,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
@@ -172,24 +173,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
raise NotImplementedError(
"Non-Attention backend is not supported by V1 NPUModelRunner.")
self.attn_backend = get_attn_backend(
self.head_size,
self.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
)
if self.attn_backend is None:
error_msg = (
f"Error with get_att_backend: {self.head_size=}, "
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
f"{self.model_config.is_attention_free=}, "
f"{self.model_config.use_mla=}")
logger.error(error_msg)
raise NotImplementedError(
"Non-Attention backend is not supported by V1 GPUModelRunner.")
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self))
@@ -237,16 +220,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
pin_memory=True,
vocab_size=self.model_config.get_vocab_size(),
)
else:
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=True,
vocab_size=self.model_config.get_vocab_size(),
)
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
@@ -600,7 +573,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions_np // self.block_size)
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
else:
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
block_offsets = positions_np % self.block_size
np.add(block_numbers * self.block_size,
@@ -1206,6 +1182,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
"""
import torch_npu
kv_caches: Dict[str, torch.Tensor] = {}
if not (vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")):
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=True,
vocab_size=self.model_config.get_vocab_size(),
block_size=self.cache_config.block_size,
)
for kv_cache_group in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group.kv_cache_spec