[P/D][main] Clean connector history information (#4650)

### What this PR does / why we need it?
Clean connector history information when the node restarts.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
By ci

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiaoteng888
2025-12-05 16:22:23 +08:00
committed by GitHub
parent a336543977
commit 41fbc5ebc9
3 changed files with 53 additions and 5 deletions

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional from typing import Optional
from uuid import uuid4
from vllm.logger import logger from vllm.logger import logger
@@ -143,6 +144,11 @@ class AscendConfig:
get_flashcomm2_oproj_tp_size_and_validate_config get_flashcomm2_oproj_tp_size_and_validate_config
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_oproj_tp_size_and_validate_config( self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_oproj_tp_size_and_validate_config(
self, vllm_config) self, vllm_config)
kv_cfg = vllm_config.kv_transfer_config
if kv_cfg is not None and not getattr(kv_cfg, "_engine_id_patched",
False):
kv_cfg.engine_id = f"{kv_cfg.engine_id}-{uuid4().hex}"
kv_cfg._engine_id_patched = True
class AscendCompilationConfig: class AscendCompilationConfig:

View File

@@ -68,6 +68,27 @@ class ReqMeta:
remote_dcp_size: int remote_dcp_size: int
@dataclass
class SizedDict(OrderedDict):
def __init__(self, max_size=16000, *args, **kwargs):
self.max_size = max_size
super().__init__(*args, **kwargs)
def __setitem__(self, key, value):
super().__setitem__(key, value)
if len(self) > self.max_size:
self.popitem(last=False)
def __getitem__(self, key):
try:
return super().__getitem__(key)
except KeyError:
value: dict[int, list[int]] = {}
self[key] = value
return value
class KVCacheTaskTracker: class KVCacheTaskTracker:
def __init__(self): def __init__(self):
@@ -253,11 +274,11 @@ class KVCacheRecvingThread(threading.Thread):
self.ready_event = ready_event self.ready_event = ready_event
self.kv_caches_base_addr: dict[str, dict[int, list[int]]] = \ self.kv_caches_base_addr: dict[str, dict[int, list[int]]] = \
defaultdict(dict) SizedDict()
self.kv_caches_base_addr[local_engine_id][local_handshake_port] = \ self.kv_caches_base_addr[local_engine_id][local_handshake_port] = \
local_kv_caches_base_addr local_kv_caches_base_addr
self.remote_te_port: dict[str, dict[int, int]] = \ self.remote_te_port: dict[str, dict[int, int]] = \
defaultdict(dict) SizedDict()
self.block_len = block_len self.block_len = block_len
# TODO(jianzs): find a better way to detect MLA. # TODO(jianzs): find a better way to detect MLA.
self.use_mla = len(block_len) == 2 self.use_mla = len(block_len) == 2

View File

@@ -8,7 +8,7 @@ import queue
import struct import struct
import threading import threading
import time import time
from collections import defaultdict, deque from collections import OrderedDict, defaultdict, deque
from collections.abc import Iterator from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
@@ -67,6 +67,27 @@ class ReqMeta:
metaserver: Optional[str] metaserver: Optional[str]
@dataclass
class SizedDict(OrderedDict):
def __init__(self, max_size=16000, *args, **kwargs):
self.max_size = max_size
super().__init__(*args, **kwargs)
def __setitem__(self, key, value):
super().__setitem__(key, value)
if len(self) > self.max_size:
self.popitem(last=False)
def __getitem__(self, key):
try:
return super().__getitem__(key)
except KeyError:
value: dict[int, list[int]] = {}
self[key] = value
return value
class KVCacheSendingLayerThread(threading.Thread): class KVCacheSendingLayerThread(threading.Thread):
def __init__(self, def __init__(self,
@@ -695,9 +716,9 @@ class MooncakeLayerwiseConnectorWorker:
self.encoder = msgspec.msgpack.Encoder() self.encoder = msgspec.msgpack.Encoder()
self.remote_kv_caches_base_addr: dict[str, dict[int, list[int]]] = \ self.remote_kv_caches_base_addr: dict[str, dict[int, list[int]]] = \
defaultdict(dict) SizedDict()
self.remote_te_port: dict[str, dict[int, int]] = \ self.remote_te_port: dict[str, dict[int, int]] = \
defaultdict(dict) SizedDict()
self.remote_sockets_lock = threading.Lock() self.remote_sockets_lock = threading.Lock()
self.remote_sockets: dict[ # type: ignore self.remote_sockets: dict[ # type: ignore
str, deque[zmq.Socket]] = defaultdict( # type: ignore str, deque[zmq.Socket]] = defaultdict( # type: ignore