[bugfix] Fix mooncake kvpool accuracy issue (#4976)
### What this PR does / why we need it?
The current KVPool has a accuracy issue
https://github.com/vllm-project/vllm-ascend/issues/4412. This PR aims to
fix the precision problem without impacting prefill performance.
Note:Due to a bug in ADXL, calling `current_event.synchronize()` may
occasionally hang. This issue will be fixed in Cann version 8.5.rc1. You
can manually build the master branch of the project at
https://gitcode.com/cann/hixl to resolve this issue before the 8.5.RC1
release.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: LCAIZJ <leichao139636@163.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
|
||||
KVConnectorMetadata
|
||||
from vllm.logger import logger
|
||||
@@ -284,6 +285,8 @@ class ReqMeta:
|
||||
|
||||
is_last_chunk: Optional[bool] = None
|
||||
|
||||
current_event: Optional[torch.npu.Event] = None
|
||||
|
||||
@staticmethod
|
||||
def from_request_tracker(
|
||||
tracker: RequestTracker,
|
||||
@@ -375,3 +378,4 @@ class LasyerMultiBlockReqMeta:
|
||||
block_ids: list[int]
|
||||
layer_id: int
|
||||
is_last_chunk: Optional[bool] = True
|
||||
current_event: Optional[torch.npu.Event] = None
|
||||
@@ -114,6 +114,7 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
||||
block_ids = req_meta.block_ids
|
||||
req_id = req_meta.req_id
|
||||
is_last_chunk = req_meta.is_last_chunk
|
||||
current_event = req_meta.current_event
|
||||
starts = []
|
||||
ends = []
|
||||
keys = []
|
||||
@@ -161,6 +162,14 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
||||
addrs.append(addr)
|
||||
sizes.append(size)
|
||||
if keys:
|
||||
"""
|
||||
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
|
||||
This issue will be fixed in CANN version 8.5.rc1.
|
||||
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
|
||||
to resolve this issue before the 8.5.RC1 release.
|
||||
"""
|
||||
if current_event is not None:
|
||||
current_event.synchronize()
|
||||
self.m_store.put(keys, addrs, sizes)
|
||||
|
||||
if is_last_chunk:
|
||||
@@ -235,6 +244,7 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
|
||||
ends = req_meta.ends
|
||||
keys = req_meta.keys
|
||||
layer_id = req_meta.layer_id
|
||||
current_event = req_meta.current_event
|
||||
total_block = len(keys)
|
||||
is_last_chunk = req_meta.is_last_chunk
|
||||
if not self.dcp_size > 1:
|
||||
@@ -270,6 +280,8 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
|
||||
if current_event is not None:
|
||||
current_event.synchronize()
|
||||
self.m_store.put(key_list, addr_list, size_list)
|
||||
|
||||
if layer_id == self.final_layer_id and is_last_chunk:
|
||||
|
||||
@@ -251,12 +251,20 @@ class KVPoolWorker:
|
||||
connector_metadata: AscendConnectorMetadata) -> None:
|
||||
if self.current_layer == 0:
|
||||
self.layerwise_storers = []
|
||||
current_event = None
|
||||
for request in connector_metadata.requests:
|
||||
can_save = request.can_save
|
||||
if can_save is None or not can_save:
|
||||
continue
|
||||
current_event = torch.npu.Event()
|
||||
current_event.record()
|
||||
break
|
||||
for request in connector_metadata.requests:
|
||||
can_save = request.can_save
|
||||
if can_save is None or not can_save:
|
||||
continue
|
||||
|
||||
layerwise_storer = self.store_layer(request)
|
||||
layerwise_storer = self.store_layer(request, current_event)
|
||||
self.layerwise_storers.append(layerwise_storer)
|
||||
for layerwise_storer in self.layerwise_storers:
|
||||
try:
|
||||
@@ -266,11 +274,21 @@ class KVPoolWorker:
|
||||
self.current_layer = self.current_layer + 1
|
||||
|
||||
def wait_for_save(self, connector_metadata: AscendConnectorMetadata):
|
||||
current_event = None
|
||||
for request in connector_metadata.requests:
|
||||
can_save = request.can_save
|
||||
if can_save is None or not can_save:
|
||||
continue
|
||||
current_event = torch.npu.Event()
|
||||
current_event.record()
|
||||
break
|
||||
|
||||
for request in connector_metadata.requests:
|
||||
can_save = request.can_save
|
||||
if can_save is None or not can_save:
|
||||
continue
|
||||
|
||||
request.current_event = current_event
|
||||
self.kv_send_thread.add_request( # type: ignore[union-attr]
|
||||
request, )
|
||||
|
||||
@@ -347,6 +365,7 @@ class KVPoolWorker:
|
||||
def store_layer(
|
||||
self,
|
||||
request: ReqMeta,
|
||||
current_event: Optional[torch.npu.Event],
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Store the KV cache in a layerwise manner.
|
||||
@@ -385,7 +404,8 @@ class KVPoolWorker:
|
||||
keys_multi_chunk, starts,
|
||||
ends, request.block_ids,
|
||||
layer_id,
|
||||
request.is_last_chunk)
|
||||
request.is_last_chunk,
|
||||
current_event)
|
||||
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
|
||||
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
||||
yield
|
||||
|
||||
Reference in New Issue
Block a user