[Bugfix][kvcache] revert multiple kv cache groups (#923)

Revert multiple kv cache groups related changes as this feature is
reverted in vllm https://github.com/vllm-project/vllm/pull/18459

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-05-22 15:15:33 +08:00
committed by GitHub
parent b4d6672d01
commit 7aa4f85f10
3 changed files with 34 additions and 30 deletions

View File

@@ -30,7 +30,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import vllm_version_is
class AscendAttentionBackend(AttentionBackend):
@@ -141,14 +140,8 @@ class AscendAttentionMetadataBuilder:
def build(self, num_reqs, num_actual_tokens, max_query_len,
common_prefix_len):
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
block_table = (self.runner.input_batch.block_table.
get_device_tensor()[:num_reqs])
else:
block_table = self.runner.input_batch.block_table[
0].get_device_tensor()
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
block_table[:num_reqs])
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
query_lens = self.runner.query_lens
seq_lens = self.runner.seq_lens_cpu[:num_reqs]

View File

@@ -16,7 +16,6 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
from vllm_ascend.utils import vllm_version_is
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
if TYPE_CHECKING:
@@ -239,12 +238,8 @@ class AscendMLAMetadataBuilder:
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.runner.device
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
block_table = (self.runner.input_batch.block_table.
get_device_tensor()[:num_reqs])
else:
block_table = (self.runner.input_batch.block_table[0].
get_device_tensor()[:num_reqs])
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
device, non_blocking=True)
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(

View File

@@ -164,6 +164,24 @@ class NPUModelRunner:
raise NotImplementedError(
"Non-Attention backend is not supported by V1 NPUModelRunner.")
self.attn_backend = get_attn_backend(
self.head_size,
self.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
)
if self.attn_backend is None:
error_msg = (
f"Error with get_att_backend: {self.head_size=}, "
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
f"{self.model_config.is_attention_free=}, "
f"{self.model_config.use_mla=}")
logger.error(error_msg)
raise NotImplementedError(
"Non-Attention backend is not supported by V1 GPUModelRunner.")
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self))
@@ -196,6 +214,17 @@ class NPUModelRunner:
pin_memory=True,
vocab_size=self.model_config.get_vocab_size(),
)
else:
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=True,
vocab_size=self.model_config.get_vocab_size(),
)
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
@@ -542,10 +571,7 @@ class NPUModelRunner:
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions_np // self.block_size)
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
else:
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
block_offsets = positions_np % self.block_size
np.add(block_numbers * self.block_size,
@@ -960,16 +986,6 @@ class NPUModelRunner:
"""
import torch_npu
kv_caches: Dict[str, torch.Tensor] = {}
if not (vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")):
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=True,
vocab_size=self.model_config.get_vocab_size(),
kv_cache_config=kv_cache_config,
)
for kv_cache_group in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group.kv_cache_spec