diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 31bf2190..8b485c17 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional +from uuid import uuid4 from vllm.logger import logger @@ -143,6 +144,11 @@ class AscendConfig: get_flashcomm2_oproj_tp_size_and_validate_config self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_oproj_tp_size_and_validate_config( self, vllm_config) + kv_cfg = vllm_config.kv_transfer_config + if kv_cfg is not None and not getattr(kv_cfg, "_engine_id_patched", + False): + kv_cfg.engine_id = f"{kv_cfg.engine_id}-{uuid4().hex}" + kv_cfg._engine_id_patched = True class AscendCompilationConfig: diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 2d376058..8fe21146 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -68,6 +68,27 @@ class ReqMeta: remote_dcp_size: int +@dataclass +class SizedDict(OrderedDict): + + def __init__(self, max_size=16000, *args, **kwargs): + self.max_size = max_size + super().__init__(*args, **kwargs) + + def __setitem__(self, key, value): + super().__setitem__(key, value) + if len(self) > self.max_size: + self.popitem(last=False) + + def __getitem__(self, key): + try: + return super().__getitem__(key) + except KeyError: + value: dict[int, list[int]] = {} + self[key] = value + return value + + class KVCacheTaskTracker: def __init__(self): @@ -253,11 +274,11 @@ class KVCacheRecvingThread(threading.Thread): self.ready_event = ready_event self.kv_caches_base_addr: dict[str, dict[int, list[int]]] = \ - defaultdict(dict) + SizedDict() self.kv_caches_base_addr[local_engine_id][local_handshake_port] = \ local_kv_caches_base_addr self.remote_te_port: dict[str, dict[int, int]] = \ - defaultdict(dict) + SizedDict() self.block_len = block_len # TODO(jianzs): find a better way to detect MLA. self.use_mla = len(block_len) == 2 diff --git a/vllm_ascend/distributed/mooncake_layerwise_connector.py b/vllm_ascend/distributed/mooncake_layerwise_connector.py index f85549bd..1f5c44e1 100644 --- a/vllm_ascend/distributed/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/mooncake_layerwise_connector.py @@ -8,7 +8,7 @@ import queue import struct import threading import time -from collections import defaultdict, deque +from collections import OrderedDict, defaultdict, deque from collections.abc import Iterator from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass @@ -67,6 +67,27 @@ class ReqMeta: metaserver: Optional[str] +@dataclass +class SizedDict(OrderedDict): + + def __init__(self, max_size=16000, *args, **kwargs): + self.max_size = max_size + super().__init__(*args, **kwargs) + + def __setitem__(self, key, value): + super().__setitem__(key, value) + if len(self) > self.max_size: + self.popitem(last=False) + + def __getitem__(self, key): + try: + return super().__getitem__(key) + except KeyError: + value: dict[int, list[int]] = {} + self[key] = value + return value + + class KVCacheSendingLayerThread(threading.Thread): def __init__(self, @@ -695,9 +716,9 @@ class MooncakeLayerwiseConnectorWorker: self.encoder = msgspec.msgpack.Encoder() self.remote_kv_caches_base_addr: dict[str, dict[int, list[int]]] = \ - defaultdict(dict) + SizedDict() self.remote_te_port: dict[str, dict[int, int]] = \ - defaultdict(dict) + SizedDict() self.remote_sockets_lock = threading.Lock() self.remote_sockets: dict[ # type: ignore str, deque[zmq.Socket]] = defaultdict( # type: ignore