From a486ff8c11ae258e35e6e0b11a0743172f8fb112 Mon Sep 17 00:00:00 2001 From: Chao Lei Date: Tue, 30 Sep 2025 15:10:29 +0800 Subject: [PATCH] KVCache Transfer via Layer-wise Strategy in Disaggregation (#2602) ### What this PR does / why we need it? See RFC: https://github.com/vllm-project/vllm-ascend/issues/2470 This PR add a new kv connector for layer-wised kv transfer ### Does this PR introduce _any_ user-facing change? yes, a new kv connector is added. User can use layer wised feature now. ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/releases/v0.11.0 --------- Signed-off-by: leichao.lc Signed-off-by: CaveNightingale <2859066733@qq.com> Signed-off-by: nwpu-zxr Signed-off-by: wangxiaoteng Signed-off-by: hanxinlong <50882499@qq.com> Signed-off-by: liziyu Co-authored-by: CaveNightingale <2859066733@qq.com> Co-authored-by: nwpu-zxr Co-authored-by: wangxiaoteng Co-authored-by: hanxinlong <50882499@qq.com> --- ..._balance_proxy_layerwise_server_example.py | 576 +++++++ .../load_balance_proxy_server_example.py | 2 +- tests/ut/distributed/test_parallel_state.py | 9 +- .../test_mooncake_layerwise_connector.py | 1001 ++++++++++++ vllm_ascend/ascend_config.py | 11 + vllm_ascend/distributed/__init__.py | 5 + vllm_ascend/distributed/mooncake_connector.py | 2 +- .../mooncake_layerwise_connector.py | 1335 +++++++++++++++++ vllm_ascend/distributed/parallel_state.py | 28 + vllm_ascend/distributed/utils.py | 47 + 10 files changed, 3012 insertions(+), 4 deletions(-) create mode 100644 examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py create mode 100644 tests/ut/kv_connector/test_mooncake_layerwise_connector.py create mode 100644 vllm_ascend/distributed/mooncake_layerwise_connector.py create mode 100644 vllm_ascend/distributed/utils.py diff --git a/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py b/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py new file mode 100644 index 0000000..61d4201 --- /dev/null +++ b/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py @@ -0,0 +1,576 @@ +# Adapted from https://github.com/vllm-project/vllm/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py + +# SPDX-License-Identifier: Apache-2.0 +# +# Tutorial: Using the Load Balance Proxy Server Example +# +# This proxy server is designed to distribute requests between multiple +# "prefiller" and "decoder" backend servers for large language model inference. +# It is useful for scaling out inference workloads and balancing load across +# multiple backend instances. +# +# Features: +# - Load balances requests to multiple prefiller and decoder servers. +# - Supports OpenAI-compatible /v1/completions and /v1/chat/completions endpoints. +# - Streams responses from backend servers to clients. +# +# Prerequisites: +# - Python 3.8+ +# - Install dependencies: +# pip install fastapi httpx uvicorn vllm +# +# Step 1: Start Your Backend Servers +# ---------------------------------- +# You need to have at least one prefiller and one decoder backend running. +# These can be mock servers or actual vLLM servers. +# +# For testing, you can use the provided mock server: +# +# vllm serve --host 0.0.0.0 --port 8100 ... # Prefiller 1 +# vllm serve --host 0.0.0.0 --port 8101 ... # Prefiller 2 +# vllm serve --host 0.0.0.0 --port 8200 ... # Decoder 1 +# vllm serve --host 0.0.0.0 --port 8201 ... # Decoder 2 +# +# Step 2: Start the Proxy Server +# ------------------------------ +# Run the proxy server, specifying the host/port for each prefiller and decoder: +# +# python load_balance_proxy_server_example.py \ +# --host 0.0.0.0 --port 9000 \ +# --prefiller-hosts 127.0.0.1 127.0.0.1 \ +# --prefiller-ports 8100 8101 \ +# --decoder-hosts 127.0.0.1 127.0.0.1 \ +# --decoder-ports 8200 8201 +# +# This will start the proxy on port 9000, load balancing between two prefiller +# and two decoder servers. +# +# Step 3: Send a Request to the Proxy +# ----------------------------------- +# You can now send OpenAI-compatible requests to the proxy. For example: +# +# curl -X POST http://localhost:9000/v1/completions \ +# -H "Content-Type: application/json" \ +# -d '{ +# "model": "your-model", +# "prompt": "The quick brown fox jumps over the lazy dog", +# "max_tokens": 16 +# }' +# +# Or for chat completions: +# +# curl -X POST http://localhost:9000/v1/chat/completions \ +# -H "Content-Type: application/json" \ +# -d '{ +# "model": "your-model", +# "messages": [{"role": "user", "content": "Hello!"}], +# "max_tokens": 16 +# }' +# +# Step 4: Health Check +# -------------------- +# To check if the proxy is running and see how many backend instances are +# connected, use: +# +# curl http://localhost:9000/healthcheck +# +# This will return a JSON object with the status and the number of prefiller +# and decoder instances. +# +# Notes: +# - You can scale the number of prefiller and decoder servers as needed. +# - The proxy will round-robin requests to balance load. +# - For production, ensure your backend servers are robust and secure. +# +# For more details, see the code and comments in this file. + + +import argparse +import asyncio +import functools +import heapq +import os +import sys +import uuid +import threading +from contextlib import asynccontextmanager +from typing import List + +import httpx +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +from vllm.logger import init_logger + +logger = init_logger(__name__) + +# Add uvloop for faster event loop if available +try: + import uvloop + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +except ImportError: + pass + + +class ServerState: + + def __init__(self, host, port): + self.host = host + self.port = port + self.url = f'http://{host}:{port}/v1' + self.client = httpx.AsyncClient(timeout=None, + base_url=self.url, + limits=httpx.Limits( + max_connections=100000, + max_keepalive_connections=100000)) + self.active_tokens = 0 + self.active_kv_cache = 0 # Only for prefiller + self.active_requests = 0 # Number of active requests + self.aborted_requests = set() # Track aborted requests + # Removed individual server lock - will use global locks instead + + +class ProxyState: + + def __init__(self, prefiller_instances, decoder_instances): + self.prefillers: List[ServerState] = [ + ServerState(h, p) for h, p in prefiller_instances + ] + self.decoders: List[ServerState] = [ + ServerState(h, p) for h, p in decoder_instances + ] + self.req_to_prefiller = {} + self.req_id_lock = asyncio.Lock() + # Removed selection locks - no longer needed for synchronous methods + + # Initialize priority queues for efficient server selection + # Each entry is (priority_score, server_index, server_reference) + # Lower priority score = higher priority (less loaded) + self.prefiller_heap = [(0, i, server) + for i, server in enumerate(self.prefillers)] + self.decoder_heap = [(0, i, server) + for i, server in enumerate(self.decoders)] + heapq.heapify(self.prefiller_heap) + heapq.heapify(self.decoder_heap) + self.req_id_future = {} + + def _update_prefiller_priority(self, server_idx: int): + """Update the priority of a prefiller server in the heap.""" + server = self.prefillers[server_idx] + # Priority based on active_tokens and active_kv_cache + priority = server.active_tokens + server.active_kv_cache * 0.3 + # Remove old entry and add new one + self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap + if i != server_idx] + heapq.heappush(self.prefiller_heap, + (priority, server_idx, server)) # type: ignore + + def _update_decoder_priority(self, server_idx: int): + """Update the priority of a decoder server in the heap.""" + server = self.decoders[server_idx] + priority = server.active_tokens + # Remove old entry and add new one + self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap + if i != server_idx] + heapq.heappush(self.decoder_heap, + (priority, server_idx, server)) # type: ignore + + def abort_prefiller_request(self, server_idx: int, + request_id): # Changed to synchronous + """ + Mark a request as aborted. This will helps to release kv cache in + prefiller node. + """ + # No lock needed - atomic operation + self.prefillers[server_idx].aborted_requests.add(request_id) + + def aquire_aborted_prefiller_requests( + self, server_idx: int): # Changed to synchronous + """ + Get the set of aborted requests and clear it. + This is used to release kv cache in prefiller node. + """ + # No lock needed - atomic operation + aborted_requests = self.prefillers[server_idx].aborted_requests.copy() + self.prefillers[server_idx].aborted_requests.clear() + return aborted_requests + + async def next_req_id(self): + async with self.req_id_lock: + return str(uuid.uuid4()) + + def select_prefiller(self, token_count): # Changed to synchronous + # No lock needed - entire function is atomic + if not self.prefiller_heap: + raise RuntimeError("No prefiller servers available") + + priority, chosen, server = heapq.heappop(self.prefiller_heap) + + # Update the chosen server atomically + self.prefillers[chosen].active_tokens += token_count + self.prefillers[chosen].active_kv_cache += token_count + + # Update priority and re-add to heap + self._update_prefiller_priority(chosen) + + return chosen + + def release_prefiller(self, idx, token_count): # Changed to synchronous + # No lock needed - atomic operation + self.prefillers[idx].active_tokens -= token_count + # Update priority queue after releasing + self._update_prefiller_priority(idx) + + def release_prefiller_kv(self, idx, token_count): # Changed to synchronous + # No lock needed - atomic operation + if self.prefillers[idx].active_kv_cache > 0: + self.prefillers[idx].active_kv_cache -= token_count + # Update priority queue after releasing + self._update_prefiller_priority(idx) + + def select_decoder(self, token_count): # Changed to synchronous + # No lock needed - entire function is atomic + if not self.decoder_heap: + raise RuntimeError("No decoder servers available") + + priority, chosen, server = heapq.heappop(self.decoder_heap) + + # Update the chosen server atomically + self.decoders[chosen].active_tokens += token_count + + # Update priority and re-add to heap + self._update_decoder_priority(chosen) + + return chosen + + def release_decoder(self, idx, token_count): # Changed to synchronous + # No lock needed - atomic operation + self.decoders[idx].active_tokens -= token_count + # Update priority queue after releasing + self._update_decoder_priority(idx) + + # Omni_infer's calculate_input_scores function + def calculate_prefill_scores(self, request_length: int) -> float: + length_score = request_length / 4.0 + input_score = length_score * 0.0345 + 120.0745 + return input_score + + def calculate_decode_scores(self, request_length: int) -> float: + return request_length + + +proxy_state = None + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--prefiller-hosts", + type=str, + nargs="+", + default=["localhost"]) + parser.add_argument("--prefiller-ports", + type=int, + nargs="+", + default=[8001]) + parser.add_argument("--decoder-hosts", + type=str, + nargs="+", + default=["localhost"]) + parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002]) + parser.add_argument("--max-retries", + type=int, + default=3, + help="Maximum number of retries for HTTP requests") + parser.add_argument( + "--retry-delay", + type=float, + default=0.001, + help="Base delay (seconds) for exponential backoff retries") + args = parser.parse_args() + if len(args.prefiller_hosts) != len(args.prefiller_ports): + raise ValueError( + "Number of prefiller hosts must match number of prefiller ports") + if len(args.decoder_hosts) != len(args.decoder_ports): + raise ValueError( + "Number of decoder hosts must match number of decoder ports") + args.prefiller_instances = list( + zip(args.prefiller_hosts, args.prefiller_ports)) + args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports)) + return args + + +@asynccontextmanager +async def lifespan(app: FastAPI): + global proxy_state + proxy_state = ProxyState(global_args.prefiller_instances, + global_args.decoder_instances) + print( + f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients." + ) + yield + for p in proxy_state.prefillers: + await p.client.aclose() + for d in proxy_state.decoders: + await d.client.aclose() + + +async def listen_for_disconnect(request: Request) -> None: + """Return if a disconnect message is received""" + while True: + message = await request.receive() + if message["type"] == "http.disconnect": + break + + +def with_cancellation(handler_func): + + @functools.wraps(handler_func) + async def wrapper(*args, **kwargs): + request = kwargs["request"] + handler_task = asyncio.create_task(handler_func(*args, **kwargs)) + cancellation_task = asyncio.create_task(listen_for_disconnect(request)) + done, pending = await asyncio.wait([handler_task, cancellation_task], + return_when=asyncio.FIRST_COMPLETED) + for task in pending: + task.cancel() + if handler_task in done: + return handler_task.result() + return None + + return wrapper + + +app = FastAPI(lifespan=lifespan) + + +async def send_request_to_service(client: httpx.AsyncClient, + prefiller_id: int, + endpoint: str, + req_data: dict, + request_id: str, + max_retries: int = 3, + base_delay: float = 0.2): + aborted_requests = proxy_state.aquire_aborted_prefiller_requests( + prefiller_id) + req_data = req_data.copy() + req_data['kv_transfer_params'] = { + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": None, + "remote_port": None, + "aborted_request": list(aborted_requests), + "metaserver": f"http://{global_args.host}:{global_args.port}/v1/metaserver" + } + req_data["stream"] = False + req_data["max_tokens"] = 1 + if "stream_options" in req_data: + del req_data["stream_options"] + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + last_exc = None + for attempt in range(1, max_retries + 1): + try: + response = await client.post(endpoint, + json=req_data, + headers=headers) + response.raise_for_status() + if request_id in proxy_state.req_id_future: + result_future = proxy_state.req_id_future[request_id] + result_future.set_result(response.json()["kv_transfer_params"]) + return + except (httpx.RequestError, httpx.HTTPStatusError) as e: + logger.warning( + f"Attempt {attempt} failed for {endpoint}: {str(e)}") + last_exc = e + if attempt < max_retries: + await asyncio.sleep(base_delay * (2**(attempt - 1))) + else: + logger.error( + f"All {max_retries} attempts failed for {endpoint}.") + raise last_exc + + +async def stream_service_response_with_retry(client: httpx.AsyncClient, + endpoint: str, + req_data: dict, + request_id: str, + max_retries: int = 3, + base_delay: float = 0.2): + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + for attempt in range(1, max_retries + 1): + try: + async with client.stream("POST", + endpoint, + json=req_data, + headers=headers) as response: + response.raise_for_status() + first_chunk_sent = False + async for chunk in response.aiter_bytes(): + first_chunk_sent = True + yield chunk + return # Success, exit after streaming + except (httpx.RequestError, httpx.HTTPStatusError) as e: + if attempt < max_retries: + logger.warning( + f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}" + ) + await asyncio.sleep(base_delay * (2**(attempt - 1))) + else: + logger.error( + f"All {max_retries} attempts failed for streaming {endpoint}." + ) + raise e + except Exception as e: + # If any chunk has been sent, do not retry, just log and drop + if 'first_chunk_sent' in locals() and first_chunk_sent: + logger.error( + f"Streaming to client interrupted after response started: {str(e)}" + ) + return + else: + if attempt < max_retries: + logger.warning( + f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}" + ) + await asyncio.sleep(base_delay * (2**(attempt - 1))) + else: + logger.error( + f"All {max_retries} attempts failed for streaming {endpoint}." + ) + raise e + + +def get_api_request_id(api, req_id): + if api == "/completions": + return "cmpl-" + req_id + "-0" + elif api == "/chat/completions": + return "chatcmpl-" + req_id + + +async def _handle_completions(api: str, request: Request): + try: + req_data = await request.json() + req_body = await request.body() + request_length = len(req_body) + prefiller_score = proxy_state.calculate_prefill_scores(request_length) + logger.debug( + f"Request length: {request_length}, Prefiller score: {prefiller_score}" + ) + request_id = await proxy_state.next_req_id() + # Select prefiller + prefiller_idx = proxy_state.select_prefiller(prefiller_score) + prefiller = proxy_state.prefillers[prefiller_idx] + result_future = asyncio.Future() # type: ignore + request_id_api = get_api_request_id(api, request_id) + proxy_state.req_id_future[request_id_api] = result_future + # Send request to prefiller + asyncio.get_running_loop().create_task(send_request_to_service( + prefiller.client, + prefiller_idx, + api, + req_data, + request_id, + max_retries=global_args.max_retries, + base_delay=global_args.retry_delay)) + proxy_state.release_prefiller(prefiller_idx, prefiller_score) + + response = await result_future + del proxy_state.req_id_future[request_id_api] + req_data["kv_transfer_params"] = response + + # Select decoder + decoder_score = proxy_state.calculate_decode_scores(request_length) + logger.debug("Decoder score: %f", decoder_score) + # Use the prefiller's kv_transfer_params to select decoder + decoder_idx = proxy_state.select_decoder(decoder_score) + decoder = proxy_state.decoders[decoder_idx] + logger.debug("Using %s %s", prefiller.url, decoder.url) + # Stream response from decoder + released_kv = False + async def generate_stream(): + nonlocal released_kv + # Only one await per chunk, minimal logic in loop + try: + async for chunk in stream_service_response_with_retry( + decoder.client, + api, + req_data, + request_id=request_id, + max_retries=global_args.max_retries, + base_delay=global_args.retry_delay): + if not released_kv and chunk: + proxy_state.release_prefiller_kv( + prefiller_idx, prefiller_score) + released_kv = True + yield chunk + except Exception as e: + logger.error( + f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it" + ) + proxy_state.abort_prefiller_request(prefiller_idx, request_id) + proxy_state.release_prefiller_kv(prefiller_idx, + prefiller_score) + + # After streaming done, release tokens + proxy_state.release_decoder(decoder_idx, decoder_score) + + return StreamingResponse(generate_stream(), + media_type="application/json") + except Exception as e: + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server" + f" - {api} endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +@app.post("/v1/completions") +@with_cancellation +async def handle_completions(request: Request): + return await _handle_completions("/completions", request) + + +@app.post("/v1/chat/completions") +@with_cancellation +async def handle_chat_completions(request: Request): + return await _handle_completions("/chat/completions", request) + + +@app.get("/healthcheck") +async def healthcheck(): + return { + "status": "ok", + "prefill_instances": len(proxy_state.prefillers), + "decode_instances": len(proxy_state.decoders) + } + + +@app.post("/v1/metaserver") +async def metaserver(request: Request): + try: + req_data = await request.json() + request_id = req_data.pop("request_id", None) + if request_id in proxy_state.req_id_future: + result_future = proxy_state.req_id_future[request_id] + result_future.set_result(req_data) + except Exception as e: + logger.error( + f"Post metaserver failed with: {str(e)}" + ) + + +if __name__ == '__main__': + global global_args + global_args = parse_args() + import uvicorn + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py b/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py index 2728931..fd1c7e5 100644 --- a/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py +++ b/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py @@ -544,4 +544,4 @@ if __name__ == '__main__': global global_args global_args = parse_args() import uvicorn - uvicorn.run(app, host=global_args.host, port=global_args.port) + uvicorn.run(app, host=global_args.host, port=global_args.port) \ No newline at end of file diff --git a/tests/ut/distributed/test_parallel_state.py b/tests/ut/distributed/test_parallel_state.py index 6b52b7b..f6c3315 100644 --- a/tests/ut/distributed/test_parallel_state.py +++ b/tests/ut/distributed/test_parallel_state.py @@ -4,8 +4,9 @@ import pytest from vllm.config import ParallelConfig from vllm_ascend.distributed.parallel_state import ( - _LMTP, _MC2, _OTP, destroy_ascend_model_parallel, get_lmhead_tp_group, - get_mc2_group, get_otp_group, init_ascend_model_parallel) + _LMTP, _MC2, _OTP, _P_TP, destroy_ascend_model_parallel, + get_lmhead_tp_group, get_mc2_group, get_otp_group, get_p_tp_group, + init_ascend_model_parallel) @pytest.fixture @@ -30,6 +31,7 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config): mock_ascend_config = MagicMock() mock_ascend_config.lmhead_tensor_parallel_size = 2 mock_ascend_config.oproj_tensor_parallel_size = 2 + mock_ascend_config.pd_tp_ratio = 2 with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \ patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \ patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config): @@ -38,11 +40,14 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config): mc2_group = get_mc2_group() lmheadtp_group = get_lmhead_tp_group() otp_group = get_otp_group() + p_tp_group = get_p_tp_group() assert mc2_group is not None assert otp_group is not None assert lmheadtp_group is not None + assert p_tp_group is not None destroy_ascend_model_parallel() assert _MC2 is None assert _LMTP is None assert _OTP is None + assert _P_TP is None diff --git a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py new file mode 100644 index 0000000..c7a1fcc --- /dev/null +++ b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py @@ -0,0 +1,1001 @@ +import os +import sys +import threading +import time +import types +import unittest +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import torch +import zmq + +fake_engine = types.ModuleType("mooncake.engine") +fake_engine.TransferEngine = MagicMock() # type: ignore[attr-defined] +sys.modules["mooncake.engine"] = fake_engine + +from vllm_ascend.distributed.mooncake_layerwise_connector import ( # noqa: E402 + DecodeMooncakeAgentMetadata, KVCacheRecvingLayerThread, + KVCacheSendingLayerThread, KVCacheTaskTracker, KVConnectorRole, + MooncakeLayerwiseConnector, MooncakeLayerwiseConnectorMetadata, + MooncakeLayerwiseConnectorScheduler, MooncakeLayerwiseConnectorWorker, + ReqMeta, SendingLayerThread, ensure_zmq_recv, ensure_zmq_send, + group_concurrent_contiguous, string_to_int64_hash, zmq_ctx) + +GET_META_MSG = b"get_meta_msg" +DONE_RECVING_MSG = b"done_recving_msg" + + +class TestKVCacheTaskTrackerInit(unittest.TestCase): + + def test_init_basic_properties(self): + tracker = KVCacheTaskTracker() + self.assertIsInstance(tracker.done_task_lock, type(threading.Lock())) + self.assertIsInstance(tracker.finished_requests, set) + self.assertIsInstance(tracker.delayed_free_requests, dict) + + +class TestGetAndClearFinishedSingleRequests(unittest.TestCase): + + def setUp(self): + self.tracker = KVCacheTaskTracker() + self.tracker.finished_requests = set() + self.tracker.done_task_lock = threading.Lock() + + def test_empty_requests(self): + result = self.tracker.get_and_clear_finished_requests() + self.assertEqual(result, set()) + self.assertEqual(len(self.tracker.finished_requests), 0) + + def test_single_request(self): + self.tracker.finished_requests = {"req_123"} + result = self.tracker.get_and_clear_finished_requests() + self.assertEqual(result, {"req_123"}) + self.assertEqual(len(self.tracker.finished_requests), 0) + + def test_multiple_requests(self): + self.tracker.finished_requests = {"req_1", "req_2", "req_3"} + result = self.tracker.get_and_clear_finished_requests() + self.assertSetEqual(result, {"req_1", "req_2", "req_3"}) + self.assertEqual(len(self.tracker.finished_requests), 0) + + @patch("vllm_ascend.distributed.mooncake_layerwise_connector.logger") + def test_concurrent_access(self, mock_logger): + from concurrent.futures import ThreadPoolExecutor + self.tracker.finished_requests = {"req_1", "req_2"} + with ThreadPoolExecutor(max_workers=3) as executor: + futures = [ + executor.submit(self.tracker.get_and_clear_finished_requests) + for _ in range(3) + ] + results = [f.result() for f in futures] + self.assertEqual(sum(1 for r in results if r), 1) + self.assertEqual(len(self.tracker.finished_requests), 0) + + +class TestKVCacheSendingLayerThreadBasic(unittest.TestCase): + + def setUp(self): + self.p1 = patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.get_ascend_config', + new=MagicMock(return_value=SimpleNamespace(pd_tp_ratio=1))) + self.p2 = patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.get_current_vllm_config', + new=MagicMock(return_value=SimpleNamespace( + scheduler_config=SimpleNamespace(max_model_len=128)))) + self.p1.start() + self.addCleanup(self.p1.stop) + self.p2.start() + self.addCleanup(self.p2.stop) + self.engine = MagicMock() + self.engine.register_memory.return_value = 0 + self.ready_event = threading.Event() + + batch_size, seq_len, hidden_dim, num_heads = 8, 128, 512, 8 + head_dim = hidden_dim // num_heads + self.first_kv_cache = torch.zeros( + (batch_size, num_heads, seq_len, head_dim), + dtype=torch.float32, + device='cpu') + + self.thread = KVCacheSendingLayerThread( + tp_rank=0, + tp_size=4, + decode_tp_size=2, + local_engine_id="local_engine", + side_channel_host="localhost", + side_channel_port=5555, + metadata=MagicMock(), + ready_event=self.ready_event, + total_layers=3, + engine=self.engine, + local_kv_base_addr=[0x1000, 0x2000], + block_len=[1024, 2048], + use_mla=True, + first_kv_cache=self.first_kv_cache) + + def test_add_request(self): + req_id = "req1" + meta = DecodeMooncakeAgentMetadata( + req_id=req_id, + block_ids=[3, 4], + host="localhost", + port=6666, + engine_id="remote_engine", + te_rpc_port=6000, + kv_caches_base_addr=[0x3000, 0x4000], + num_blocks=8) + with self.thread.lock: + self.thread.ready_decode[req_id] = meta + + local_block_ids = [1, 2] + key = torch.zeros((1, 1), dtype=torch.float32) + value = torch.zeros((1, 1), dtype=torch.float32) + + self.thread.add_request(request_id=req_id, + local_block_ids=local_block_ids, + layer_index=5, + key=key, + value=value) + + queued = self.thread.send_layer_thread.send_queue.get_nowait() + # queued: (metadata, request_id, local_block_ids, layer_index, key, value) + self.assertEqual(queued[1], "req1") + self.assertEqual(queued[0].host, "localhost") + + @patch.object(KVCacheTaskTracker, 'get_and_clear_finished_requests') + def test_get_finished_requests(self, mock_tracker): + mock_tracker.return_value = {"req1", "req2"} + result = self.thread.get_and_clear_finished_requests() + self.assertEqual(result, {"req1", "req2"}) + + @patch.object(KVCacheTaskTracker, 'add_delayed_request') + def test_add_delayed_request_passthrough(self, mock_add): + mock_add.return_value = None + ret = self.thread.add_delayed_request("req1", 123.456) + mock_add.assert_called_once_with("req1", 123.456) + self.assertIsNone(ret) + + def test_abort_requests_removes_pending(self): + with self.thread.lock: + self.thread.pending_decode["keep"] = [([9], 1)] + self.thread.pending_decode["dropA"] = [([1], 0)] + self.thread.pending_decode["dropB"] = [([2], 0)] + + self.thread._abort_requests({"dropA", "dropB"}) + + with self.thread.lock: + self.assertNotIn("dropA", self.thread.pending_decode) + self.assertNotIn("dropB", self.thread.pending_decode) + self.assertIn("keep", self.thread.pending_decode) + + @patch('vllm_ascend.distributed.mooncake_layerwise_connector.zmq.Context') + @patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.make_zmq_socket') + @patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.ensure_zmq_send') + def test_post_transfer_sends_and_receives_ack(self, mock_send, + mock_make_socket, + mock_context): + req_id = "req_ok" + meta = DecodeMooncakeAgentMetadata( + req_id=req_id, + block_ids=[1], + host="127.0.0.1", + port=7777, + engine_id="remote", + te_rpc_port=6000, + kv_caches_base_addr=[0x1], + num_blocks=1, + ) + with self.thread.lock: + self.thread.ready_decode[req_id] = meta + + fake_sock = MagicMock() + fake_sock.recv.return_value = b"ACK" + mock_make_socket.return_value = fake_sock + + self.thread._post_transfer(req_id) + + self.assertTrue(mock_make_socket.called) + _, kwargs = mock_make_socket.call_args + self.assertEqual(kwargs.get('path'), 'tcp://127.0.0.1:7777') + self.assertEqual(kwargs.get('socket_type'), zmq.REQ) # type: ignore + self.assertFalse(kwargs.get('bind', True)) + + mock_send.assert_called_once() + with self.thread.lock: + self.assertNotIn(req_id, self.thread.ready_decode) + + @patch('vllm_ascend.distributed.mooncake_layerwise_connector.zmq.Context') + @patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.make_zmq_socket') + @patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.ensure_zmq_send') + def test_post_transfer_bad_ack_raises_value_error(self, _mock_send, + mock_make_socket, + _mock_context): + req_id = "req_bad" + meta = DecodeMooncakeAgentMetadata( + req_id=req_id, + block_ids=[1], + host="127.0.0.1", + port=8888, + engine_id="remote", + te_rpc_port=6000, + kv_caches_base_addr=[0x2], + num_blocks=1, + ) + with self.thread.lock: + self.thread.ready_decode[req_id] = meta + + fake_sock = MagicMock() + fake_sock.recv.return_value = b"NOT_ACK" + mock_make_socket.return_value = fake_sock + + with self.assertRaises(ValueError): + self.thread._post_transfer(req_id) + + +class TestSendingLayerThread(unittest.TestCase): + + def setUp(self): + self.p1 = patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.get_ascend_config', + new=MagicMock(return_value=SimpleNamespace(pd_tp_ratio=1))) + self.p2 = patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.get_current_vllm_config', + new=MagicMock(return_value=SimpleNamespace( + scheduler_config=SimpleNamespace(max_model_len=128)))) + self.p1.start() + self.addCleanup(self.p1.stop) + self.p2.start() + self.addCleanup(self.p2.stop) + self.task_tracker = MagicMock(KVCacheTaskTracker) + self.engine = MagicMock() + self.engine.register_memory.side_effect = lambda addr, size: 0 + batch_size = 8 + seq_len = 128 + hidden_dim = 512 + num_heads = 8 + head_dim = hidden_dim // num_heads # 512 // 8 = 64 + self.first_kv_cache = torch.zeros( + (batch_size, num_heads, seq_len, head_dim), + dtype=torch.float32, + device='cpu') + self.thread = SendingLayerThread( + task_tracker=self.task_tracker, + total_layers=3, + engine=self.engine, + local_kv_base_addr=["0x1000", "0x2000"], + block_len=[1024, 2048], + use_mla=True, + tp_rank=0, + first_kv_cache=self.first_kv_cache) + + @patch.object(SendingLayerThread, "_transfer_kv_cache", autospec=True) + def test_handle_request(self, mock_transfer): + req_id = "req_1" + req_meta = MagicMock(spec=DecodeMooncakeAgentMetadata) + key = torch.zeros((1, 1), dtype=torch.float32) + value = torch.zeros((1, 1), dtype=torch.float32) + item = (req_meta, req_id, [10, 11], 0, key, value) + with patch.object(self.thread.task_tracker, "update_done_task_count") as mock_update_done, \ + patch.object(self.thread.send_queue, "task_done", autospec=True) as mock_task_done: + self.thread._handle_request(item) + mock_transfer.assert_called_once_with(self.thread, req_meta, [10, 11], + 0, key, value) + mock_update_done.assert_called_once_with(req_id) + mock_task_done.assert_called_once() + + @patch('torch.npu.synchronize') + @patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous' + ) + def test_transfer_kv_cache(self, mock_group, mock_sync): + key = torch.zeros((1, 1), dtype=torch.float32) + value = torch.zeros((1, 1), dtype=torch.float32) + mock_sync.return_value = None + self.thread.pd_tp_ratio = 1 + + self.thread.local_kv_base_addr = [1000, 2000] + + meta = DecodeMooncakeAgentMetadata( + req_id="req-ok", + block_ids=[0], + host="127.0.0.1", + port=7777, + engine_id="remote", + te_rpc_port=6000, + kv_caches_base_addr=[4000, 8000], + num_blocks=256, + ) + + mock_group.return_value = ( + [[10, 11, 12], [20, 21]], # grouped_remote_block_ids + [[5, 6, 7], [8, 9]], # grouped_local_block_ids + ) + + self.engine.batch_transfer_sync_write.return_value = 1 + + self.thread._transfer_kv_cache(meta, + local_block_ids=[123], + layer_index=0, + key=key, + value=value) + + # k=0 (block_len=1024): + # grp1: src=1000+5*1024=6120, dst=4000+10*1024=14240, len=3*1024=3072 + # grp2: src=1000+8*1024=9192, dst=4000+20*1024=24480, len=2*1024=2048 + # k=1 (block_len=2048): + # grp1: src=2000+5*2048=12240, dst=8000+10*2048=28480, len=3*2048=6144 + # grp2: src=2000+8*2048=18384, dst=8000+20*2048=48960, len=2*2048=4096 + exp_session = "127.0.0.1:6000" + exp_src = [6120, 9192, 12240, 18384] + exp_dst = [14240, 24480, 28480, 48960] + exp_len = [3072, 2048, 6144, 4096] + + self.engine.batch_transfer_sync_write.assert_called_once() + args, _ = self.engine.batch_transfer_sync_write.call_args + self.assertEqual(args[0], exp_session) + self.assertEqual(args[1], exp_src) + self.assertEqual(args[2], exp_dst) + self.assertEqual(args[3], exp_len) + + +class TestKVCacheRecvingLayerThreadBasic(unittest.TestCase): + + def setUp(self): + self.ready_event = threading.Event() + self.thread = KVCacheRecvingLayerThread( + tp_rank=0, + side_channel_port=5555, + tp_size=4, + local_engine_id="local_engine", + ready_event=self.ready_event, + ) + + def test_get_finished_requests(self): + + with self.thread.lock: + self.thread.done_requests.update({"req1", "req2"}) + + result = self.thread.get_and_clear_finished_requests() + self.assertEqual(result, {"req1", "req2"}) + + result2 = self.thread.get_and_clear_finished_requests() + self.assertEqual(result2, set()) + + +class MockVllmConfig: + + def __init__(self): + self.model_config = MagicMock() + self.parallel_config = MagicMock() + self.cache_config = MagicMock() + self.kv_transfer_config = MagicMock() + self.model_config.use_mla = True + self.parallel_config.tensor_parallel_size = 2 + self.parallel_config.data_parallel_rank_local = 0 + self.parallel_config.data_parallel_size_local = 1 + self.cache_config.block_size = 16 + self.kv_transfer_config.kv_port = 5000 + self.kv_transfer_config.kv_role = 'kv_producer' + self.kv_transfer_config.get_from_extra_config = MagicMock() + self.kv_transfer_config.get_from_extra_config.side_effect = lambda k, d: { + "prefill": { + "tp_size": 2, + "dp_size": 1 + }, + "decode": { + "tp_size": 2, + "dp_size": 1 + } + }.get(k, d) + + +class MockRequest: + + def __init__(self, + request_id, + prompt_token_ids=None, + kv_transfer_params=None, + status=None): + self.request_id = request_id + self.prompt_token_ids = prompt_token_ids or [1, 2, 3, 4] + self.kv_transfer_params = kv_transfer_params or {} + self.status = status or "running" + self.output_token_ids = [101, 102] + + +class TestKVCacheTaskTracker(unittest.TestCase): + + def setUp(self): + self.tracker = KVCacheTaskTracker() + + def test_update_done_task_count(self): + + self.assertEqual(len(self.tracker.finished_requests), 0) + self.assertEqual(len(self.tracker.delayed_free_requests), 0) + + current_time = time.time() + self.tracker.add_delayed_request("req_1", current_time) + + result = self.tracker.delayed_free_requests + self.assertEqual(len(result), 1) + self.assertIn("req_1", result) + self.assertEqual(result["req_1"], current_time) + + with patch.object(self.tracker, "on_done") as mock_on_done: + for _ in range(getattr(self.tracker, "target_count", 1)): + self.tracker.update_done_task_count("req_1") + mock_on_done.assert_called_once_with("req_1") + + self.assertEqual(self.tracker.finished_requests, {"req_1"}) + + result_delayed = self.tracker.delayed_free_requests + self.assertEqual(len(result_delayed), 1) + self.assertIn("req_1", result_delayed) + self.assertEqual(result_delayed["req_1"], current_time) + + def test_retrieve_expired_requests(self): + current_time = time.time() + self.tracker.add_delayed_request("req_1", current_time - 600) + self.tracker.add_delayed_request("req_2", current_time) + result = self.tracker._retrieve_expired_requests() + self.assertEqual(result, { + "req_1", + }) + result_delay = self.tracker.delayed_free_requests # dict + self.assertEqual(len(result_delay), 1) + + self.assertIn("req_2", result_delay) + self.assertEqual(result_delay["req_2"], current_time) + + def test_duplicate_task_update(self): + self.tracker.update_done_task_count("req1") + self.tracker.update_done_task_count("req1") + self.tracker.update_done_task_count("req1") + + finished = self.tracker.get_and_clear_finished_requests() + self.assertEqual(finished, {"req1"}) + + +class TestMooncakeLayerwiseConnectorMetadata(unittest.TestCase): + + def test_add_new_req(self): + meta = MooncakeLayerwiseConnectorMetadata() + self.assertEqual(len(meta.requests), 0) + self.assertEqual(len(meta.requests_to_send), 0) + + meta.add_new_req(request_id="req1", + local_block_ids=[1, 2, 3], + kv_transfer_params={ + "remote_block_ids": [4, 5, 6], + "remote_engine_id": "remote_engine", + "remote_host": "localhost", + "remote_port": 5000 + }) + + self.assertEqual(len(meta.requests), 1) + req_meta = meta.requests["req1"] + self.assertIsInstance(req_meta, ReqMeta) + self.assertEqual(req_meta.local_block_ids, [1, 2, 3]) + self.assertEqual(req_meta.remote_block_ids, [4, 5, 6]) + self.assertEqual(req_meta.remote_engine_id, "remote_engine") + self.assertEqual(req_meta.remote_host, "localhost") + self.assertEqual(req_meta.remote_port, 5000) + + +class TestMooncakeLayerwiseConnectorSchedulerMatchedTokens(unittest.TestCase): + + def setUp(self): + config = MockVllmConfig() + self.scheduler = MooncakeLayerwiseConnectorScheduler( + config, "test_engine") + + def test_get_num_new_matched_tokens(self): + request = MockRequest("req1") + tokens, async_flag = self.scheduler.get_num_new_matched_tokens( + request, 0) + self.assertEqual(tokens, 0) + self.assertFalse(async_flag) + + request.kv_transfer_params = {"do_remote_prefill": True} + tokens, async_flag = self.scheduler.get_num_new_matched_tokens( + request, 0) + self.assertEqual(tokens, 4) + self.assertTrue(async_flag) + + def test_build_connector_meta(self): + request = MockRequest("req1") + blocks_mock = MagicMock() + blocks_mock.get_unhashed_block_ids.return_value = [4, 5, 6] + self.scheduler._reqs_need_recv["req1"] = (request, [4, 5, 6]) + request.kv_transfer_params = { + "remote_block_ids": [1, 2, 3], + "remote_engine_id": "remote", + "remote_host": "localhost", + "remote_port": 5000 + } + + meta = self.scheduler.build_connector_meta(MagicMock()) + self.assertIsInstance(meta, MooncakeLayerwiseConnectorMetadata) + self.assertEqual(len(meta.requests), 1) + self.assertEqual(meta.requests["req1"].local_block_ids, [4, 5, 6]) + self.assertEqual(meta.requests["req1"].remote_block_ids, [1, 2, 3]) + self.assertEqual(len(self.scheduler._reqs_need_recv), 0) + + def test_get_finished_count(self): + count = self.scheduler.get_finished_count() + self.assertEqual(count, 2) + + +class TestHelperFunctions(unittest.TestCase): + + def test_group_concurrent_contiguous(self): + src: list[int] = [1, 2, 3, 5, 6] + dst: list[int] = [10, 11, 12, 14, 15] + + src_groups, dst_groups = group_concurrent_contiguous(src, dst) + + self.assertEqual(len(src_groups), 2) + self.assertEqual(src_groups[0], [1, 2, 3]) + self.assertEqual(src_groups[1], [5, 6]) + self.assertEqual(dst_groups[0], [10, 11, 12]) + self.assertEqual(dst_groups[1], [14, 15]) + + def test_group_concurrent_contiguous_empty(self): + src: list[int] = [] + dst: list[int] = [] + src_groups, dst_groups = group_concurrent_contiguous(src, dst) + self.assertEqual(src_groups, []) + self.assertEqual(dst_groups, []) + + def test_string_to_int64_hash(self): + hash1 = string_to_int64_hash("test_string") + hash2 = string_to_int64_hash("test_string") + self.assertEqual(hash1, hash2) + + hash3 = string_to_int64_hash("different_string") + self.assertNotEqual(hash1, hash3) + + +class TestMooncakeLayerwiseConnectorForScheduler(unittest.TestCase): + + def test_scheduler_role(self): + config = MockVllmConfig() + connector = MooncakeLayerwiseConnector(config, + KVConnectorRole.SCHEDULER) + self.assertIsNotNone(connector.connector_scheduler) + self.assertIsNone(connector.connector_worker) + + @patch.object(MooncakeLayerwiseConnectorScheduler, + "get_num_new_matched_tokens") + def test_scheduler_methods(self, mock_method): + config = MockVllmConfig() + connector = MooncakeLayerwiseConnector(config, + KVConnectorRole.SCHEDULER) + request = MockRequest("req1") + connector.get_num_new_matched_tokens(request, 0) + mock_method.assert_called_once_with(request, 0) + + +class MockKVCacheBlocks: + + def get_unhashed_block_ids(self): + return [4, 5, 6] + + +class MockSchedulerOutput: + pass + + +class MockForwardContext: + pass + + +class TestMooncakeLayerwiseConnector(unittest.TestCase): + + def setUp(self): + self.config = MockVllmConfig() + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1" + + def test_scheduler_initialization(self): + connector = MooncakeLayerwiseConnector(self.config, + KVConnectorRole.SCHEDULER) + self.assertIsNotNone(connector.connector_scheduler) + self.assertIsNone(connector.connector_worker) + + @patch.object(MooncakeLayerwiseConnectorScheduler, + "get_num_new_matched_tokens") + def test_get_num_new_matched_tokens(self, mock_method): + connector = MooncakeLayerwiseConnector(self.config, + KVConnectorRole.SCHEDULER) + request = MockRequest("req1") + connector.get_num_new_matched_tokens(request, 0) + mock_method.assert_called_once_with(request, 0) + + @patch.object(MooncakeLayerwiseConnectorScheduler, + "update_state_after_alloc") + def test_update_state_after_alloc(self, mock_method): + connector = MooncakeLayerwiseConnector(self.config, + KVConnectorRole.SCHEDULER) + request = MockRequest("req1") + blocks = MockKVCacheBlocks() + connector.update_state_after_alloc(request, blocks, 3) + mock_method.assert_called_once_with(request, blocks, 3) + + @patch.object(MooncakeLayerwiseConnectorScheduler, "build_connector_meta") + def test_build_connector_meta(self, mock_method): + connector = MooncakeLayerwiseConnector(self.config, + KVConnectorRole.SCHEDULER) + scheduler_output = MockSchedulerOutput() + connector.build_connector_meta(scheduler_output) + mock_method.assert_called_once_with(scheduler_output) + + @patch.object(MooncakeLayerwiseConnectorScheduler, "request_finished") + def test_request_finished(self, mock_method): + connector = MooncakeLayerwiseConnector(self.config, + KVConnectorRole.SCHEDULER) + request = MockRequest("req1") + connector.request_finished(request, [1, 2, 3]) + mock_method.assert_called_once_with(request, [1, 2, 3]) + + +class TestMooncakeLayerwiseConnectorScheduler(unittest.TestCase): + + def setUp(self): + self.config = MockVllmConfig() + self.scheduler = MooncakeLayerwiseConnectorScheduler( + self.config, "test_engine") + + def test_get_num_new_matched_tokens_no_remote_prefill(self): + request = MockRequest("req1") + tokens, async_flag = self.scheduler.get_num_new_matched_tokens( + request, 0) + self.assertEqual(tokens, 0) + self.assertFalse(async_flag) + + def test_get_num_new_matched_tokens_with_remote_prefill(self): + request = MockRequest("req1", + kv_transfer_params={"do_remote_prefill": True}) + tokens, async_flag = self.scheduler.get_num_new_matched_tokens( + request, 0) + self.assertEqual(tokens, 4) + self.assertTrue(async_flag) + + def test_update_state_after_alloc_no_remote_prefill(self): + request = MockRequest("req1") + blocks = MagicMock() + self.scheduler.update_state_after_alloc(request, blocks, 0) + self.assertEqual(len(self.scheduler._reqs_need_recv), 0) + + def test_update_state_after_alloc_with_remote_prefill(self): + request = MockRequest("req1", + kv_transfer_params={ + "do_remote_prefill": True, + "remote_block_ids": [1, 2, 3], + "remote_engine_id": "remote", + "remote_host": "localhost", + "remote_port": 5000 + }) + blocks = MockKVCacheBlocks() + self.scheduler.update_state_after_alloc(request, blocks, 3) + self.assertEqual(len(self.scheduler._reqs_need_recv), 1) + self.assertEqual(self.scheduler._reqs_need_recv["req1"][0], request) + self.assertEqual(self.scheduler._reqs_need_recv["req1"][1], [4, 5, 6]) + + def test_request_finished_no_remote_decode(self): + request = MockRequest("req1") + delay_free, params = self.scheduler.request_finished( + request, [1, 2, 3]) + self.assertFalse(delay_free) + self.assertIsNone(params) + + +class TestUtils(unittest.TestCase): + + def test_string_to_int64_hash(self): + h1 = string_to_int64_hash("hello") + h2 = string_to_int64_hash("hello") + h3 = string_to_int64_hash("world") + self.assertEqual(h1, h2) + self.assertNotEqual(h1, h3) + self.assertIsInstance(h1, int) + + def test_group_concurrent_contiguous(self): + src: list[int] = [1, 2, 3, 5, 6] + dst: list[int] = [10, 11, 12, 20, 21] + src_g, dst_g = group_concurrent_contiguous(src, dst) + self.assertEqual(src_g, [[1, 2, 3], [5, 6]]) + self.assertEqual(dst_g, [[10, 11, 12], [20, 21]]) + + def test_group_empty(self): + src_g, dst_g = group_concurrent_contiguous([], []) + self.assertEqual(src_g, []) + self.assertEqual(dst_g, []) + + def test_zmq_ctx_invalid_type(self): + with self.assertRaises(ValueError): + with zmq_ctx("INVALID", "tcp://127.0.0.1:5555"): + pass + + @patch( + "vllm_ascend.distributed.mooncake_layerwise_connector.make_zmq_socket") + def test_zmq_ctx_ok(self, mock_make_socket): + mock_socket = MagicMock() + mock_make_socket.return_value = mock_socket + with zmq_ctx(zmq.REQ, "tcp://localhost:1234") as s: # type: ignore + self.assertEqual(s, mock_socket) + + @patch("vllm_ascend.distributed.mooncake_layerwise_connector.logger") + def test_ensure_zmq_send_success(self, mock_logger): + mock_socket = MagicMock() + ensure_zmq_send(mock_socket, b"hello") + mock_socket.send.assert_called_once_with(b"hello") + + @patch("vllm_ascend.distributed.mooncake_layerwise_connector.logger") + def test_ensure_zmq_send_retry_and_fail(self, mock_logger): + mock_socket = MagicMock() + mock_socket.send.side_effect = zmq.ZMQError( # type: ignore + "send failed") + with self.assertRaises(RuntimeError): + ensure_zmq_send(mock_socket, b"hello", max_retries=2) + self.assertEqual(mock_socket.send.call_count, 2) + + @patch("vllm_ascend.distributed.mooncake_layerwise_connector.logger") + def test_ensure_zmq_recv_success(self, mock_logger): + mock_socket = MagicMock() + mock_socket.recv.return_value = b"response" + mock_poller = MagicMock() + mock_poller.poll.return_value = [ + (mock_socket, zmq.POLLIN) # type: ignore + ] + data = ensure_zmq_recv(mock_socket, mock_poller) + self.assertEqual(data, b"response") + + @patch("vllm_ascend.distributed.mooncake_layerwise_connector.logger") + def test_ensure_zmq_recv_timeout_and_fail(self, mock_logger): + mock_socket = MagicMock() + mock_poller = MagicMock() + mock_poller.poll.return_value = [] + with self.assertRaises(RuntimeError): + ensure_zmq_recv(mock_socket, + mock_poller, + timeout=0.01, + max_retries=2) + + +class MockMooncakeAgentMetadata: + + def __init__(self, **kwargs): + pass + + +class MockMooncakeLayerwiseConnectorMetadata: + + def __init__(self): + self.requests = {} + + +class MockKVCacheSendingThread(threading.Thread): + + def __init__(self, *args, **kwargs): + super().__init__() + self.daemon = True + self._finished_requests = set() + + def get_and_clear_finished_requests(self): + return self._finished_requests + + def start(self): + pass + + +class MockKVCacheRecvingThread(threading.Thread): + + def __init__(self, *args, **kwargs): + super().__init__() + self.daemon = True + self._finished_requests = set() + self.add_request = MagicMock() + + def get_and_clear_finished_requests(self): + return self._finished_requests + + def start(self): + pass + + +class MockTensor: + + def __init__(self, *args, **kwargs): + self.size = MagicMock(return_value=(10, 16, 8, 16)) + self.element_size = MagicMock(return_value=4) + self.shape = (10, 16, 8, 16) + self.data_ptr = MagicMock(return_value=0x1000) + + +mock_envs_ascend = MagicMock() +mock_envs_ascend.MOONCAKE_CONNECTOR_PROTOCOL = "mock_protocol" + +mock_logger = MagicMock() + + +class MockTransferEngine: + + def initialize(self, *args, **kwargs): + return 0 + + def register_memory(self, *args, **kwargs): + return 1 + + +class MockEnvsAscend: + MOONCAKE_CONNECTOR_PROTOCOL = "mock_protocol" + PHYSICAL_DEVICES = "10,11" + + +def mock_get_tensor_model_parallel_rank(): + return 0 + + +def mock_get_tp_group(): + return MagicMock() + + +def mock_get_ip(): + return "127.0.0.1" + + +def mock_string_to_int64_hash(s): + return hash(s) + + +class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase): + + def setUp(self): + self.envs_ascend_mock = MockEnvsAscend() + self.mock_transfer_engine = MagicMock() + self.mock_transfer_engine.get_rpc_port.return_value = 9090 + self.mock_transfer_engine.initialize.return_value = 0 + self.mock_transfer_engine.register_memory.return_value = 0 + + self.patches = [ + patch('os.getenv', return_value="10,11"), + patch('torch.Tensor.size', return_value=(10, 16, 8, 16)), + patch('torch.Tensor.element_size', return_value=4), + patch('torch.Tensor.data_ptr', return_value=0x1000), + patch('math.prod', return_value=128), + patch('random.Random'), + patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.get_tensor_model_parallel_rank', + mock_get_tensor_model_parallel_rank), + patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.get_tp_group', + mock_get_tp_group), + patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.get_ip', + mock_get_ip), + patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.string_to_int64_hash', + mock_string_to_int64_hash), + patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.TransferEngine', + return_value=self.mock_transfer_engine), + patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.KVCacheSendingLayerThread', + MagicMock()), + patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.KVCacheRecvingLayerThread', + MagicMock()), + patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.logger', + MagicMock()), + patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.threading.Event', + MagicMock()), + patch.dict('sys.modules', + {'vllm_ascend.envs': self.envs_ascend_mock}), + patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.get_ascend_config', + return_value=SimpleNamespace(pd_tp_ratio=1), + ), + patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.get_current_vllm_config', + return_value=SimpleNamespace(scheduler_config=SimpleNamespace( + max_model_len=128)), + ) + ] + + for p in self.patches: + p.start() # type: ignore + + self.vllm_config = MockVllmConfig() + self.engine_id = "test_engine" + self.kv_caches = {"layer1": (MagicMock(), MagicMock())} + + def tearDown(self): + for p in self.patches: + p.stop() # type: ignore + + def test_worker_use_ascend_direct(self): + test_case = [True, False] + + for use_ascend_direct in test_case: + with self.subTest(use_ascend_direct=use_ascend_direct): + config = MagicMock() + config.kv_transfer_config = MagicMock() + config.kv_transfer_config.get_from_extra_config.side_effect = ( + lambda k, d: { + "prefill": { + "tp_size": 2, + "dp_size": 1 + }, + "decode": { + "tp_size": 2, + "dp_size": 1 + }, + "use_ascend_direct": use_ascend_direct, + }.get(k, d)) + + config.parallel_config = MagicMock() + config.parallel_config.tensor_parallel_size = 2 + config.parallel_config.data_parallel_rank_local = 0 + config.parallel_config.data_parallel_size_local = 1 + config.kv_transfer_config.kv_port = 8000 + config.kv_transfer_config.kv_role = 'worker' + + with patch( + "vllm_ascend.distributed.mooncake_layerwise_connector.get_tensor_model_parallel_rank", + return_value=0): + with patch( + "vllm_ascend.distributed.mooncake_layerwise_connector.get_tp_group", + return_value=None): + with patch( + "vllm_ascend.distributed.mooncake_layerwise_connector.get_ip", + return_value="127.0.0.1"): + worker = MooncakeLayerwiseConnectorWorker( + config, self.engine_id) + self.assertIsNotNone(worker) + + def test_register_kv_caches_producer(self): + worker = MooncakeLayerwiseConnectorWorker(self.vllm_config, + self.engine_id) + worker.register_kv_caches(self.kv_caches) + self.assertEqual(len(worker.kv_caches), 1) + self.assertIsNotNone(worker.kv_send_layer_thread) + self.assertIsNone(worker.kv_recv_layer_thread) + + def test_register_kv_caches_consumer(self): + self.vllm_config.kv_transfer_config.kv_role = 'kv_consumer' + worker = MooncakeLayerwiseConnectorWorker(self.vllm_config, + self.engine_id) + worker.register_kv_caches(self.kv_caches) + self.assertIsNone(worker.kv_send_layer_thread) + self.assertIsNotNone(worker.kv_recv_layer_thread) + + def test_register_kv_caches_mla_case(self): + mla_cache1 = MagicMock() + mla_cache1.size.return_value = (10, 16, 1, 16) + mla_cache2 = MagicMock() + mla_cache2.size.return_value = (10, 16, 1, 8) + mla_caches = {"layer1": (mla_cache1, mla_cache2)} + + worker = MooncakeLayerwiseConnectorWorker(self.vllm_config, + self.engine_id) + worker.register_kv_caches(mla_caches) + self.assertTrue(worker.use_mla) + self.assertEqual(len(worker.block_len), 2) + + def test_device_id_selection_with_physical_devices(self): + # Test with physical devices set + worker = MooncakeLayerwiseConnectorWorker(self.vllm_config, + self.engine_id) + # Default tp_rank is 0, so device_id should be 10 + self.assertEqual(worker.device_id, 10) + + +if __name__ == '__main__': + unittest.main() diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 65ea3ea..27017c1 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -94,6 +94,17 @@ class AscendConfig: raise AssertionError( "oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node." ) + self.pd_tp_ratio = 1 + if vllm_config.kv_transfer_config is not None and not vllm_config.model_config.is_deepseek_mla: + prefill_tp_size = vllm_config.kv_transfer_config.get_from_extra_config( + "prefill", {"tp_size": 1})["tp_size"] + decode_tp_size = vllm_config.kv_transfer_config.get_from_extra_config( + "decode", {"tp_size": 1})["tp_size"] + pd_tp_ratio: int = prefill_tp_size // decode_tp_size + self.pd_tp_ratio = pd_tp_ratio + if self.pd_tp_ratio == 0: + raise AssertionError( + "Only support P node tp size lagger then D node tp size") class TorchairGraphConfig: diff --git a/vllm_ascend/distributed/__init__.py b/vllm_ascend/distributed/__init__.py index 26ddd8f..b73f135 100644 --- a/vllm_ascend/distributed/__init__.py +++ b/vllm_ascend/distributed/__init__.py @@ -31,3 +31,8 @@ KVConnectorFactory.register_connector( "MooncakeConnectorStoreV1", "vllm_ascend.distributed.mooncake.mooncake_store_connector_v1", "MooncakeConnectorV1") + +KVConnectorFactory.register_connector( + "MooncakeLayerwiseConnector", + "vllm_ascend.distributed.mooncake_layerwise_connector", + "MooncakeLayerwiseConnector") diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index c0fc1a6..6ecf8e7 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -1109,4 +1109,4 @@ def ensure_zmq_recv( logger.error(f"Receive failed after all retries: {e}") raise RuntimeError( f"Failed to receive data after {max_retries} " - f"retries: {e}") + f"retries: {e}") \ No newline at end of file diff --git a/vllm_ascend/distributed/mooncake_layerwise_connector.py b/vllm_ascend/distributed/mooncake_layerwise_connector.py new file mode 100644 index 0000000..c199a37 --- /dev/null +++ b/vllm_ascend/distributed/mooncake_layerwise_connector.py @@ -0,0 +1,1335 @@ +# SPDX-License-Identifier: Apache-2.0 +import contextlib +import hashlib +import math +import queue +import random +import struct +import threading +import time +from collections import defaultdict +from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple + +import httpx +import msgspec +import numpy as np +import numpy.typing as npt +import torch +import zmq +from mooncake.engine import TransferEngine # type: ignore +from vllm import envs +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, + get_tp_group, get_world_group) +from vllm.utils import get_ip, logger, make_zmq_path, make_zmq_socket +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import RequestStatus + +import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.distributed.utils import (align_memory, + kv_alltoall_and_rearrange) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + +GET_META_MSG = b"get_meta_msg" +DONE_RECVING_MSG = b"done_recving_msg" + + +class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True): + engine_id: str + te_rpc_port: int + kv_caches_base_addr: list[int] + num_blocks: int + + +@dataclass +class ReqMeta: + local_block_ids: list[int] + # Not None if layer-wise is disabled + remote_block_ids: Optional[list[int]] + remote_host: Optional[str] + remote_port: Optional[int] + remote_engine_id: Optional[str] + # Not None if layer-wise is enabled + metaserver: Optional[str] + remote_tp_size: Optional[int] + + +class DecodeMooncakeAgentMetadata(msgspec.Struct, + omit_defaults=True, + dict=True): + req_id: str + block_ids: list[int] + host: str + port: int + engine_id: str + te_rpc_port: int + kv_caches_base_addr: list[int] + num_blocks: int + + +class KVCacheTaskTracker: + + def __init__(self, + target_count: int = 1, + on_done: Callable[[str], None] = lambda x: None, + on_timeout: Callable[[set[str]], Any] = lambda x: None): + super().__init__() + self.target_count = target_count + self.done_task_lock = threading.Lock() + self.done_task_counts: defaultdict[str, int] = defaultdict(int) + self.finished_requests: set[str] = set() + # Only used in prefill node. Tracks requests whose kv blocks freeing is + # intentionally delayed. Each entry is a tuple of (request_id, + # timestamp). If a request remains in this queue for too long, it will + # be force-freed. + # Notice: In layer-wise mode, the transfer may complete before it is + # added to delayed_free_requests when prefill node finishes forwarding. + # Therefore we need to track requests that are removed before being added. + self.delayed_free_requests: dict[str, float] = {} + self.removed_delayed_free_requests: set[str] = set() + self.on_done = on_done + self.on_timeout = on_timeout + + def update_done_task_count(self, request_id: str): + self.done_task_counts[request_id] += 1 + if self.done_task_counts[request_id] == self.target_count: + with self.done_task_lock: + self.finished_requests.add(request_id) + self.done_task_counts.pop(request_id) + self.on_done(request_id) + + def get_and_clear_finished_requests(self) -> set[str]: + """ + Get and clear the requests that have been completed. + Returns: + A set of request IDs that have been completed. + """ + with self.done_task_lock: + finished_requests = self.finished_requests.copy() + expired_requests = self._retrieve_expired_requests() + finished_requests.update(expired_requests) + self.finished_requests.clear() + self.on_timeout(expired_requests) + return finished_requests + + def add_delayed_request(self, request_id: str, delay_start_time: float): + """Add a delayed free request, where delay_start_time is monotonic increasing.""" + with self.done_task_lock: + if request_id in self.removed_delayed_free_requests: + self.removed_delayed_free_requests.remove(request_id) + else: + self.delayed_free_requests[request_id] = delay_start_time + + def _retrieve_expired_requests(self): + """Retrieve all expired delayed requests.""" + expired_requests: set[str] = set() + # Free delayed requests if they exceed the timeout + current_time = time.time() + while self.delayed_free_requests: + request_id, delay_start_time = next( + iter(self.delayed_free_requests.items())) + if (current_time - delay_start_time + > envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT): + self.delayed_free_requests.pop(request_id) + expired_requests.add(request_id) + logger.info("Force freed request: %s", request_id) + else: + break + return expired_requests + + def remove_delayed_request(self, request_id: str): + """Remove all delayed free requests matching the given request_id.""" + with self.done_task_lock: + if self.delayed_free_requests.pop(request_id, None) is None: + self.removed_delayed_free_requests.add(request_id) + + +class KVCacheSendingLayerThread(threading.Thread): + + def __init__(self, tp_rank: int, tp_size: int, decode_tp_size: int, + local_engine_id: str, side_channel_host: str, + side_channel_port: int, metadata: MooncakeAgentMetadata, + ready_event: threading.Event, total_layers: int, + engine: TransferEngine, local_kv_base_addr: list[int], + block_len: list[int], use_mla: bool, + first_kv_cache: torch.Tensor): + super().__init__(daemon=True, name="KVCacheSendingLayerThread") + self.tp_rank = tp_rank + self.tp_size = tp_size + self.decode_tp_size = decode_tp_size + self.local_engine_id = local_engine_id + self.side_channel_host = side_channel_host + self.side_channel_port = side_channel_port + self.task_tracker = KVCacheTaskTracker(total_layers, + on_done=self._post_transfer, + on_timeout=self._abort_requests) + self.send_layer_thread = SendingLayerThread( + self.task_tracker, total_layers, engine, local_kv_base_addr, + block_len, use_mla, self.tp_rank, first_kv_cache) + self.ready_decode = dict[str, DecodeMooncakeAgentMetadata]() + self.pending_decode = dict[str, + list[tuple[list[int], int, torch.Tensor, + torch.Tensor]]]() + self.total_layers = total_layers + self.lock = threading.Lock() + self.ready_event = ready_event + + def get_and_clear_finished_requests(self) -> set[str]: + """ + Get and clear the requests that have been completed. + Returns: + A set of request IDs that have been completed. + """ + # vllm won't call us if all inference is done, so we can't do step 9 here + return self.task_tracker.get_and_clear_finished_requests() + + def add_delayed_request(self, request_id: str, delay_start_time: float): + return self.task_tracker.add_delayed_request(request_id, + delay_start_time) + + def run(self): + """Run the thread to handle KV cache transfer requests.""" + self.send_layer_thread.start() + handshake_port = self.side_channel_port + self.tp_rank + path = make_zmq_path("tcp", self.side_channel_host, handshake_port) + logger.info("Starting listening on path: %s", path) + with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore + self.ready_event.set() + decoder = msgspec.msgpack.Decoder(type=DecodeMooncakeAgentMetadata) + while True: + try: + frames = sock.recv_multipart() + if len(frames) < 2: + logger.error("Invalid message format: %s", frames) + continue + + identity = frames[0] + payload = [f for f in frames[1:] if f != b""] + if len(payload) != 1: + logger.error("Invalid message format: %s", frames) + continue + + metadata = decoder.decode(payload[0]) + request_id = metadata.req_id + logger.debug( + f"Prefiller has received that request {request_id} from the decoder." + ) + sock.send_multipart((identity, b"", b"ACK")) + self.task_tracker.remove_delayed_request(request_id) + with self.lock: + self.ready_decode[request_id] = metadata + pending = self.pending_decode.pop(request_id, []) + for local_block_ids, layer_index, key, value in pending: + self.send_layer_thread.send_queue.put( + (metadata, request_id, local_block_ids, + layer_index, key, value)) + except Exception as e: + logger.error("Failed to decode message: %s", e) + + def _post_transfer(self, request_id: str): + with self.lock: + decoder_meta = self.ready_decode.pop(request_id) + path = make_zmq_path("tcp", decoder_meta.host, decoder_meta.port) + msg_encoder = msgspec.msgpack.Encoder() + encoded_data = msg_encoder.encode(request_id) + with zmq_ctx(zmq.REQ, path) as sock: # type: ignore + ensure_zmq_send(sock, encoded_data) + ack = sock.recv() + if ack != b"ACK": + raise ValueError(f"Unexpected ACK response: {ack}") + + def add_request(self, request_id: str, local_block_ids: list[int], + layer_index: int, key: torch.Tensor, value: torch.Tensor): + # add request to send layer thread + with self.lock: + if request_id in self.ready_decode: + self.send_layer_thread.send_queue.put( + (self.ready_decode[request_id], request_id, + local_block_ids, layer_index, key, value)) + else: + self.pending_decode.setdefault(request_id, []).append( + (local_block_ids, layer_index, key, value)) + + def _abort_requests(self, request_ids: set[str]): + with self.lock: + for request_id in request_ids: + self.pending_decode.pop(request_id, None) + + +class SendingLayerThread(threading.Thread): + + def __init__(self, task_tracker: KVCacheTaskTracker, total_layers: int, + engine: TransferEngine, local_kv_base_addr: list[int], + block_len: list[int], use_mla: bool, tp_rank: int, + first_kv_cache: torch.Tensor): + super().__init__(daemon=True, name="KVCacheRecvingPrefillerByeThread") + self.send_queue = queue.Queue[tuple[DecodeMooncakeAgentMetadata, str, + list[int], int, torch.Tensor, + torch.Tensor]]() + self.completion_event: threading.Event + self.completion_event_count: int + self.task_tracker = task_tracker + self.total_layers = total_layers + self.local_kv_base_addr = local_kv_base_addr + self.block_len = block_len + self.use_mla = use_mla + self.engine = engine + self.tp_rank = tp_rank + self.pd_tp_ratio = get_ascend_config().pd_tp_ratio + vllm_config = get_current_vllm_config() + max_model_len = vllm_config.scheduler_config.max_model_len + first_kv_cache = first_kv_cache[:max_model_len] + alignment = 2 * 1024 * 1024 + self.k_buffer = torch.zeros( + first_kv_cache.numel() + alignment, + dtype=first_kv_cache.dtype, + device=first_kv_cache.device) # 【4,1,128】-》【1000, 128】 + self.k_buffer = align_memory(self.k_buffer, + alignment)[:first_kv_cache.numel()].view( + -1, first_kv_cache.shape[-1]) + self.v_buffer = torch.zeros(first_kv_cache.numel() + alignment, + dtype=first_kv_cache.dtype, + device=first_kv_cache.device) + self.v_buffer = align_memory(self.v_buffer, + alignment)[:first_kv_cache.numel()].view( + -1, first_kv_cache.shape[-1]) + + for tensor in (self.k_buffer, self.v_buffer): + assert tensor.data_ptr( + ) % alignment == 0, "The address of the registered kv cache should be aligned to 2M" + ret_value = self.engine.register_memory(tensor.data_ptr(), + tensor.numel()) + logger.info( + f"Sendinglayerthread register_memory {tensor.data_ptr()} {tensor.numel()} {ret_value=}" + ) + if ret_value != 0: + raise RuntimeError("Mooncake memory registration failed. ") + + def run(self): + """Run the thread to handle KV cache receiving for prefiller bye messages.""" + # send kv cache for request in send_queue + local_rank = get_world_group().local_rank + device = torch.device(f"npu:{local_rank}") + torch.npu.set_device(device) + while True: + request = self.send_queue.get() + self._handle_request(request) + + def _handle_request(self, request: tuple[DecodeMooncakeAgentMetadata, str, + list[int], int, torch.Tensor, + torch.Tensor]): + # send kv layer to remote + req_meta, request_id, local_block_ids, layer_index, key, value = request + + try: + logger.debug( + f"Starting to transfer KV cache for request {request_id}.") + self._transfer_kv_cache(req_meta, local_block_ids, layer_index, + key, value) + logger.debug( + f"Finished transferring KV cache for request {request_id}.") + except Exception as e: + logger.error("Failed to transfer KV cache for request " + f"{request_id}: {e}") + finally: + self.task_tracker.update_done_task_count(request_id) + self.send_queue.task_done() + + def _transfer_kv_cache(self, req_meta: DecodeMooncakeAgentMetadata, + local_block_ids: list[int], layer_index: int, key, + value): + # send kv layer to remote + if len(local_block_ids) == 0: + return + + remote_host = req_meta.host + remote_te_port = req_meta.te_rpc_port + remote_kv_base_addrs = req_meta.kv_caches_base_addr + + remote_block_ids = req_meta.block_ids + if self.pd_tp_ratio == 1: + layer_local_kv_base_addr = [ + self.local_kv_base_addr[i] + for i in [2 * layer_index, 2 * layer_index + 1] + ] + layer_remote_kv_base_addr = [ + remote_kv_base_addrs[i] + for i in [2 * layer_index, 2 * layer_index + 1] + ] + grouped_remote_block_ids, grouped_local_block_ids = \ + group_concurrent_contiguous(remote_block_ids, local_block_ids) + + session_id = f"{remote_host}:{remote_te_port}" + src_list, dst_list, length_list = [], [], [] + for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( + zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)): + block_len = self.block_len[ + k % 2] if self.use_mla else self.block_len[0] + for group_remote_block_id, group_local_block_id in zip( + grouped_remote_block_ids, grouped_local_block_ids): + src = src_layer_base_addr + group_local_block_id[ + 0] * block_len + dst = dst_layer_base_addr + group_remote_block_id[ + 0] * block_len + length = len(group_local_block_id) * block_len + src_list.append(src) + dst_list.append(dst) + length_list.append(length) + torch.npu.synchronize() + ret = self.engine.batch_transfer_sync_write( + session_id, src_list, dst_list, length_list) + + if ret < 0: + logger.error("Mooncake transfer failed for request %s", + req_meta.req_id) + raise RuntimeError(f"Mooncake transfer failed, ret: {ret}") + else: + key = key.view(-1, key.shape[-1]) + value = value.view(-1, key.shape[-1]) + self.k_buffer[:key.shape[0]].copy_(key) # [:4, 128] -> + self.v_buffer[:value.shape[0]].copy_(value) + + layer_local_kv_base_addr = [ + self.k_buffer.data_ptr(), + self.v_buffer.data_ptr() + ] + + layer_remote_kv_base_addr = [ + remote_kv_base_addrs[i] + for i in [2 * layer_index, 2 * layer_index + 1] + ] + + grouped_remote_block_ids, _ = group_concurrent_contiguous( + remote_block_ids) + + session_id = f"{remote_host}:{remote_te_port}" + src_list, dst_list, length_list = [], [], [] + for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( + zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)): + src_layer_addr = src_layer_base_addr + for group_remote_block_id in grouped_remote_block_ids: + block_len = self.block_len[0] + remote_block_len = self.block_len[0] * self.pd_tp_ratio + src_list.append(src_layer_addr) + + if src_layer_addr + len( + group_remote_block_id + ) * block_len > src_layer_base_addr + key.numel( + ) * key.element_size(): + length = src_layer_base_addr + key.numel( + ) * key.element_size() - src_layer_addr + else: + length = len(group_remote_block_id) * block_len + length_list.append(length) + + dst_list.append(dst_layer_base_addr + + group_remote_block_id[0] * + remote_block_len + length * + (self.tp_rank % self.pd_tp_ratio)) + src_layer_addr += length + torch.npu.synchronize() + ret = self.engine.batch_transfer_sync_write( + session_id, src_list, dst_list, length_list) + self.completion_event_count -= 1 + + if self.completion_event_count == 0 and self.completion_event is not None: + print( + f"[_transfer_kv_cache] {self.completion_event_count} self.event.set()" + ) + self.completion_event.set() + + if ret < 0: + logger.error("Mooncake transfer failed for request %s", + req_meta.req_id) + raise RuntimeError(f"Mooncake transfer failed, ret: {ret}") + + def add_event(self, event: threading.Event, count: int) -> None: + self.completion_event = event + self.completion_event_count = count + + +class KVCacheRecvingLayerThread(threading.Thread): + + def __init__(self, tp_rank: int, side_channel_port: int, tp_size: int, + local_engine_id: str, ready_event: threading.Event): + super().__init__(daemon=True, name="KVCacheRecvingLayerThread") + self.tp_rank = tp_rank + self.tp_size = tp_size + self.local_engine_id = local_engine_id + self.side_channel_host = get_ip() + self.side_channel_port = side_channel_port + self.lock = threading.Lock() + self.done_requests = set[str]() + self.ready_event = ready_event + + def get_and_clear_finished_requests(self) -> set[str]: + """ + Get and clear the requests that have been completed. + Returns: + A set of request IDs that have been completed. + """ + with self.lock: + finished_requests = self.done_requests + self.done_requests = set() + return finished_requests + + def run(self): + """Run the thread to handle KV cache transfer requests.""" + #TODO layerwise step9 + # with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore + # while True: + # recv_msg from prefill request send finish= + # Listen for new requests for metadata. + # NOTE(rob): we need each rank to have a unique port. This hack to keeps + # us moving. We will switch when moving to etcd or where we have a + # single ZMQ socket in the scheduler. + handshake_port = self.side_channel_port + self.tp_rank + path = make_zmq_path("tcp", self.side_channel_host, handshake_port) + logger.info("Starting listening on path: %s", path) + with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore + self.ready_event.set() + decoder = msgspec.msgpack.Decoder(type=str) + while True: + try: + frames = sock.recv_multipart() + if len(frames) < 2: + logger.error("Invalid message format: %s", frames) + continue + + identity = frames[0] + payload = [f for f in frames[1:] if f != b""] + if len(payload) != 1: + logger.error("Invalid message format: %s", frames) + continue + + request_id = decoder.decode(payload[0]) + with self.lock: + self.done_requests.add(request_id) + sock.send_multipart((identity, b"", b"ACK")) + except Exception as e: + logger.error("Failed to decode message: %s", e) + + +class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata): + + def __init__(self): + self.requests: dict[str, ReqMeta] = {} + self.requests_to_send: dict[str, float] = {} + + def add_new_req(self, + request_id: str, + local_block_ids: list[int], + kv_transfer_params: dict[str, Any], + metaserver=None): + self.requests[request_id] = ReqMeta( + local_block_ids=local_block_ids, + remote_block_ids=kv_transfer_params.get("remote_block_ids", None), + remote_engine_id=kv_transfer_params["remote_engine_id"], + remote_host=kv_transfer_params["remote_host"], + remote_port=kv_transfer_params["remote_port"], + metaserver=metaserver, + remote_tp_size=kv_transfer_params.get("remote_tp_size", None), + ) + + +class MooncakeLayerwiseConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + assert vllm_config.kv_transfer_config is not None + self.engine_id = vllm_config.kv_transfer_config.engine_id + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler: Optional[MooncakeLayerwiseConnectorScheduler] = \ + MooncakeLayerwiseConnectorScheduler(vllm_config, str(self.engine_id)) + self.connector_worker: Optional[ + MooncakeLayerwiseConnectorWorker] = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = MooncakeLayerwiseConnectorWorker( + vllm_config, str(self.engine_id)) + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + def get_finished_count(self) -> Optional[int]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_finished_count() + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + return self.connector_worker.get_finished() + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, + MooncakeLayerwiseConnectorMetadata) + self.connector_worker.start_load_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + """MooncakeLayerwiseConnector does not do layerwise saving.""" + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, + MooncakeLayerwiseConnectorMetadata) + self.connector_worker.wait_for_layer_load(layer_name) + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """MooncakeLayerwiseConnector does not save explicitly.""" + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, + MooncakeLayerwiseConnectorMetadata) + self.connector_worker.save_kv_layer(layer_name, kv_layer, + attn_metadata, + self._connector_metadata) + + def wait_for_save(self): + """MooncakeLayerwiseConnector does not save explicitly.""" + pass + + +class MooncakeLayerwiseConnectorScheduler: + """Implementation of Scheduler side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.engine_id = engine_id + logger.info("Initializing Mooncake Scheduler %s", engine_id) + + self.side_channel_host = get_ip() + self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \ + vllm_config.parallel_config.data_parallel_size + + # Handshake base port + self.side_channel_port = ( + vllm_config.kv_transfer_config.kv_port + + vllm_config.parallel_config.data_parallel_rank_local * + vllm_config.parallel_config.tensor_parallel_size) + + # Requests that need to start recv. + # New requests are added by update_state_after_alloc in + # the scheduler. Used to make metadata passed to Worker. + self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} + self._reqs_need_send: dict[str, float] = {} + self._reqs_need_send_layerwise: dict[str, tuple[str, int, + list[int]]] = {} + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + """ + For remote prefill, pull all prompt blocks from remote + asynchronously relative to engine execution. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + Returns: + * the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + * true if the external KV cache tokens will be loaded + asynchronously (between scheduler steps). + """ + + params = request.kv_transfer_params + logger.debug( + "MooncakeLayerwiseConnector get_num_new_matched_tokens: " + "num_computed_tokens=%s, kv_transfer_params=%s", + num_computed_tokens, params) + + if params is not None and params.get("do_remote_prefill"): + assert num_computed_tokens == 0, "Currently only support " \ + "prefill with num_computed_tokens == 0." + # Assume that the request's KV cache is already fully prefilled and + # can be fetched entirely from the prefill node. + count = len(request.prompt_token_ids) + if count > 0: + return count, True + + # No remote prefill for this request. + return 0, False + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + + params = request.kv_transfer_params + logger.debug( + "MooncakeLayerwiseConnector update_state_after_alloc: " + "num_external_tokens=%s, kv_transfer_params=%s", + num_external_tokens, params) + + if params is not None and params.get("do_remote_prefill"): + if all(p in params for p in ("remote_engine_id", "remote_host", + "remote_port")): + local_block_ids = (blocks.get_unhashed_block_ids() + if num_external_tokens > 0 else []) + # Get unhashed blocks to pull from remote. + self._reqs_need_recv[request.request_id] = (request, + local_block_ids) + else: + logger.warning( + "Got invalid KVTransferParams: %s. This " + "request will not utilize KVTransfer", params) + params["do_remote_prefill"] = False + + # Layerwise prefiller add request need send + if params is not None and params.get("do_remote_decode"): + local_block_ids = (blocks.get_block_ids()[0]) + self._reqs_need_send_layerwise[request.request_id] = ( + params["metaserver"], len(request.all_token_ids), + local_block_ids) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + meta = MooncakeLayerwiseConnectorMetadata() + + # Loop through scheduled reqs and convert to ReqMeta. + for req_id, (req, block_ids) in self._reqs_need_recv.items(): + assert req.kv_transfer_params is not None + # For the case where there are no remote blocks to pull + # (block_ids is empty), we don't need to schedule + # an async read on the worker side. + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + ) + + # Clear the list once workers start the transfers + self._reqs_need_recv.clear() + + cached_reqs = scheduler_output.scheduled_cached_reqs + new_reqs = scheduler_output.scheduled_new_reqs + for req_id, new_blocks in zip(cached_reqs.req_ids, + cached_reqs.new_block_ids): + if req_id in self._reqs_need_send_layerwise and new_blocks is not None: + metaserver, total_tokens, block_ids = self._reqs_need_send_layerwise[ + req_id] + block_ids.extend(new_blocks[0]) + + computed_tokens = dict( + list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens)) + + [(x.req_id, x.num_computed_tokens) for x in new_reqs]) + for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items( + ): + if req_id in self._reqs_need_send_layerwise: + metaserver, total_tokens, block_ids = self._reqs_need_send_layerwise[ + req_id] + current_tokens = computed_tokens.get(req_id, + 0) + scheduled_tokens + if current_tokens == total_tokens: + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=defaultdict(lambda: None), + metaserver=metaserver) + self._reqs_need_send_layerwise.pop(req_id) + + meta.requests_to_send = self._reqs_need_send + self._reqs_need_send = {} + + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Once a request is finished, determine whether request blocks + should be freed now or will be sent asynchronously and freed later. + """ + + params = request.kv_transfer_params + logger.debug( + "MooncakeLayerwiseConnector request_finished, request_status=%s, " + "kv_transfer_params=%s", request.status, params) + + if (params is None or not params.get("do_remote_decode") + or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): + return False, None + + computed_block_ids = block_ids + delay_free_blocks = len(computed_block_ids) > 0 + if delay_free_blocks: + logger.info("Delaying free of %d blocks for request %s", + len(computed_block_ids), request.request_id) + self._reqs_need_send[request.request_id] = time.time() + + return delay_free_blocks, dict( + do_remote_prefill=True, + do_remote_decode=False, + remote_engine_id=self.engine_id, + remote_host=self.side_channel_host, + remote_port=self.side_channel_port, + remote_block_ids=computed_block_ids, + ) + + def get_finished_count(self) -> Optional[int]: + prefill_parallel_config: dict[ + str, + Any] = self.vllm_config.kv_transfer_config.get_from_extra_config( + "prefill", {}) + + assert "tp_size" in prefill_parallel_config.keys() + self._prefill_tp_size = prefill_parallel_config["tp_size"] + decode_parallel_config: dict[ + str, + Any] = self.vllm_config.kv_transfer_config.get_from_extra_config( + "decode", {}) + assert "tp_size" in decode_parallel_config.keys() + self._decode_tp_size = decode_parallel_config["tp_size"] + + if self.vllm_config.model_config.use_mla: + return self._decode_tp_size + else: + # TODO support mha and gqa + return None + + +class MooncakeLayerwiseConnectorWorker: + """Implementation of Worker side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + self._get_prefill_decode_size(vllm_config) + if self._prefill_tp_size < self._decode_tp_size: + raise ValueError( + f"prefill_tp_size: {self._prefill_tp_size} must be greater than" + f" or equal to the decode_tp_size: {self._decode_tp_size}") + + if TransferEngine is None: + raise RuntimeError("mooncake is not available") + logger.info("Initializing Mooncake work %s", engine_id) + self.engine = TransferEngine() + + # Metadata. + self.completion_event: threading.Event + self.vllm_config = vllm_config + self.engine_id = engine_id + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = vllm_config.parallel_config.tensor_parallel_size + self.tp_group = get_tp_group() + self.dp_rank = vllm_config.parallel_config.data_parallel_rank_local + self.dp_size = vllm_config.parallel_config.data_parallel_size_local + self.kv_caches: dict[str, torch.Tensor] = {} + self.side_channel_host = get_ip() + self.max_device_id = self.tp_size * self.dp_size + self.kv_role = vllm_config.kv_transfer_config.kv_role + self.total_layers = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config) + + self.executor = ThreadPoolExecutor(1) + self.metaserver_client = httpx.Client( + limits=httpx.Limits(max_connections=100000), + timeout=None) if self.tp_rank == 0 else None + + # Handshake base port + self.side_channel_port = ( + vllm_config.kv_transfer_config.kv_port + + vllm_config.parallel_config.data_parallel_rank_local * + vllm_config.parallel_config.tensor_parallel_size) + self.handshake_port = self.side_channel_port + self.tp_rank + self.sockets: dict = {} + + # get tp device id + # TODO(kw): https://github.com/vllm-project/vllm-ascend/pull/940 + # introducing some changes + device_ids_str = envs_ascend.PHYSICAL_DEVICES + if device_ids_str is None: + device_ids = list( + range(self.dp_rank * self.tp_size, + (self.dp_rank + 1) * self.tp_size)) + else: + device_ids = list(map(int, device_ids_str.split(','))) + start_index = self.dp_rank * self.tp_size + end_index = start_index + self.tp_size + if len(device_ids) < end_index: + raise ValueError( + f"Not enough physical devices available for DP rank {self.dp_rank}. " + f"Expected at least {end_index} devices, but found {len(device_ids)} " + "in PHYSICAL_DEVICES.") + device_ids = device_ids[start_index:end_index] + assert len(device_ids) > self.tp_rank # type: ignore + self.device_id = device_ids[self.tp_rank] # type: ignore + + if vllm_config.kv_transfer_config.get_from_extra_config( + 'use_ascend_direct', False): + hostname = self.side_channel_host + else: + hostname = f"{self.side_channel_host}:0:npu_{self.device_id}" + self._initialize(hostname=hostname, device_name=None) + self.te_rpc_port = self.engine.get_rpc_port() + + # Background thread for sending or receiving KV caches. + self.kv_send_layer_thread: Optional[KVCacheSendingLayerThread] = None + self.kv_recv_layer_thread: Optional[KVCacheRecvingLayerThread] = None + + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.kv_caches_base_addr: list[int] = [] + + self.pd_tp_ratio = get_ascend_config().pd_tp_ratio + self.first_kv_cache = None + + def _get_prefill_decode_size(self, vllm_config: VllmConfig): + # get prefill tp and dp size from extra config + prefill_parallel_config: dict[ + str, Any] = vllm_config.kv_transfer_config.get_from_extra_config( + "prefill", {}) + + assert "tp_size" in prefill_parallel_config.keys() + self._prefill_tp_size = prefill_parallel_config["tp_size"] + + assert "dp_size" in prefill_parallel_config.keys() + self._prefill_dp_size = prefill_parallel_config["dp_size"] + + # get decode tp and dp size from extra config + decode_parallel_config: dict[ + str, Any] = vllm_config.kv_transfer_config.get_from_extra_config( + "decode", {}) + assert "tp_size" in decode_parallel_config.keys() + self._decode_tp_size = decode_parallel_config["tp_size"] + assert "dp_size" in decode_parallel_config.keys() + self._decode_dp_size = decode_parallel_config["dp_size"] + + def _initialize( + self, + hostname: str, + device_name: Optional[str], + ) -> None: + """Initialize the mooncake instance.""" + device_name = device_name if device_name is not None else "" + ret_value = self.engine.initialize(hostname, "P2PHANDSHAKE", "ascend", + device_name) + if ret_value != 0: + raise RuntimeError( + f"Mooncake initialization failed with ret_value: {ret_value}") + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """Register the KV Cache data.""" + + _, first_kv_cache_tuple = next(iter(kv_caches.items())) + first_kv_cache = first_kv_cache_tuple[0] + self.first_kv_cache = first_kv_cache + + # TODO(tms): Find a more robust way to detect and handle MLA + self.use_mla = first_kv_cache_tuple[0].size( + -1) != first_kv_cache_tuple[1].size(-1) + if self.use_mla: + # MLA case.[num_block, block_size, 1, hidden_dim] + self.num_blocks = first_kv_cache.shape[0] + block_rank = 3 # [block_size, latent_dim] + block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] + block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] + self.block_len = [ + first_kv_cache[0].element_size() * math.prod(block_shape_norm), + first_kv_cache[1].element_size() * math.prod(block_shape_pe) + ] + logger.info( + "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", + self.num_blocks, block_shape_norm, block_shape_pe) + else: + # [num_block, block_size, num_head, hidden_dim] + self.num_blocks = first_kv_cache.shape[0] + kv_elem_size = first_kv_cache.element_size() + block_rank = 3 # [block_size, kv_heads, head_dim] + block_shape = first_kv_cache.shape[-block_rank:] + self.block_len = [kv_elem_size * math.prod(block_shape)] + logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, + block_shape) + + logger.info("Registering KV_Caches. use_mla: %s, shape %s", + self.use_mla, first_kv_cache.shape) + + self.kv_caches = kv_caches + kv_caches_base_addr = [] + for cache_or_caches in kv_caches.values(): + # Normalize to always be a list of caches + if self.use_mla: + for i, cache in enumerate(cache_or_caches, 0): + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len[i % 2] + kv_caches_base_addr.append(base_addr) + self._register(base_addr, region_len) + else: + cache_list = [cache_or_caches + ] if self.use_mla else cache_or_caches + for cache in cache_list: + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len[0] + kv_caches_base_addr.append(base_addr) + self._register(base_addr, region_len) + self.kv_caches_base_addr = kv_caches_base_addr + + # After KV Caches registered, start the sending or receiving thread. + metadata = MooncakeAgentMetadata( + engine_id=self.engine_id, + te_rpc_port=self.te_rpc_port, + kv_caches_base_addr=kv_caches_base_addr, + num_blocks=self.num_blocks, + ) + + ready_event = threading.Event() + if self.kv_role == 'kv_producer': + self.kv_send_layer_thread = KVCacheSendingLayerThread( + self.tp_rank, self.tp_size, self._decode_tp_size, + self.engine_id, self.side_channel_host, self.side_channel_port, + metadata, ready_event, self.total_layers, self.engine, + kv_caches_base_addr, self.block_len, self.use_mla, + self.first_kv_cache) + self.kv_send_layer_thread.start() + else: + self.kv_recv_layer_thread = KVCacheRecvingLayerThread( + self.tp_rank, self.side_channel_port, self.tp_size, + self.engine_id, ready_event) + self.kv_recv_layer_thread.start() + ready_event.wait() + + def _register(self, ptr, length): + logger.info( + "Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, " + "block_lens=%s", ptr, length, self.num_blocks, self.block_len) + ret_value = self.engine.register_memory(ptr, length) + if ret_value != 0: + raise RuntimeError("Mooncake memory registration failed.") + + def _access_metaserver(self, url, message): + self.metaserver_client.post(url, json=message) + + def get_finished(self) -> tuple[set[str], set[str]]: + done_sending = ( + self.kv_send_layer_thread. + get_and_clear_finished_requests( # type: ignore[union-attr] + ) if self.kv_role == 'kv_producer' else set()) + done_recving = ( + self.kv_recv_layer_thread. + get_and_clear_finished_requests( # type: ignore[union-attr] + ) if self.kv_role == 'kv_consumer' else set()) + if self.tp_rank == 0: + logger.debug( + "Number of completed KV cache send requests: %d, receive " + "requests: %d", len(done_sending), len(done_recving)) + return done_sending, done_recving + + def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata): + """Start loading KV blocks from remote engine.""" + self.current_layer = 0 + if self.vllm_config.kv_transfer_config.is_kv_producer: + for req_id, meta in metadata.requests.items(): + logger.debug( + f"Send request: {req_id} to proxy metaserver: {meta.metaserver}" + ) + if self.tp_rank == 0: + # All parameters here should appear in the returned dict of + # request_finished in the scheduler side except "request_id". + kv_transfer_params = dict( + request_id=req_id, + do_remote_prefill=True, + do_remote_decode=False, + remote_engine_id=self.engine_id, + remote_host=self.side_channel_host, + remote_port=self.side_channel_port) + + future = self.executor.submit( + self._access_metaserver, + url=meta.metaserver, + message=kv_transfer_params, + ) + + def handle_exception(future): + if future.exception(): + logger.error( + f"Access metaserver fail: {future.exception()}" + ) + + future.add_done_callback(handle_exception) + else: + for req_id, meta in metadata.requests.items(): + for offset in range(self.pd_tp_ratio): + path = make_zmq_path( + "tcp", meta.remote_host, meta.remote_port + + self.tp_rank * self.pd_tp_ratio + offset) + logger.debug( + f"Notify the prefiller: {path} that request: {req_id} from decoder is ready." + ) + msg_encoder = msgspec.msgpack.Encoder() + docode_metadata = DecodeMooncakeAgentMetadata( + req_id=req_id, + block_ids=meta.local_block_ids, + port=self.handshake_port, + host=self.side_channel_host, + engine_id=self.engine_id, + te_rpc_port=self.te_rpc_port, + kv_caches_base_addr=self.kv_caches_base_addr, + num_blocks=self.num_blocks) + encoded_data = msg_encoder.encode(docode_metadata) + size_in_bytes = len(encoded_data) + logger.debug( + "Size of encoded Mooncake agent metadata: %d bytes", + size_in_bytes) + with zmq_ctx(zmq.REQ, path) as sock: # type: ignore + ensure_zmq_send(sock, encoded_data) + ack = sock.recv() + if ack != b"ACK": + raise ValueError( + f"Unexpected ACK from prefill node: {ack}") + + if self.kv_send_layer_thread is not None: + for req_id, delay_start_time in metadata.requests_to_send.items(): + if self.tp_rank in self._get_remote_tp_ranks_for_req(req_id): + self.kv_send_layer_thread.add_delayed_request( + req_id, delay_start_time) + + def save_kv_layer(self, layer_name: str, kv_layer: Tuple[torch.Tensor, + torch.Tensor], + attn_metadata: "AttentionMetadata", + connector_metadata: MooncakeLayerwiseConnectorMetadata, + **kwargs) -> None: + """MooncakeLayerwiseConnector does not save explicitly.""" + if self.kv_role == 'kv_producer': + if self.pd_tp_ratio != 1: + if self.current_layer != 0: + self.completion_event.wait() + self.completion_event = threading.Event() + if self.kv_send_layer_thread is not None: + self.kv_send_layer_thread.send_layer_thread.add_event( + self.completion_event, + len(connector_metadata.requests.keys())) + + def sort_kv_cache(input_kv: list[list[int]]): + return torch.cat([ + torch.chunk(tensor, self.pd_tp_ratio, dim=0)[x] + for x in range(self.pd_tp_ratio) for tensor in input_kv + ]) + + total_block_ids = [ + request.local_block_ids + for request in connector_metadata.requests.values() + ] + keys = [ + kv_layer[0][block_ids].reshape( + -1, *kv_layer[0].shape[2:]).clone() + for block_ids in total_block_ids + ] + values = [ + kv_layer[1][block_ids].reshape( + -1, *kv_layer[1].shape[2:]).clone() + for block_ids in total_block_ids + ] + key_block_size = keys[0].size(0) // len(total_block_ids[0]) + value_block_size = values[0].size(0) // len(total_block_ids[0]) + keys = sort_kv_cache(keys) # [req1_key, req2_key] + values = sort_kv_cache(values) + (keys, + values) = kv_alltoall_and_rearrange(self.pd_tp_ratio, keys, + values) + key_start_id = 0 + value_start_id = 0 + else: + key = None + value = None + for req_id, request in connector_metadata.requests.items(): + logger.info(f"Add request {req_id} to kv send layer thread. ") + if self.pd_tp_ratio != 1: + key_block_num = len( + request.local_block_ids) * key_block_size + value_block_num = len( + request.local_block_ids) * value_block_size + key = keys[key_start_id:key_start_id + + key_block_num] #.clone().contiguous() + value = values[value_start_id:value_start_id + + value_block_num] #.clone().contiguous() + key_start_id += key_block_num + value_start_id += value_block_num + if self.kv_send_layer_thread is not None: + self.kv_send_layer_thread.add_request( + request_id=req_id, + local_block_ids=request.local_block_ids, + layer_index=self.current_layer, + key=key, + value=value) + self.current_layer += 1 + + def wait_for_layer_load(self, layer_name: str) -> None: + pass + + def _get_remote_tp_rank(self, req_id: str) -> int: + return self._get_remote_tp_ranks_for_req(req_id)[self.tp_rank] + + def _get_remote_tp_ranks_for_req(self, req_id: str) -> list[int]: + if self._prefill_tp_size == self._decode_tp_size: + return list(range(self._prefill_tp_size)) + + seed = string_to_int64_hash(req_id) + rand = random.Random(seed) + sampled_nums = rand.sample(range(self._prefill_tp_size), + self._decode_tp_size) + return sampled_nums + + +@contextlib.contextmanager +def zmq_ctx(socket_type: Any, + addr: str) -> Iterator[zmq.Socket]: # type: ignore + """Context manager for a ZMQ socket""" + + if socket_type not in (zmq.ROUTER, zmq.REQ, zmq.DEALER): # type: ignore + raise ValueError(f"Unexpected socket type: {socket_type}") + + ctx: Optional[zmq.Context] = None # type: ignore + try: + ctx = zmq.Context() # type: ignore + yield make_zmq_socket(ctx=ctx, + path=addr, + socket_type=socket_type, + bind=socket_type == zmq.ROUTER) # type: ignore + finally: + if ctx is not None: + ctx.destroy(linger=0) + + +def group_concurrent_contiguous( + src: List[int], + dst: List[int] = [] +) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: + """Vectorised NumPy implementation.""" + if not dst: + src_only_indices: npt.NDArray[np.int64] = np.array(src, dtype=np.int64) + + if src_only_indices.size == 0: + return [], [] + + brk = np.where((np.diff(src_only_indices) != 1))[0] + 1 + src_groups = np.split(src_only_indices, brk) + src_groups = [g.tolist() for g in src_groups] + + return src_groups, [] + + else: + src_indices: npt.NDArray[np.int64] = np.array(src, dtype=np.int64) + dst_indices: npt.NDArray[np.int64] = np.array(dst, dtype=np.int64) + + if src_indices.size == 0: + return [], [] + + brk = np.where((np.diff(src_indices) != 1) + | (np.diff(dst_indices) != 1))[0] + 1 + src_groups = np.split(src_indices, brk) + dst_groups = np.split(dst_indices, brk) + + src_groups = [g.tolist() for g in src_groups] + dst_groups = [g.tolist() for g in dst_groups] + + return src_groups, dst_groups + + +def string_to_int64_hash(input_str): + """ + Hash the string using SHA-256 and convert it into an int64 integer. + """ + hashed_bytes = hashlib.sha256(input_str.encode("utf-8")).digest() + trunked_bytes = hashed_bytes[:8] + uint64_value = struct.unpack(" 0: + logger.warning( + f"Send failed: {e}, retrying... ({retries_left} " + "attempts left)") + time.sleep(0.1) + else: + logger.error(f"Send failed after all retries: {e}") + raise RuntimeError(f"Failed to send data after {max_retries} " + f"retries: {e}") + + +def ensure_zmq_recv( + socket: zmq.Socket, # type: ignore + poller: zmq.Poller, # type: ignore + timeout: float = 1.0, + max_retries: int = 3) -> bytes: + retries_left = max_retries + while True: + try: + if dict(poller.poll(int(timeout * 1000))): # milliseconds + data = socket.recv() + return data + else: + raise zmq.ZMQError("Receive timeout") # type: ignore + except zmq.ZMQError as e: # type: ignore + retries_left -= 1 + if retries_left > 0: + logger.warning(f"Receive failed: {e}, retrying... " + f"({retries_left} attempts left)") + time.sleep(0.1) + else: + logger.error(f"Receive failed after all retries: {e}") + raise RuntimeError( + f"Failed to receive data after {max_retries} " + f"retries: {e}") diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 07c707e..071a234 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -13,6 +13,7 @@ _MC2: Optional[GroupCoordinator] = None _MLP_TP: Optional[GroupCoordinator] = None _OTP: Optional[GroupCoordinator] = None _LMTP: Optional[GroupCoordinator] = None +_P_TP: Optional[GroupCoordinator] = None def get_mc2_group() -> GroupCoordinator: @@ -37,6 +38,12 @@ def get_mlp_tp_group() -> GroupCoordinator: return _MLP_TP +def get_p_tp_group() -> GroupCoordinator: + assert _P_TP is not None, ( + "distributed prefill tensor parallel group is not initialized") + return _P_TP + + def model_parallel_initialized(): return (_MC2 is not None) @@ -54,6 +61,22 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): all_ranks = torch.arange(world_size).reshape( -1, parallel_config.data_parallel_size * parallel_config.tensor_parallel_size) + + pd_tp_ratio = get_ascend_config().pd_tp_ratio + global _P_TP + assert _P_TP is None, ( + "distributed prefill tensor parallel group is already initialized") + prefill_tensor_model_parallel_size = pd_tp_ratio if \ + pd_tp_ratio > 0 and pd_tp_ratio < parallel_config.tensor_parallel_size else parallel_config.tensor_parallel_size + group_ranks = all_ranks.view(-1, + prefill_tensor_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + num = get_world_group().local_rank // pd_tp_ratio + _P_TP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name=f"p_tp_{num}") + global _MC2 group_ranks = all_ranks.unbind(0) group_ranks = [x.tolist() for x in group_ranks] @@ -142,3 +165,8 @@ def destroy_ascend_model_parallel(): if _OTP: _OTP.destroy() _OTP = None + + global _P_TP + if _P_TP: + _P_TP.destroy() + _P_TP = None diff --git a/vllm_ascend/distributed/utils.py b/vllm_ascend/distributed/utils.py new file mode 100644 index 0000000..4b1344a --- /dev/null +++ b/vllm_ascend/distributed/utils.py @@ -0,0 +1,47 @@ +import torch +import torch.distributed as dist + +from vllm_ascend.distributed.parallel_state import get_p_tp_group + + +def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor, + value: torch.TensorType): + if pd_tp_ratio <= 1: + return None, None + elif key is None or value is None: + raise ValueError("key or value is None") + k_output = alltoall_and_rearrange(pd_tp_ratio, key) + v_output = alltoall_and_rearrange(pd_tp_ratio, value) + return k_output, v_output + + +def alltoall_and_rearrange(tp_ratio: int, input_tensor: torch.Tensor): + num_kv_heads = input_tensor.size(1) + output_tensor = torch.zeros_like(input_tensor) + dist.all_to_all_single(output_tensor, + input_tensor, + group=get_p_tp_group().device_group) + input_tensor = 0 + result = rearrange_output(output_tensor, tp_ratio, num_kv_heads) + output_tensor = 0 + return result + + +def rearrange_output(base_output: torch.Tensor, cut_num: int, + num_kv_heads: int): + size_0 = base_output.size(0) + if size_0 % cut_num != 0: + raise ValueError( + f"The size of dim 0 [{size_0}] must be divisible by the cut_num [{cut_num}]" + ) + chunk_size = size_0 // cut_num + reshaped = base_output.view(cut_num, chunk_size, -1) + transposed = reshaped.transpose(0, 1) + return transposed.contiguous().view(size_0, num_kv_heads, -1) + + +def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor: + data_ptr = tensor.data_ptr() + aligned_addr = (data_ptr + alignment - 1) // alignment * alignment + offset = (aligned_addr - data_ptr) // tensor.element_size() + return tensor[int(offset):]