[bugfix] Fix mooncake kvpool accuracy issue (#4976)

### What this PR does / why we need it?

The current KVPool has a accuracy issue
https://github.com/vllm-project/vllm-ascend/issues/4412. This PR aims to
fix the precision problem without impacting prefill performance.

Note:Due to a bug in ADXL, calling `current_event.synchronize()` may
occasionally hang. This issue will be fixed in Cann version 8.5.rc1. You
can manually build the master branch of the project at
https://gitcode.com/cann/hixl to resolve this issue before the 8.5.RC1
release.


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: LCAIZJ <leichao139636@163.com>
This commit is contained in:
Chao Lei
2025-12-16 11:33:16 +08:00
committed by GitHub
parent 9e24bdd44c
commit 9c02fa9867
3 changed files with 38 additions and 2 deletions

View File

@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple, Union
import torch
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
KVConnectorMetadata
from vllm.logger import logger
@@ -284,6 +285,8 @@ class ReqMeta:
is_last_chunk: Optional[bool] = None
current_event: Optional[torch.npu.Event] = None
@staticmethod
def from_request_tracker(
tracker: RequestTracker,
@@ -375,3 +378,4 @@ class LasyerMultiBlockReqMeta:
block_ids: list[int]
layer_id: int
is_last_chunk: Optional[bool] = True
current_event: Optional[torch.npu.Event] = None

View File

@@ -114,6 +114,7 @@ class KVCacheStoreSendingThread(KVTransferThread):
block_ids = req_meta.block_ids
req_id = req_meta.req_id
is_last_chunk = req_meta.is_last_chunk
current_event = req_meta.current_event
starts = []
ends = []
keys = []
@@ -161,6 +162,14 @@ class KVCacheStoreSendingThread(KVTransferThread):
addrs.append(addr)
sizes.append(size)
if keys:
"""
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
This issue will be fixed in CANN version 8.5.rc1.
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
to resolve this issue before the 8.5.RC1 release.
"""
if current_event is not None:
current_event.synchronize()
self.m_store.put(keys, addrs, sizes)
if is_last_chunk:
@@ -235,6 +244,7 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
ends = req_meta.ends
keys = req_meta.keys
layer_id = req_meta.layer_id
current_event = req_meta.current_event
total_block = len(keys)
is_last_chunk = req_meta.is_last_chunk
if not self.dcp_size > 1:
@@ -270,6 +280,8 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
addr_list.append(addr)
size_list.append(size)
if current_event is not None:
current_event.synchronize()
self.m_store.put(key_list, addr_list, size_list)
if layer_id == self.final_layer_id and is_last_chunk:

View File

@@ -251,12 +251,20 @@ class KVPoolWorker:
connector_metadata: AscendConnectorMetadata) -> None:
if self.current_layer == 0:
self.layerwise_storers = []
current_event = None
for request in connector_metadata.requests:
can_save = request.can_save
if can_save is None or not can_save:
continue
current_event = torch.npu.Event()
current_event.record()
break
for request in connector_metadata.requests:
can_save = request.can_save
if can_save is None or not can_save:
continue
layerwise_storer = self.store_layer(request)
layerwise_storer = self.store_layer(request, current_event)
self.layerwise_storers.append(layerwise_storer)
for layerwise_storer in self.layerwise_storers:
try:
@@ -266,11 +274,21 @@ class KVPoolWorker:
self.current_layer = self.current_layer + 1
def wait_for_save(self, connector_metadata: AscendConnectorMetadata):
current_event = None
for request in connector_metadata.requests:
can_save = request.can_save
if can_save is None or not can_save:
continue
current_event = torch.npu.Event()
current_event.record()
break
for request in connector_metadata.requests:
can_save = request.can_save
if can_save is None or not can_save:
continue
request.current_event = current_event
self.kv_send_thread.add_request( # type: ignore[union-attr]
request, )
@@ -347,6 +365,7 @@ class KVPoolWorker:
def store_layer(
self,
request: ReqMeta,
current_event: Optional[torch.npu.Event],
) -> Generator[None, None, None]:
"""
Store the KV cache in a layerwise manner.
@@ -385,7 +404,8 @@ class KVPoolWorker:
keys_multi_chunk, starts,
ends, request.block_ids,
layer_id,
request.is_last_chunk)
request.is_last_chunk,
current_event)
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
req_meta) # type: ignore[union-attr, call-arg, arg-type]
yield