[bugfix] Fixing KV Pool Memory Retention and Performance Degradation Issues (#5751)
### What this PR does / why we need it?
1.Fixed memory retention on certain GPUs caused by missing PUT
operations.
2.Fixed performance degradation resulting from architectural
incompatibilities in the underlying refactor.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: fems14 <1804143737@qq.com>
This commit is contained in:
@@ -6,6 +6,9 @@ from unittest.mock import MagicMock
|
|||||||
fake_engine = types.ModuleType("mooncake.engine")
|
fake_engine = types.ModuleType("mooncake.engine")
|
||||||
fake_engine.TransferEngine = MagicMock() # type: ignore[attr-defined]
|
fake_engine.TransferEngine = MagicMock() # type: ignore[attr-defined]
|
||||||
sys.modules["mooncake.engine"] = fake_engine
|
sys.modules["mooncake.engine"] = fake_engine
|
||||||
|
fake_store = types.ModuleType("mooncake.store")
|
||||||
|
fake_store.ReplicateConfig = MagicMock() # type: ignore[attr-defined]
|
||||||
|
sys.modules["mooncake.store"] = fake_store
|
||||||
|
|
||||||
from vllm_ascend.distributed.kvpool.backend.mooncake_backend import ( # noqa: E402
|
from vllm_ascend.distributed.kvpool.backend.mooncake_backend import ( # noqa: E402
|
||||||
_convert_to_bytes, _parse_global_segment_size)
|
_convert_to_bytes, _parse_global_segment_size)
|
||||||
|
|||||||
@@ -136,21 +136,9 @@ class AscendStoreConnector(KVConnectorBase_V1):
|
|||||||
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||||
"""Get the finished recving and sending requests."""
|
"""Get the finished recving and sending requests."""
|
||||||
assert self.connector_worker is not None
|
assert self.connector_worker is not None
|
||||||
meta = self._get_connector_metadata()
|
|
||||||
done_sending, done_recving = self.connector_worker.get_finished(
|
done_sending, done_recving = self.connector_worker.get_finished(
|
||||||
finished_req_ids)
|
finished_req_ids)
|
||||||
sended_and_finished: set[str] = set()
|
return done_sending, done_recving
|
||||||
for item in list(self.sended_but_unfinished_reqs):
|
|
||||||
if item not in meta.unfinished_request_ids:
|
|
||||||
sended_and_finished.add(item)
|
|
||||||
self.sended_but_unfinished_reqs.remove(item)
|
|
||||||
for item in done_sending:
|
|
||||||
if item in meta.unfinished_request_ids:
|
|
||||||
self.sended_but_unfinished_reqs.add(item)
|
|
||||||
else:
|
|
||||||
sended_and_finished.add(item)
|
|
||||||
|
|
||||||
return sended_and_finished, done_recving
|
|
||||||
|
|
||||||
|
|
||||||
class LookupKeyServer:
|
class LookupKeyServer:
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
# Third Party
|
# Third Party
|
||||||
|
from mooncake.store import ReplicateConfig # type: ignore
|
||||||
from vllm.config import ParallelConfig
|
from vllm.config import ParallelConfig
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.utils.network_utils import get_ip
|
from vllm.utils.network_utils import get_ip
|
||||||
@@ -56,7 +57,11 @@ class MooncakeBackend(Backend):
|
|||||||
def put(self, keys: list[str], addrs: list[list[int]],
|
def put(self, keys: list[str], addrs: list[list[int]],
|
||||||
sizes: list[list[int]]):
|
sizes: list[list[int]]):
|
||||||
try:
|
try:
|
||||||
res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes)
|
config = ReplicateConfig()
|
||||||
|
config.preferred_segment = self.local_seg
|
||||||
|
config.prefer_alloc_in_same_node = True
|
||||||
|
res = self.store.batch_put_from_multi_buffers(
|
||||||
|
keys, addrs, sizes, config)
|
||||||
for value in res:
|
for value in res:
|
||||||
if value < 0:
|
if value < 0:
|
||||||
logger.error(f"Failed to put key {keys},res:{res}")
|
logger.error(f"Failed to put key {keys},res:{res}")
|
||||||
@@ -66,7 +71,8 @@ class MooncakeBackend(Backend):
|
|||||||
def get(self, keys: list[str], addrs: list[list[int]],
|
def get(self, keys: list[str], addrs: list[list[int]],
|
||||||
sizes: list[list[int]]):
|
sizes: list[list[int]]):
|
||||||
try:
|
try:
|
||||||
res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes)
|
res = self.store.batch_get_into_multi_buffers(
|
||||||
|
keys, addrs, sizes, True)
|
||||||
for value in res:
|
for value in res:
|
||||||
if value < 0:
|
if value < 0:
|
||||||
logger.error(f"Failed to get key {keys}, res:{res}")
|
logger.error(f"Failed to get key {keys}, res:{res}")
|
||||||
|
|||||||
@@ -223,6 +223,8 @@ class LoadSpec:
|
|||||||
# Whether the scheduler allow us to load the tokens
|
# Whether the scheduler allow us to load the tokens
|
||||||
can_load: bool
|
can_load: bool
|
||||||
|
|
||||||
|
token_len: int = 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RequestTracker:
|
class RequestTracker:
|
||||||
|
|||||||
@@ -125,7 +125,6 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
|||||||
token_len = req_meta.token_len_chunk
|
token_len = req_meta.token_len_chunk
|
||||||
block_ids = req_meta.block_ids
|
block_ids = req_meta.block_ids
|
||||||
req_id = req_meta.req_id
|
req_id = req_meta.req_id
|
||||||
is_last_chunk = req_meta.is_last_chunk
|
|
||||||
current_event = req_meta.current_event
|
current_event = req_meta.current_event
|
||||||
starts = []
|
starts = []
|
||||||
ends = []
|
ends = []
|
||||||
@@ -142,15 +141,15 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
|||||||
keys = keys[self.tp_rank % self.put_step::self.put_step]
|
keys = keys[self.tp_rank % self.put_step::self.put_step]
|
||||||
|
|
||||||
if not keys:
|
if not keys:
|
||||||
if is_last_chunk:
|
with self.done_task_lock:
|
||||||
self.set_finished_request(req_id)
|
self.stored_requests[req_id] -= 1
|
||||||
return
|
return
|
||||||
|
|
||||||
skip_block_num = self.lookup(keys)
|
skip_block_num = self.lookup(keys)
|
||||||
|
|
||||||
if skip_block_num == len(keys):
|
if skip_block_num == len(keys):
|
||||||
if is_last_chunk:
|
with self.done_task_lock:
|
||||||
self.set_finished_request(req_id)
|
self.stored_requests[req_id] -= 1
|
||||||
return
|
return
|
||||||
|
|
||||||
starts = starts[skip_block_num:]
|
starts = starts[skip_block_num:]
|
||||||
@@ -208,6 +207,7 @@ class KVCacheStoreRecvingThread(KVTransferThread):
|
|||||||
name="KVCacheStoreRecvingThread")
|
name="KVCacheStoreRecvingThread")
|
||||||
|
|
||||||
def _handle_request(self, req_meta: ReqMeta):
|
def _handle_request(self, req_meta: ReqMeta):
|
||||||
|
token_len = req_meta.load_spec.token_len # type: ignore[union-attr]
|
||||||
req_id = req_meta.req_id
|
req_id = req_meta.req_id
|
||||||
mask_num = (
|
mask_num = (
|
||||||
req_meta.load_spec.vllm_cached_tokens # type: ignore[union-attr]
|
req_meta.load_spec.vllm_cached_tokens # type: ignore[union-attr]
|
||||||
@@ -216,7 +216,7 @@ class KVCacheStoreRecvingThread(KVTransferThread):
|
|||||||
size_list = []
|
size_list = []
|
||||||
key_list = []
|
key_list = []
|
||||||
for start, end, key in self.token_database.process_tokens(
|
for start, end, key in self.token_database.process_tokens(
|
||||||
req_meta.token_len_chunk, req_meta.block_hashes, mask_num):
|
token_len, req_meta.block_hashes, mask_num):
|
||||||
addr, size, _ = self.token_database.prepare_value(
|
addr, size, _ = self.token_database.prepare_value(
|
||||||
start, end, req_meta.block_ids)
|
start, end, req_meta.block_ids)
|
||||||
key_list.append(key.to_string())
|
key_list.append(key.to_string())
|
||||||
|
|||||||
@@ -134,6 +134,12 @@ class KVPoolWorker:
|
|||||||
self.use_mla, partitions)
|
self.use_mla, partitions)
|
||||||
|
|
||||||
real_backend = backend_map.get(self.backend.lower())
|
real_backend = backend_map.get(self.backend.lower())
|
||||||
|
|
||||||
|
# be removed later
|
||||||
|
if self.backend == "mooncake":
|
||||||
|
self.head_or_tp_rank = self.tp_rank
|
||||||
|
self.put_step = 1
|
||||||
|
|
||||||
self.m_store = real_backend( # type: ignore[misc]
|
self.m_store = real_backend( # type: ignore[misc]
|
||||||
parallel_config)
|
parallel_config)
|
||||||
|
|
||||||
@@ -245,7 +251,7 @@ class KVPoolWorker:
|
|||||||
token_len = request.load_spec.kvpool_cached_tokens + 1
|
token_len = request.load_spec.kvpool_cached_tokens + 1
|
||||||
else:
|
else:
|
||||||
token_len = request.load_spec.kvpool_cached_tokens
|
token_len = request.load_spec.kvpool_cached_tokens
|
||||||
request.token_len_chunk = token_len
|
request.load_spec.token_len = token_len
|
||||||
if self.use_layerwise:
|
if self.use_layerwise:
|
||||||
layerwise_retriever = self.retrieve_layer(request)
|
layerwise_retriever = self.retrieve_layer(request)
|
||||||
next(layerwise_retriever) # first layer load
|
next(layerwise_retriever) # first layer load
|
||||||
|
|||||||
Reference in New Issue
Block a user