[P/D][main] Clean connector history information (#4650)

### What this PR does / why we need it?
Clean connector history information when the node restarts.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
By ci

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiaoteng888
2025-12-05 16:22:23 +08:00
committed by GitHub
parent a336543977
commit 41fbc5ebc9
3 changed files with 53 additions and 5 deletions

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from uuid import uuid4
from vllm.logger import logger
@@ -143,6 +144,11 @@ class AscendConfig:
get_flashcomm2_oproj_tp_size_and_validate_config
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_oproj_tp_size_and_validate_config(
self, vllm_config)
kv_cfg = vllm_config.kv_transfer_config
if kv_cfg is not None and not getattr(kv_cfg, "_engine_id_patched",
False):
kv_cfg.engine_id = f"{kv_cfg.engine_id}-{uuid4().hex}"
kv_cfg._engine_id_patched = True
class AscendCompilationConfig:

View File

@@ -68,6 +68,27 @@ class ReqMeta:
remote_dcp_size: int
@dataclass
class SizedDict(OrderedDict):
def __init__(self, max_size=16000, *args, **kwargs):
self.max_size = max_size
super().__init__(*args, **kwargs)
def __setitem__(self, key, value):
super().__setitem__(key, value)
if len(self) > self.max_size:
self.popitem(last=False)
def __getitem__(self, key):
try:
return super().__getitem__(key)
except KeyError:
value: dict[int, list[int]] = {}
self[key] = value
return value
class KVCacheTaskTracker:
def __init__(self):
@@ -253,11 +274,11 @@ class KVCacheRecvingThread(threading.Thread):
self.ready_event = ready_event
self.kv_caches_base_addr: dict[str, dict[int, list[int]]] = \
defaultdict(dict)
SizedDict()
self.kv_caches_base_addr[local_engine_id][local_handshake_port] = \
local_kv_caches_base_addr
self.remote_te_port: dict[str, dict[int, int]] = \
defaultdict(dict)
SizedDict()
self.block_len = block_len
# TODO(jianzs): find a better way to detect MLA.
self.use_mla = len(block_len) == 2

View File

@@ -8,7 +8,7 @@ import queue
import struct
import threading
import time
from collections import defaultdict, deque
from collections import OrderedDict, defaultdict, deque
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
@@ -67,6 +67,27 @@ class ReqMeta:
metaserver: Optional[str]
@dataclass
class SizedDict(OrderedDict):
def __init__(self, max_size=16000, *args, **kwargs):
self.max_size = max_size
super().__init__(*args, **kwargs)
def __setitem__(self, key, value):
super().__setitem__(key, value)
if len(self) > self.max_size:
self.popitem(last=False)
def __getitem__(self, key):
try:
return super().__getitem__(key)
except KeyError:
value: dict[int, list[int]] = {}
self[key] = value
return value
class KVCacheSendingLayerThread(threading.Thread):
def __init__(self,
@@ -695,9 +716,9 @@ class MooncakeLayerwiseConnectorWorker:
self.encoder = msgspec.msgpack.Encoder()
self.remote_kv_caches_base_addr: dict[str, dict[int, list[int]]] = \
defaultdict(dict)
SizedDict()
self.remote_te_port: dict[str, dict[int, int]] = \
defaultdict(dict)
SizedDict()
self.remote_sockets_lock = threading.Lock()
self.remote_sockets: dict[ # type: ignore
str, deque[zmq.Socket]] = defaultdict( # type: ignore