init v0.11.0rc0

This commit is contained in:
2025-10-14 10:38:28 +08:00
parent 67afd0ea78
commit 66dc16f966
278 changed files with 28130 additions and 11708 deletions

View File

@@ -11,7 +11,7 @@ from collections import defaultdict, deque
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, List, Optional, OrderedDict, Tuple
import msgspec
import numpy as np
@@ -19,6 +19,7 @@ import numpy.typing as npt
import torch
import zmq
from mooncake.engine import TransferEngine # type: ignore
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
@@ -29,6 +30,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
@@ -67,12 +69,16 @@ class KVCacheTaskTracker:
# intentionally delayed. Each entry is a tuple of (request_id,
# timestamp). If a request remains in this queue for too long, it will
# be force-freed.
self.delayed_free_requests: deque[Tuple[str, float]] = deque()
self.record_finished_requests: set[str] = set()
self.delayed_free_requests: OrderedDict[str, float] = OrderedDict()
def update_done_task_count(self, request_id: str):
with self.done_task_lock:
self.finished_requests.add(request_id)
self._remove_delayed_requests(request_id)
if request_id in self.delayed_free_requests:
self._remove_delayed_requests(request_id)
else:
self.record_finished_requests.add(request_id)
def get_and_clear_finished_requests(self) -> set[str]:
"""
@@ -90,7 +96,10 @@ class KVCacheTaskTracker:
def add_delayed_request(self, request_id: str, delay_start_time: float):
"""Add a delayed free request."""
with self.done_task_lock:
self.delayed_free_requests.append((request_id, delay_start_time))
if request_id not in self.record_finished_requests:
self.delayed_free_requests[request_id] = delay_start_time
else:
self.record_finished_requests.discard(request_id)
def _retrieve_expired_requests(self):
"""Retrieve all expired delayed requests."""
@@ -98,10 +107,11 @@ class KVCacheTaskTracker:
# Free delayed requests if they exceed the timeout
current_time = time.time()
while self.delayed_free_requests:
request_id, delay_start_time = self.delayed_free_requests[0]
request_id = next(iter(self.delayed_free_requests))
delay_start_time = self.delayed_free_requests[request_id]
if (current_time - delay_start_time
> envs_ascend.VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT):
self.delayed_free_requests.popleft()
> envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT):
self.delayed_free_requests.popitem(last=False)
expired_requests.add(request_id)
logger.info("Force freed request: %s", request_id)
else:
@@ -110,8 +120,7 @@ class KVCacheTaskTracker:
def _remove_delayed_requests(self, request_id: str):
"""Remove all delayed free requests matching the given request_id."""
self.delayed_free_requests = deque(
(r, t) for r, t in self.delayed_free_requests if r != request_id)
self.delayed_free_requests.pop(request_id)
class KVCacheSendingThread(threading.Thread):
@@ -230,6 +239,7 @@ class KVCacheRecvingThread(threading.Thread):
self.block_len = block_len
# TODO(jianzs): find a better way to detect MLA.
self.use_mla = len(block_len) == 2
self.use_sfa = len(block_len) == 3
self.request_queue: queue.Queue[Any] = queue.Queue()
# TODO(jianzs): make this configurable
@@ -341,8 +351,12 @@ class KVCacheRecvingThread(threading.Thread):
src_list, dst_list, length_list = [], [], []
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)):
block_len = (self.block_len[k % 2]
if self.use_mla else self.block_len[0])
if self.use_mla:
block_len = (self.block_len[k % 2])
elif self.use_sfa:
block_len = (self.block_len[k % 3])
else:
block_len = (self.block_len[0])
for i, remote_block_id in enumerate(grouped_remote_block_ids):
local_block_ids = grouped_local_block_ids[i]
src = src_layer_base_addr + local_block_ids[0] * block_len
@@ -559,6 +573,7 @@ class MooncakeConnectorScheduler:
def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.vllm_config = vllm_config
self.ascend_config = get_ascend_config()
self.block_size = vllm_config.cache_config.block_size
self.engine_id = engine_id
logger.info("Initializing Mooncake Scheduler %s", engine_id)
@@ -718,7 +733,7 @@ class MooncakeConnectorScheduler:
assert "tp_size" in decode_parallel_config.keys()
self._decode_tp_size = decode_parallel_config["tp_size"]
if self.vllm_config.model_config.use_mla:
if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa:
return self._decode_tp_size
else:
# TODO support mha and gqa
@@ -782,10 +797,12 @@ class MooncakeConnectorWorker:
assert len(device_ids) > self.tp_rank # type: ignore
self.device_id = device_ids[self.tp_rank] # type: ignore
self._initialize(
hostname=self.side_channel_host + ':' + '0' + ':' + 'npu_' \
+ str(self.device_id),
device_name=None)
if vllm_config.kv_transfer_config.get_from_extra_config(
'use_ascend_direct', False):
hostname = self.side_channel_host
else:
hostname = f"{self.side_channel_host}:0:npu_{self.device_id}"
self._initialize(hostname=hostname, device_name=None)
self.te_rpc_port = self.engine.get_rpc_port()
# Background thread for sending or receiving KV caches.
@@ -837,7 +854,9 @@ class MooncakeConnectorWorker:
# TODO(tms): Find a more robust way to detect and handle MLA
self.use_mla = first_kv_cache_tuple[0].size(
-1) != first_kv_cache_tuple[1].size(-1)
-1) != first_kv_cache_tuple[1].size(-1) and len(
first_kv_cache_tuple) == 2
self.use_sfa = len(first_kv_cache_tuple) == 3
if self.use_mla:
# MLA case.[num_block, block_size, 1, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
@@ -851,6 +870,21 @@ class MooncakeConnectorWorker:
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
self.num_blocks, block_shape_norm, block_shape_pe)
elif self.use_sfa:
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
first_kv_cache[2].element_size() * math.prod(block_shape_k)
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s",
self.num_blocks, block_shape_norm, block_shape_pe,
block_shape_k)
else:
# [num_block, block_size, num_head, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
@@ -861,8 +895,9 @@ class MooncakeConnectorWorker:
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
block_shape)
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
self.use_mla, first_kv_cache.shape)
logger.info(
"Registering KV_Caches. use_mla: %s, use_sfa: %s, shape %s",
self.use_mla, self.use_sfa, first_kv_cache.shape)
self.kv_caches = kv_caches
kv_caches_base_addr = []
@@ -874,9 +909,16 @@ class MooncakeConnectorWorker:
region_len = self.num_blocks * self.block_len[i % 2]
kv_caches_base_addr.append(base_addr)
self._register(base_addr, region_len)
elif self.use_sfa:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % 3]
kv_caches_base_addr.append(base_addr)
self._register(base_addr, region_len)
else:
cache_list = [cache_or_caches
] if self.use_mla else cache_or_caches
cache_list = [
cache_or_caches
] if self.use_mla or self.use_sfa else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[0]