From af7a56550b3c3489d42c2e00cf0795d0d7d47d0e Mon Sep 17 00:00:00 2001 From: wangxiaoteng888 <56506195+wangxiaoteng888@users.noreply.github.com> Date: Thu, 30 Oct 2025 22:21:11 +0800 Subject: [PATCH] [bugfix_v0.11.0-dev] layerwise D first plan (#3907) ### What this PR does / why we need it? Refactored the layerwise code to send to the D node first, preventing P-node hangs due to communication timeouts when DP > 1. --------- Signed-off-by: nwpu-zxr Signed-off-by: liziyu Signed-off-by: wangxiaoteng Co-authored-by: nwpu-zxr Co-authored-by: liziyu --- ..._balance_proxy_layerwise_server_example.py | 252 +--- .../kv_connector/test_mooncake_connector.py | 2 +- .../test_mooncake_layerwise_connector.py | 1205 ++++++++--------- .../mooncake_layerwise_connector.py | 859 +++++------- vllm_ascend/distributed/parallel_state.py | 3 +- 5 files changed, 965 insertions(+), 1356 deletions(-) diff --git a/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py b/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py index 7e80b55..ea0c5be 100644 --- a/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py +++ b/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py @@ -88,18 +88,17 @@ import argparse import asyncio import functools import heapq -import json import os import sys import threading import uuid from contextlib import asynccontextmanager -from dataclasses import dataclass -from typing import Any, List +from typing import List import httpx from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse +from transformers import AutoTokenizer from vllm.logger import init_logger logger = init_logger(__name__) @@ -107,7 +106,6 @@ logger = init_logger(__name__) # Add uvloop for faster event loop if available try: import uvloop - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except ImportError: pass @@ -154,6 +152,9 @@ class ProxyState: heapq.heapify(self.prefiller_heap) heapq.heapify(self.decoder_heap) self.req_id_future = {} + self.req_data_dict = {} + self.tokenizer = AutoTokenizer.from_pretrained( + global_args.tokenizer_dir) def _update_prefiller_priority(self, server_idx: int): """Update the priority of a prefiller server in the heap.""" @@ -280,6 +281,10 @@ def parse_args(): nargs="+", default=["localhost"]) parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002]) + parser.add_argument("--tokenizer-dir", + type=str, + default="/mnt/weight/Qwen3-235B-A22B-W8A8", + help="Maximum number of retries for HTTP requests") parser.add_argument("--max-retries", type=int, default=3, @@ -356,17 +361,6 @@ async def send_request_to_service(client: httpx.AsyncClient, aborted_requests = proxy_state.aquire_aborted_prefiller_requests( prefiller_id) req_data = req_data.copy() - req_data['kv_transfer_params'] = { - "do_remote_decode": True, - "do_remote_prefill": False, - "remote_engine_id": None, - "remote_block_ids": None, - "remote_host": None, - "remote_port": None, - "aborted_request": list(aborted_requests), - "metaserver": - f"http://{global_args.host}:{global_args.port}/v1/metaserver" - } req_data["stream"] = False req_data["max_tokens"] = 1 if "stream_options" in req_data: @@ -458,59 +452,11 @@ def get_api_request_id(api, req_id): return "chatcmpl-" + req_id -async def _handle_select_instance(api: str, req_data: Any, - request_length: int): - prefiller_score = proxy_state.calculate_prefill_scores(request_length) - logger.debug( - f"Request length: {request_length}, Prefiller score: {prefiller_score}" - ) - request_id = await proxy_state.next_req_id() - # Select prefiller - prefiller_idx = proxy_state.select_prefiller(prefiller_score) - prefiller = proxy_state.prefillers[prefiller_idx] - result_future = asyncio.Future() # type: ignore - request_id_api = get_api_request_id(api, request_id) - proxy_state.req_id_future[request_id_api] = result_future - # Send request to prefiller - asyncio.get_running_loop().create_task( - send_request_to_service(prefiller.client, - prefiller_idx, - api, - req_data, - request_id, - max_retries=global_args.max_retries, - base_delay=global_args.retry_delay)) - proxy_state.release_prefiller(prefiller_idx, prefiller_score) - - response = await result_future - del proxy_state.req_id_future[request_id_api] - req_data["kv_transfer_params"] = response - - # Select decoder - decoder_score = proxy_state.calculate_decode_scores(request_length) - logger.debug("Decoder score: %f", decoder_score) - # Use the prefiller's kv_transfer_params to select decoder - decoder_idx = proxy_state.select_decoder(decoder_score) - decoder = proxy_state.decoders[decoder_idx] - logger.debug("Using %s %s", prefiller.url, decoder.url) - return InstanceInfo(request_id=request_id, - prefiller_idx=prefiller_idx, - prefiller_score=prefiller_score, - prefiller=prefiller, - decoder=decoder, - decoder_idx=decoder_idx, - decoder_score=decoder_score) - - -@dataclass -class InstanceInfo: - request_id: str - prefiller_idx: int - prefiller_score: float - prefiller: ServerState - decoder_idx: int - decoder_score: float - decoder: ServerState +def get_origin_request_id(api, req_id): + if api == "/completions": + return req_id.replace("cmpl-", "").replace("-0", "") + elif api == "/chat/completions": + return req_id.replace("chatcmpl-", "") async def _handle_completions(api: str, request: Request): @@ -518,120 +464,47 @@ async def _handle_completions(api: str, request: Request): req_data = await request.json() req_body = await request.body() request_length = len(req_body) - instance_info = await _handle_select_instance(api, req_data, - request_length) - stream_flag = bool(req_data.get("stream", False)) - chat_flag = "messages" in req_data - - if "prompt" in req_data: - origin_prompt = req_data["prompt"] - elif chat_flag: - messages = req_data["messages"] - origin_prompt = messages[0].get("content", "") - else: - origin_prompt = "" - # refer to vLLM sampling_params: max_token default value - origin_max_tokens = req_data.get("max_tokens", 16) + request_id = await proxy_state.next_req_id() + request_id_api = get_api_request_id(api, request_id) + proxy_state.req_data_dict[request_id_api] = (req_data, request_length, + api) + req_data['kv_transfer_params'] = { + "do_remote_decode": + False, + "do_remote_prefill": + True, + "metaserver": + f"http://{global_args.host}:{global_args.port}/v1/metaserver" + } + # Select decoder + decoder_score = proxy_state.calculate_decode_scores(request_length) + logger.debug("Decoder score: %f", decoder_score) + # Use the prefiller's kv_transfer_params to select decoder + decoder_idx = proxy_state.select_decoder(decoder_score) + decoder = proxy_state.decoders[decoder_idx] + # logger.debug("Using %s %s", prefiller.url, decoder.url) + # Stream response from decoder + released_kv = False async def generate_stream(): - nonlocal instance_info - generated_token = "" - released_kv = False - retry_count = 0 - retry = True - completion_tokens = 0 + nonlocal released_kv # Only one await per chunk, minimal logic in loop try: - while retry: - retry = False - async for chunk in stream_service_response_with_retry( - instance_info.decoder.client, - api, - req_data, - request_id=instance_info.request_id, - max_retries=global_args.max_retries, - base_delay=global_args.retry_delay): - if not released_kv and chunk: - proxy_state.release_prefiller_kv( - instance_info.prefiller_idx, - instance_info.prefiller_score) - released_kv = True - try: - chunk_str = chunk.decode("utf-8").strip() - except UnicodeDecodeError: - logger.debug( - f"Skipping chunk: {chunk}") - yield chunk - continue - if not chunk_str: - continue - if chunk_str.startswith("data: "): - chunk_str = chunk_str[len("data: "):] - try: - chunk_json = json.loads(chunk_str) - except json.JSONDecodeError: - # if chunk is [done], skip it. - logger.debug( - f"Skipping chunk: {chunk_str}") - yield chunk - continue - choices = chunk_json.get("choices", []) - if not choices: - yield chunk - continue - - choice = choices[0] - delta = choice.get("delta") or {} - message = choice.get("message") or {} - content = ( - delta.get("content") - or message.get("content") - or choice.get("text") - or "" - ) - generated_token += content - - stop_reason = choice.get( - "stop_reason") - usage = chunk_json.get("usage", {}) - completion_tokens = (completion_tokens + 1) if stream_flag else \ - (completion_tokens + usage.get("completion_tokens")) - if stop_reason == "recomputed": - retry = True - retry_count += 1 - if chat_flag: - messages[0][ - "content"] = origin_prompt + generated_token - else: - req_data[ - "prompt"] = origin_prompt + generated_token - req_data[ - "max_tokens"] = origin_max_tokens - completion_tokens + retry_count - tmp_request_length = len( - json.dumps(req_data).encode("utf-8")) - instance_info = await _handle_select_instance( - api, req_data, tmp_request_length) - break - if retry_count > 0 and not stream_flag: - if chat_flag: - choices[0]["message"][ - "content"] = generated_token - else: - choices[0]["text"] = generated_token - chunk = json.dumps(chunk_json).encode("utf-8") - yield chunk + async for chunk in stream_service_response_with_retry( + decoder.client, + api, + req_data, + request_id=request_id, + max_retries=global_args.max_retries, + base_delay=global_args.retry_delay): + yield chunk except Exception as e: logger.error( - f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it" + f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it" ) - proxy_state.abort_prefiller_request( - instance_info.prefiller_idx, instance_info.request_id) - proxy_state.release_prefiller_kv(instance_info.prefiller_idx, - instance_info.prefiller_score) # After streaming done, release tokens - proxy_state.release_decoder(instance_info.decoder_idx, - instance_info.decoder_score) + proxy_state.release_decoder(decoder_idx, decoder_score) return StreamingResponse(generate_stream(), media_type="application/json") @@ -669,11 +542,33 @@ async def healthcheck(): @app.post("/v1/metaserver") async def metaserver(request: Request): try: - req_data = await request.json() - request_id = req_data.pop("request_id", None) - if request_id in proxy_state.req_id_future: - result_future = proxy_state.req_id_future[request_id] - result_future.set_result(req_data) + kv_transfer_params = await request.json() + + request_id = kv_transfer_params["request_id"] + assert request_id in proxy_state.req_data_dict + req_data, request_length, api = proxy_state.req_data_dict[request_id] + request_id = get_origin_request_id(api, request_id) + req_data["kv_transfer_params"] = kv_transfer_params + prefiller_score = proxy_state.calculate_prefill_scores(request_length) + logger.debug( + f"Request length: {request_length}, Prefiller score: {prefiller_score}" + ) + + # Select prefiller + prefiller_idx = proxy_state.select_prefiller(prefiller_score) + prefiller = proxy_state.prefillers[prefiller_idx] + logger.debug(f"Using prefill {prefiller.url=} {req_data=}") + # Send request to prefiller + response = await send_request_to_service( + prefiller.client, + prefiller_idx, + api, + req_data, + request_id, + max_retries=global_args.max_retries, + base_delay=global_args.retry_delay) + proxy_state.release_prefiller(prefiller_idx, prefiller_score) + except Exception as e: logger.error(f"Post metaserver failed with: {str(e)}") @@ -682,5 +577,4 @@ if __name__ == '__main__': global global_args global_args = parse_args() import uvicorn - uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index 9a4084d..8226073 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -1136,4 +1136,4 @@ class TestMooncakeConnectorWorker(unittest.TestCase): if __name__ == '__main__': - unittest.main() + unittest.main() \ No newline at end of file diff --git a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py index ae9ff04..b454282 100644 --- a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py +++ b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py @@ -1,7 +1,6 @@ import os import sys import threading -import time import types import unittest from types import SimpleNamespace @@ -10,363 +9,355 @@ from unittest.mock import MagicMock, patch import torch import zmq +# fake mooncake.engine.TransferEngine fake_engine = types.ModuleType("mooncake.engine") fake_engine.TransferEngine = MagicMock() # type: ignore[attr-defined] sys.modules["mooncake.engine"] = fake_engine from vllm_ascend.distributed.mooncake_layerwise_connector import ( # noqa: E402 - DecodeMooncakeAgentMetadata, KVCacheRecvingLayerThread, - KVCacheSendingLayerThread, KVCacheTaskTracker, KVConnectorRole, - MooncakeLayerwiseConnector, MooncakeLayerwiseConnectorMetadata, - MooncakeLayerwiseConnectorScheduler, MooncakeLayerwiseConnectorWorker, - ReqMeta, SendingLayerThread, ensure_zmq_recv, ensure_zmq_send, - group_concurrent_contiguous, string_to_int64_hash, zmq_ctx) + KVCacheRecvingLayerThread, KVCacheSendingLayerThread, KVConnectorRole, + MooncakeAgentMetadata, MooncakeLayerwiseConnector, + MooncakeLayerwiseConnectorMetadata, MooncakeLayerwiseConnectorScheduler, + MooncakeLayerwiseConnectorWorker, ReqMeta, ensure_zmq_recv, + ensure_zmq_send, group_concurrent_contiguous, string_to_int64_hash, + zmq_ctx) GET_META_MSG = b"get_meta_msg" -DONE_RECVING_MSG = b"done_recving_msg" +DONE_SENDING_MSG = b"done_sending_msg" -class TestKVCacheTaskTrackerInit(unittest.TestCase): - - def test_init_basic_properties(self): - tracker = KVCacheTaskTracker() - self.assertIsInstance(tracker.done_task_lock, type(threading.Lock())) - self.assertIsInstance(tracker.finished_requests, set) - self.assertIsInstance(tracker.delayed_free_requests, dict) - - -class TestGetAndClearFinishedSingleRequests(unittest.TestCase): +class TestKVCacheSendingLayerThread(unittest.TestCase): def setUp(self): - self.tracker = KVCacheTaskTracker() - self.tracker.finished_requests = set() - self.tracker.done_task_lock = threading.Lock() - - def test_empty_requests(self): - result = self.tracker.get_and_clear_finished_requests() - self.assertEqual(result, set()) - self.assertEqual(len(self.tracker.finished_requests), 0) - - def test_single_request(self): - self.tracker.finished_requests = {"req_123"} - result = self.tracker.get_and_clear_finished_requests() - self.assertEqual(result, {"req_123"}) - self.assertEqual(len(self.tracker.finished_requests), 0) - - def test_multiple_requests(self): - self.tracker.finished_requests = {"req_1", "req_2", "req_3"} - result = self.tracker.get_and_clear_finished_requests() - self.assertSetEqual(result, {"req_1", "req_2", "req_3"}) - self.assertEqual(len(self.tracker.finished_requests), 0) - - @patch("vllm_ascend.distributed.mooncake_layerwise_connector.logger") - def test_concurrent_access(self, mock_logger): - from concurrent.futures import ThreadPoolExecutor - self.tracker.finished_requests = {"req_1", "req_2"} - with ThreadPoolExecutor(max_workers=3) as executor: - futures = [ - executor.submit(self.tracker.get_and_clear_finished_requests) - for _ in range(3) - ] - results = [f.result() for f in futures] - self.assertEqual(sum(1 for r in results if r), 1) - self.assertEqual(len(self.tracker.finished_requests), 0) - - -class TestKVCacheSendingLayerThreadBasic(unittest.TestCase): - - def setUp(self): - self.p1 = patch( - 'vllm_ascend.distributed.mooncake_layerwise_connector.get_ascend_config', - new=MagicMock(return_value=SimpleNamespace( - pd_tp_ratio=1, num_head_replica=1, pd_head_ratio=1))) - self.p2 = patch( - 'vllm_ascend.distributed.mooncake_layerwise_connector.get_current_vllm_config', - new=MagicMock(return_value=SimpleNamespace( - scheduler_config=SimpleNamespace(max_model_len=128)))) - self.p1.start() - self.addCleanup(self.p1.stop) - self.p2.start() - self.addCleanup(self.p2.stop) self.engine = MagicMock() self.engine.register_memory.return_value = 0 - self.ready_event = threading.Event() - - batch_size, seq_len, hidden_dim, num_heads = 8, 128, 512, 8 - head_dim = hidden_dim // num_heads - self.first_kv_cache = torch.zeros( - (batch_size, num_heads, seq_len, head_dim), - dtype=torch.float32, - device='cpu') - - self.thread = KVCacheSendingLayerThread( - tp_rank=0, - tp_size=4, - decode_tp_size=2, - local_engine_id="local_engine", - side_channel_host="localhost", - side_channel_port=5555, - metadata=MagicMock(), - ready_event=self.ready_event, - total_layers=3, - engine=self.engine, - local_kv_base_addr=[0x1000, 0x2000], - block_len=[1024, 2048], - use_mla=True, - first_kv_cache=self.first_kv_cache) - - def test_add_request(self): - req_id = "req1" - meta = DecodeMooncakeAgentMetadata( - req_id=req_id, - block_ids=[3, 4], - host="localhost", - port=6666, - engine_id="remote_engine", - te_rpc_port=6000, - kv_caches_base_addr=[0x3000, 0x4000], - num_blocks=8) - with self.thread.lock: - self.thread.ready_decode[req_id] = meta - - local_block_ids = [1, 2] - key = torch.zeros((1, 1), dtype=torch.float32) - value = torch.zeros((1, 1), dtype=torch.float32) - - self.thread.add_request(request_id=req_id, - local_block_ids=local_block_ids, - layer_index=5, - key=key, - value=value) - - queued = self.thread.send_layer_thread.send_queue.get_nowait() - # queued: (metadata, request_id, local_block_ids, layer_index, key, value) - self.assertEqual(queued[1], "req1") - self.assertEqual(queued[0].host, "localhost") - - @patch.object(KVCacheTaskTracker, 'get_and_clear_finished_requests') - def test_get_finished_requests(self, mock_tracker): - mock_tracker.return_value = {"req1", "req2"} - result = self.thread.get_and_clear_finished_requests() - self.assertEqual(result, {"req1", "req2"}) - - @patch.object(KVCacheTaskTracker, 'add_delayed_request') - def test_add_delayed_request_passthrough(self, mock_add): - mock_add.return_value = None - ret = self.thread.add_delayed_request("req1", 123.456) - mock_add.assert_called_once_with("req1", 123.456) - self.assertIsNone(ret) - - def test_abort_requests_removes_pending(self): - with self.thread.lock: - self.thread.pending_decode["keep"] = [([9], 1)] - self.thread.pending_decode["dropA"] = [([1], 0)] - self.thread.pending_decode["dropB"] = [([2], 0)] - - self.thread._abort_requests({"dropA", "dropB"}) - - with self.thread.lock: - self.assertNotIn("dropA", self.thread.pending_decode) - self.assertNotIn("dropB", self.thread.pending_decode) - self.assertIn("keep", self.thread.pending_decode) - - @patch('vllm_ascend.distributed.mooncake_layerwise_connector.zmq.Context') - @patch( - 'vllm_ascend.distributed.mooncake_layerwise_connector.make_zmq_socket') - @patch( - 'vllm_ascend.distributed.mooncake_layerwise_connector.ensure_zmq_send') - def test_post_transfer_sends_and_receives_ack(self, mock_send, - mock_make_socket, - mock_context): - req_id = "req_ok" - meta = DecodeMooncakeAgentMetadata( - req_id=req_id, - block_ids=[1], - host="127.0.0.1", - port=7777, - engine_id="remote", - te_rpc_port=6000, - kv_caches_base_addr=[0x1], - num_blocks=1, - ) - with self.thread.lock: - self.thread.ready_decode[req_id] = meta - - fake_sock = MagicMock() - fake_sock.recv.return_value = b"ACK" - mock_make_socket.return_value = fake_sock - - self.thread._post_transfer(req_id) - - self.assertTrue(mock_make_socket.called) - _, kwargs = mock_make_socket.call_args - self.assertEqual(kwargs.get('path'), 'tcp://127.0.0.1:7777') - self.assertEqual(kwargs.get('socket_type'), zmq.REQ) # type: ignore - self.assertFalse(kwargs.get('bind', True)) - - mock_send.assert_called_once() - with self.thread.lock: - self.assertNotIn(req_id, self.thread.ready_decode) - - @patch('vllm_ascend.distributed.mooncake_layerwise_connector.zmq.Context') - @patch( - 'vllm_ascend.distributed.mooncake_layerwise_connector.make_zmq_socket') - @patch( - 'vllm_ascend.distributed.mooncake_layerwise_connector.ensure_zmq_send') - def test_post_transfer_bad_ack_raises_value_error(self, _mock_send, - mock_make_socket, - _mock_context): - req_id = "req_bad" - meta = DecodeMooncakeAgentMetadata( - req_id=req_id, - block_ids=[1], - host="127.0.0.1", - port=8888, - engine_id="remote", - te_rpc_port=6000, - kv_caches_base_addr=[0x2], - num_blocks=1, - ) - with self.thread.lock: - self.thread.ready_decode[req_id] = meta - - fake_sock = MagicMock() - fake_sock.recv.return_value = b"NOT_ACK" - mock_make_socket.return_value = fake_sock - - with self.assertRaises(ValueError): - self.thread._post_transfer(req_id) - - -class TestSendingLayerThread(unittest.TestCase): - - def setUp(self): - self.p1 = patch( - 'vllm_ascend.distributed.mooncake_layerwise_connector.get_ascend_config', - new=MagicMock(return_value=SimpleNamespace( - pd_tp_ratio=1, num_head_replica=1, pd_head_ratio=1))) - self.p2 = patch( - 'vllm_ascend.distributed.mooncake_layerwise_connector.get_current_vllm_config', - new=MagicMock(return_value=SimpleNamespace( - scheduler_config=SimpleNamespace(max_model_len=128)))) - self.p1.start() - self.addCleanup(self.p1.stop) - self.p2.start() - self.addCleanup(self.p2.stop) - self.task_tracker = MagicMock(KVCacheTaskTracker) - self.engine = MagicMock() - self.engine.register_memory.side_effect = lambda addr, size: 0 - batch_size = 8 - seq_len = 128 - hidden_dim = 512 - num_heads = 8 - head_dim = hidden_dim // num_heads # 512 // 8 = 64 - self.first_kv_cache = torch.zeros( - (batch_size, num_heads, seq_len, head_dim), - dtype=torch.float32, - device='cpu') - self.thread = SendingLayerThread( - task_tracker=self.task_tracker, - total_layers=3, - engine=self.engine, - local_kv_base_addr=["0x1000", "0x2000"], - block_len=[1024, 2048], - use_mla=True, - tp_rank=0, - first_kv_cache=self.first_kv_cache) - - @patch.object(SendingLayerThread, "_transfer_kv_cache", autospec=True) - def test_handle_request(self, mock_transfer): - req_id = "req_1" - req_meta = MagicMock(spec=DecodeMooncakeAgentMetadata) - key = torch.zeros((1, 1), dtype=torch.float32) - value = torch.zeros((1, 1), dtype=torch.float32) - item = (req_meta, req_id, [10, 11], 0, key, value) - with patch.object(self.thread.task_tracker, "update_done_task_count") as mock_update_done, \ - patch.object(self.thread.send_queue, "task_done", autospec=True) as mock_task_done: - self.thread._handle_request(item) - mock_transfer.assert_called_once_with(self.thread, req_meta, [10, 11], - 0, key, value) - mock_update_done.assert_called_once_with(req_id) - mock_task_done.assert_called_once() - - @patch('torch.npu.synchronize') - @patch( - 'vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous' - ) - def test_transfer_kv_cache(self, mock_group, mock_sync): - key = torch.zeros((1, 1), dtype=torch.float32) - value = torch.zeros((1, 1), dtype=torch.float32) - mock_sync.return_value = None - self.thread.pd_tp_ratio = 1 - - self.thread.local_kv_base_addr = [1000, 2000] - - meta = DecodeMooncakeAgentMetadata( - req_id="req-ok", - block_ids=[0], - host="127.0.0.1", - port=7777, - engine_id="remote", - te_rpc_port=6000, - kv_caches_base_addr=[4000, 8000], - num_blocks=256, - ) - - mock_group.return_value = ( - [[10, 11, 12], [20, 21]], # grouped_remote_block_ids - [[5, 6, 7], [8, 9]], # grouped_local_block_ids - ) - self.engine.batch_transfer_sync_write.return_value = 1 - self.thread._transfer_kv_cache(meta, - local_block_ids=[123], - layer_index=0, + self.first_kv_cache = torch.zeros((2, 2, 2, 8), + dtype=torch.float32, + device="cpu") + + self.ready_event = threading.Event() + + self.thread = KVCacheSendingLayerThread( + engine=self.engine, + total_layers=3, + ready_event=self.ready_event, + tp_rank=0, + pd_head_ratio=1, + num_head_replica=1, + kv_cache_base_addr=[1000, 2000, 3000, 4000, 5000, + 6000], # 2 * total_layers + use_mla=True, + block_len=[1024, 2048], + first_kv_cache=self.first_kv_cache, + callback_func=MagicMock()) + + self.req_meta_base = ReqMeta( + local_block_ids=[5, 8], + token_ids=[1, 2, 3], + remote_block_ids=[10, 20], + remote_engine_id="remote_engine", + remote_host="127.0.0.1", + remote_port=7777, + remote_te_rpc_port=6000, + remote_kv_caches_base_addr=[4000, 8000, 14000, 18000], + metaserver="http://dummy") + + @patch( + "vllm_ascend.distributed.mooncake_layerwise_connector.torch.Tensor.data_ptr", + autospec=True, + return_value=0x200000) + @patch("vllm_ascend.distributed.mooncake_layerwise_connector.align_memory", + side_effect=lambda x, _align: x) + @patch( + "vllm_ascend.distributed.mooncake_layerwise_connector.torch.npu.synchronize" + ) + @patch( + "vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous" + ) + def test_transfer_pd_gt1_uses_buffers_and_calls_engine( + self, mock_group, _mock_sync, _mock_align, _mock_dataptr): + + thread = KVCacheSendingLayerThread( + engine=self.engine, + total_layers=2, + ready_event=self.ready_event, + tp_rank=0, + pd_head_ratio=2, + num_head_replica=1, + kv_cache_base_addr=[1111, 2222, 3333, 4444], + use_mla=False, + block_len=[64], + first_kv_cache=self.first_kv_cache, + callback_func=MagicMock()) + + req_meta = self.req_meta_base + req_meta.remote_kv_caches_base_addr = [4000, 8000] + + mock_group.return_value = ([[10, 11], [20, 21]], []) + + cap = self.first_kv_cache.numel() // self.first_kv_cache.shape[-1] + dim = self.first_kv_cache.shape[-1] + + key = torch.zeros((cap, dim), dtype=torch.float32) + value = torch.zeros((cap, dim), dtype=torch.float32) + + thread._transfer_kv_cache(req_id="req1", + req_meta=req_meta, + layer_index=0, + key=key, + value=value) + + self.engine.batch_transfer_sync_write.assert_called_once() + session_id, src_list, dst_list, length_list = self.engine.batch_transfer_sync_write.call_args[ + 0] + self.assertEqual(session_id, "127.0.0.1:6000") + + self.assertEqual(len(src_list), 4) + self.assertEqual(len(dst_list), 4) + self.assertEqual(len(length_list), 4) + + for L in length_list: + self.assertGreater(L, 0) + self.assertEqual(L % 64, 0) + + remote_block_len = 64 * 2 # 128 + expected_offsets = [10 * remote_block_len, 20 * remote_block_len] + self.assertEqual(dst_list[0] - 4000, expected_offsets[0]) # K, group1 + self.assertEqual(dst_list[1] - 4000, expected_offsets[1]) # K, group2 + self.assertEqual(dst_list[2] - 8000, expected_offsets[0]) # V, group1 + self.assertEqual(dst_list[3] - 8000, expected_offsets[1]) # V, group2) + + def test_transfer_skips_when_no_local_blocks(self): + req_meta = self.req_meta_base + req_meta.local_block_ids = [] + self.thread._transfer_kv_cache("req2", req_meta, 0, torch.zeros( + (1, 8)), torch.zeros((1, 8))) + self.engine.batch_transfer_sync_write.assert_not_called() + + def test_transfer_skips_when_tp_not_sender(self): + + thread = KVCacheSendingLayerThread(engine=self.engine, + total_layers=2, + ready_event=self.ready_event, + tp_rank=1, + pd_head_ratio=1, + num_head_replica=2, + kv_cache_base_addr=[1000, 2000], + use_mla=False, + block_len=[1024], + first_kv_cache=self.first_kv_cache, + callback_func=MagicMock()) + req_meta = self.req_meta_base + thread._transfer_kv_cache("req3", req_meta, 0, torch.zeros((1, 8)), + torch.zeros((1, 8))) + self.engine.batch_transfer_sync_write.assert_not_called() + + @patch( + "vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous", + side_effect=group_concurrent_contiguous) + @patch( + "vllm_ascend.distributed.mooncake_layerwise_connector.torch.npu.synchronize" + ) + def test_callback_invoked_on_final_layer(self, _mock_sync, _mock_group): + + req_meta = self.req_meta_base + req_meta.local_block_ids = [5, 6] + req_meta.remote_block_ids = [10, 11] + + req_meta.remote_kv_caches_base_addr = [ + 7000, 8000, 9000, 10000, 11000, 12000 + ] + + key = torch.zeros((1, 8), dtype=torch.float32) + value = torch.zeros((1, 8), dtype=torch.float32) + + self.thread._transfer_kv_cache("req5", + req_meta, + layer_index=2, key=key, value=value) - # k=0 (block_len=1024): - # grp1: src=1000+5*1024=6120, dst=4000+10*1024=14240, len=3*1024=3072 - # grp2: src=1000+8*1024=9192, dst=4000+20*1024=24480, len=2*1024=2048 - # k=1 (block_len=2048): - # grp1: src=2000+5*2048=12240, dst=8000+10*2048=28480, len=3*2048=6144 - # grp2: src=2000+8*2048=18384, dst=8000+20*2048=48960, len=2*2048=4096 - exp_session = "127.0.0.1:6000" - exp_src = [6120, 9192, 12240, 18384] - exp_dst = [14240, 24480, 28480, 48960] - exp_len = [3072, 2048, 6144, 4096] - - self.engine.batch_transfer_sync_write.assert_called_once() - args, _ = self.engine.batch_transfer_sync_write.call_args - self.assertEqual(args[0], exp_session) - self.assertEqual(args[1], exp_src) - self.assertEqual(args[2], exp_dst) - self.assertEqual(args[3], exp_len) + self.thread.callback_func.assert_called_once() -class TestKVCacheRecvingLayerThreadBasic(unittest.TestCase): +class TestKVCacheRecvingLayerThread(unittest.TestCase): def setUp(self): + + self.meta = MooncakeAgentMetadata(te_rpc_port=6000, + kv_caches_base_addr=[0x1, 0x2]) self.ready_event = threading.Event() - self.thread = KVCacheRecvingLayerThread( - tp_rank=0, - side_channel_port=5555, - tp_size=4, - local_engine_id="local_engine", - ready_event=self.ready_event, - ) - def test_get_finished_requests(self): + def test_get_and_clear_finished_requests(self): + th = KVCacheRecvingLayerThread(tp_rank=0, + side_channel_port=5555, + tp_size=2, + pd_head_ratio=1, + local_engine_id="engineA", + metadata=self.meta, + ready_event=self.ready_event) - with self.thread.lock: - self.thread.done_requests.update({"req1", "req2"}) + with th.lock: + th.done_requests.update({"r1", "r2"}) + got = th.get_and_clear_finished_requests() + self.assertEqual(got, {"r1", "r2"}) - result = self.thread.get_and_clear_finished_requests() - self.assertEqual(result, {"req1", "req2"}) + got2 = th.get_and_clear_finished_requests() + self.assertEqual(got2, set()) - result2 = self.thread.get_and_clear_finished_requests() - self.assertEqual(result2, set()) + def test_update_task_aggregates_by_pd_head_ratio(self): + th = KVCacheRecvingLayerThread(tp_rank=0, + side_channel_port=5555, + tp_size=2, + pd_head_ratio=2, + local_engine_id="engineA", + metadata=self.meta, + ready_event=self.ready_event) + + with th.lock: + th.task_tracker["reqX"] = 0 + + th.update_task("reqX") + with th.lock: + self.assertIn("reqX", th.task_tracker) + self.assertNotIn("reqX", th.done_requests) + + th.update_task("reqX") + with th.lock: + self.assertNotIn("reqX", th.task_tracker) + self.assertIn("reqX", th.done_requests) + + @patch("vllm_ascend.distributed.mooncake_layerwise_connector.logger") + @patch("vllm_ascend.distributed.mooncake_layerwise_connector.get_ip", + return_value="127.0.0.1") + @patch( + "vllm_ascend.distributed.mooncake_layerwise_connector.make_zmq_socket") + @patch( + "vllm_ascend.distributed.mooncake_layerwise_connector.make_zmq_path", + side_effect=lambda proto, host, port: f"{proto}://{host}:{port}") + @patch( + "vllm_ascend.distributed.mooncake_layerwise_connector.msgspec.msgpack.Decoder" + ) + @patch( + "vllm_ascend.distributed.mooncake_layerwise_connector.msgspec.msgpack.Encoder" + ) + @patch("vllm_ascend.distributed.mooncake_layerwise_connector.zmq_ctx") + def test_run_loop_handles_meta_done_invalid_unexpected_and_ack( + self, mock_zmq_ctx, mock_Encoder, mock_Decoder, _mock_make_path, + _mock_make_sock, _mock_get_ip, mock_logger): + + enc_inst = MagicMock() + enc_inst.encode.return_value = b"ENCODED_META" + mock_Encoder.return_value = enc_inst + + dec_inst = MagicMock() + dec_inst.decode.side_effect = [ + (GET_META_MSG, ), + (DONE_SENDING_MSG, "reqA"), + (b"weird_msg", ), + ] + mock_Decoder.return_value = dec_inst + + sock = MagicMock() + + sock.recv_multipart.side_effect = [ + [b"ID", b"SOME_PAYLOAD"], + [b"ID", b"SOME_PAYLOAD2"], + [b"ONLY_ID"], # invalid + [b"ID", b"SOME_PAYLOAD3"], + SystemExit, + ] + + cm = MagicMock() + cm.__enter__.return_value = sock + mock_zmq_ctx.return_value = cm + + ready_event = threading.Event() + th = KVCacheRecvingLayerThread(tp_rank=1, + side_channel_port=6000, + tp_size=2, + pd_head_ratio=1, + local_engine_id="engineZ", + metadata=self.meta, + ready_event=ready_event) + + with th.lock: + th.task_tracker["reqA"] = 0 + + with self.assertRaises(SystemExit): + th.run() + + self.assertTrue(ready_event.is_set()) + + self.assertGreaterEqual(sock.send_multipart.call_count, 2) + calls = [c.args for c in sock.send_multipart.call_args_list] + + meta_call = calls[0] + self.assertEqual(meta_call[0][0], b"ID") + self.assertEqual(meta_call[0][1], b"") + self.assertEqual(meta_call[0][2], b"ENCODED_META") + + ack_call = calls[1] + self.assertEqual(ack_call[0][0], b"ID") + self.assertEqual(ack_call[0][1], b"") + self.assertEqual(ack_call[0][2], b"ACK") + + self.assertTrue(mock_logger.error.called) + + finished = th.get_and_clear_finished_requests() + self.assertIn("reqA", finished) + + @patch("vllm_ascend.distributed.mooncake_layerwise_connector.logger") + @patch("vllm_ascend.distributed.mooncake_layerwise_connector.get_ip", + return_value="127.0.0.1") + @patch( + "vllm_ascend.distributed.mooncake_layerwise_connector.msgspec.msgpack.Decoder" + ) + @patch( + "vllm_ascend.distributed.mooncake_layerwise_connector.msgspec.msgpack.Encoder" + ) + @patch("vllm_ascend.distributed.mooncake_layerwise_connector.zmq_ctx") + def test_run_loop_pd_head_ratio_gt1_requires_multiple_done( + self, mock_zmq_ctx, mock_Encoder, mock_Decoder, _mock_get_ip, + _mock_logger): + + enc_inst = MagicMock() + enc_inst.encode.return_value = b"ENC" + mock_Encoder.return_value = enc_inst + + dec_inst = MagicMock() + dec_inst.decode.side_effect = [ + (DONE_SENDING_MSG, "reqB"), + (DONE_SENDING_MSG, "reqB"), + ] + mock_Decoder.return_value = dec_inst + + sock = MagicMock() + sock.recv_multipart.side_effect = [ + [b"ID", b"PAY1"], + [b"ID", b"PAY2"], + SystemExit, + ] + cm = MagicMock() + cm.__enter__.return_value = sock + mock_zmq_ctx.return_value = cm + + th = KVCacheRecvingLayerThread(tp_rank=0, + side_channel_port=5555, + tp_size=2, + pd_head_ratio=2, + local_engine_id="engineY", + metadata=self.meta, + ready_event=self.ready_event) + with th.lock: + th.task_tracker["reqB"] = 0 + with self.assertRaises(SystemExit): + th.run() + + finished = th.get_and_clear_finished_requests() + self.assertIn("reqB", finished) class MockVllmConfig: @@ -380,9 +371,14 @@ class MockVllmConfig: self.parallel_config.tensor_parallel_size = 2 self.parallel_config.data_parallel_rank_local = 0 self.parallel_config.data_parallel_size_local = 1 + self.parallel_config.data_parallel_size = 1 + self.parallel_config.data_parallel_rank = 0 self.cache_config.block_size = 16 + + self.kv_transfer_config.engine_id = "test_engine" self.kv_transfer_config.kv_port = 5000 - self.kv_transfer_config.kv_role = 'kv_producer' + self.kv_transfer_config.is_kv_producer = True + self.kv_transfer_config.is_kv_consumer = False self.kv_transfer_config.get_from_extra_config = MagicMock() self.kv_transfer_config.get_from_extra_config.side_effect = lambda k, d: { "prefill": { @@ -392,7 +388,8 @@ class MockVllmConfig: "decode": { "tp_size": 2, "dp_size": 1 - } + }, + "use_ascend_direct": True, }.get(k, d) @@ -409,58 +406,7 @@ class MockRequest: self.status = status or "running" self.output_token_ids = [101, 102] - -class TestKVCacheTaskTracker(unittest.TestCase): - - def setUp(self): - self.tracker = KVCacheTaskTracker() - - def test_update_done_task_count(self): - - self.assertEqual(len(self.tracker.finished_requests), 0) - self.assertEqual(len(self.tracker.delayed_free_requests), 0) - - current_time = time.time() - self.tracker.add_delayed_request("req_1", current_time) - - result = self.tracker.delayed_free_requests - self.assertEqual(len(result), 1) - self.assertIn("req_1", result) - self.assertEqual(result["req_1"], current_time) - - with patch.object(self.tracker, "on_done") as mock_on_done: - for _ in range(getattr(self.tracker, "target_count", 1)): - self.tracker.update_done_task_count("req_1") - mock_on_done.assert_called_once_with("req_1") - - self.assertEqual(self.tracker.finished_requests, {"req_1"}) - - result_delayed = self.tracker.delayed_free_requests - self.assertEqual(len(result_delayed), 1) - self.assertIn("req_1", result_delayed) - self.assertEqual(result_delayed["req_1"], current_time) - - def test_retrieve_expired_requests(self): - current_time = time.time() - self.tracker.add_delayed_request("req_1", current_time - 600) - self.tracker.add_delayed_request("req_2", current_time) - result = self.tracker._retrieve_expired_requests() - self.assertEqual(result, { - "req_1", - }) - result_delay = self.tracker.delayed_free_requests # dict - self.assertEqual(len(result_delay), 1) - - self.assertIn("req_2", result_delay) - self.assertEqual(result_delay["req_2"], current_time) - - def test_duplicate_task_update(self): - self.tracker.update_done_task_count("req1") - self.tracker.update_done_task_count("req1") - self.tracker.update_done_task_count("req1") - - finished = self.tracker.get_and_clear_finished_requests() - self.assertEqual(finished, {"req1"}) + self.all_token_ids = list(self.prompt_token_ids) class TestMooncakeLayerwiseConnectorMetadata(unittest.TestCase): @@ -468,7 +414,6 @@ class TestMooncakeLayerwiseConnectorMetadata(unittest.TestCase): def test_add_new_req(self): meta = MooncakeLayerwiseConnectorMetadata() self.assertEqual(len(meta.requests), 0) - self.assertEqual(len(meta.requests_to_send), 0) meta.add_new_req(request_id="req1", local_block_ids=[1, 2, 3], @@ -511,14 +456,13 @@ class TestMooncakeLayerwiseConnectorSchedulerMatchedTokens(unittest.TestCase): def test_build_connector_meta(self): request = MockRequest("req1") - blocks_mock = MagicMock() - blocks_mock.get_unhashed_block_ids.return_value = [4, 5, 6] - self.scheduler._reqs_need_recv["req1"] = (request, [4, 5, 6]) + + self.scheduler._reqs_need_recv["req1"] = (request, [], [4, 5, 6]) request.kv_transfer_params = { "remote_block_ids": [1, 2, 3], "remote_engine_id": "remote", "remote_host": "localhost", - "remote_port": 5000 + "remote_port": 5000, } meta = self.scheduler.build_connector_meta(MagicMock()) @@ -528,9 +472,161 @@ class TestMooncakeLayerwiseConnectorSchedulerMatchedTokens(unittest.TestCase): self.assertEqual(meta.requests["req1"].remote_block_ids, [1, 2, 3]) self.assertEqual(len(self.scheduler._reqs_need_recv), 0) - def test_get_finished_count(self): - count = self.scheduler.get_finished_count() - self.assertEqual(count, 2) + +class _MockBlocks: + + def __init__(self, unhashed, block_ids_tuple=None): + self._unhashed = list(unhashed) + self._block_ids_tuple = block_ids_tuple if block_ids_tuple is not None else ( + [1, 2], ) + + def get_unhashed_block_ids(self): + return list(self._unhashed) + + def get_block_ids(self): + + return self._block_ids_tuple + + +class _MockSchedulerOutput: + + def __init__(self, + cached_req_ids=None, + cached_new_block_ids=None, + cached_num_computed=None, + new_reqs=None, + num_sched=None): + self.scheduled_cached_reqs = SimpleNamespace( + req_ids=cached_req_ids or [], + new_block_ids=cached_new_block_ids or [], + num_computed_tokens=cached_num_computed or [], + ) + self.scheduled_new_reqs = new_reqs or [] + self.num_scheduled_tokens = num_sched or {} + + +class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase): + + def setUp(self): + self.config = MockVllmConfig() + self.scheduler = MooncakeLayerwiseConnectorScheduler( + self.config, "test_engine") + + def test_get_num_new_matched_tokens_with_prefill_block_aligned(self): + + req = MockRequest("req_prefill", + prompt_token_ids=list(range(32)), + kv_transfer_params={"do_remote_prefill": True}) + tokens, async_flag = self.scheduler.get_num_new_matched_tokens( + req, num_computed_tokens=16) + self.assertEqual(tokens, 16) + self.assertTrue(async_flag) + + def test_update_state_after_alloc_prefill_records_and_resets_flag(self): + req = MockRequest("req_u1", + prompt_token_ids=list(range(24)), + kv_transfer_params={"do_remote_prefill": True}) + blocks = _MockBlocks(unhashed=[4, 5, 6]) + + self.scheduler.update_state_after_alloc(req, + blocks, + num_external_tokens=8) + self.assertIn("req_u1", self.scheduler._reqs_need_recv) + record = self.scheduler._reqs_need_recv["req_u1"] + self.assertIs(record[0], req) + self.assertEqual(record[1], []) + self.assertEqual(record[2], [4, 5, 6]) + self.assertFalse(req.kv_transfer_params.get("do_remote_prefill", True)) + + def test_update_state_after_alloc_decode_records_send_layerwise(self): + req = MockRequest("req_u2", + prompt_token_ids=list(range(10)), + kv_transfer_params={"do_remote_decode": True}) + blocks = _MockBlocks(unhashed=[], block_ids_tuple=([7, 8, 9], )) + self.scheduler.update_state_after_alloc(req, + blocks, + num_external_tokens=0) + self.assertIn("req_u2", self.scheduler._reqs_need_send_layerwise) + total_tokens, local_block_ids, req_ref = self.scheduler._reqs_need_send_layerwise[ + "req_u2"] + self.assertEqual(total_tokens, 10) + self.assertEqual(local_block_ids, [7, 8, 9]) + self.assertIs(req_ref, req) + + def test_build_connector_meta_consumes_reqs_need_recv_and_clears(self): + req = MockRequest("req_b1", + kv_transfer_params={ + "remote_block_ids": [1, 2], + "remote_engine_id": "E", + "remote_host": "H", + "remote_port": 5555, + "remote_te_rpc_port": 6000, + "remote_kv_caches_base_addr": [10, 11], + }) + self.scheduler._reqs_need_recv["req_b1"] = (req, [], [100, 101]) + meta = self.scheduler.build_connector_meta(_MockSchedulerOutput()) + self.assertIsInstance(meta, MooncakeLayerwiseConnectorMetadata) + self.assertIn("req_b1", meta.requests) + self.assertEqual(meta.requests["req_b1"].local_block_ids, [100, 101]) + self.assertEqual(len(self.scheduler._reqs_need_recv), 0) + + def test_build_connector_meta_accumulates_cached_blocks(self): + req = MockRequest("req_b2", + prompt_token_ids=list(range(8)), + kv_transfer_params={"do_remote_decode": True}) + + self.scheduler._reqs_need_send_layerwise["req_b2"] = (8, [1, 2], req) + + out = _MockSchedulerOutput( + cached_req_ids=["req_b2"], + cached_new_block_ids=[([3, 4], )], + cached_num_computed=[4], + new_reqs=[], + num_sched={}, + ) + meta = self.scheduler.build_connector_meta(out) + self.assertEqual(len(meta.requests), 0) + total, block_ids, _ = self.scheduler._reqs_need_send_layerwise[ + "req_b2"] + self.assertEqual(total, 8) + self.assertEqual(block_ids, [1, 2, 3, 4]) + + def test_build_connector_meta_emits_when_tokens_reach_total(self): + + req = MockRequest("req_b3", + prompt_token_ids=list(range(12)), + kv_transfer_params={ + "do_remote_decode": True, + "remote_block_ids": [9], + "remote_engine_id": "E", + "remote_host": "H", + "remote_port": 5555, + "remote_te_rpc_port": 6000, + "remote_kv_caches_base_addr": [10, 11], + }) + self.scheduler._reqs_need_send_layerwise["req_b3"] = (12, [100, + 101], req) + + out = _MockSchedulerOutput( + cached_req_ids=["req_b3"], + cached_new_block_ids=[([50], )], + cached_num_computed=[8], + new_reqs=[SimpleNamespace(req_id="other", num_computed_tokens=0)], + num_sched={"req_b3": 4}, + ) + meta = self.scheduler.build_connector_meta(out) + self.assertIn("req_b3", meta.requests) + rmeta = meta.requests["req_b3"] + + self.assertEqual(rmeta.local_block_ids, [100, 101, 50]) + + self.assertNotIn("req_b3", self.scheduler._reqs_need_send_layerwise) + + def test_request_finished_returns_false_none(self): + ok, params = self.scheduler.request_finished(MockRequest("req_fin"), + [1, 2]) + self.assertFalse(ok) + self.assertIsNone(params) class TestHelperFunctions(unittest.TestCase): @@ -538,9 +634,7 @@ class TestHelperFunctions(unittest.TestCase): def test_group_concurrent_contiguous(self): src: list[int] = [1, 2, 3, 5, 6] dst: list[int] = [10, 11, 12, 14, 15] - src_groups, dst_groups = group_concurrent_contiguous(src, dst) - self.assertEqual(len(src_groups), 2) self.assertEqual(src_groups[0], [1, 2, 3]) self.assertEqual(src_groups[1], [5, 6]) @@ -562,6 +656,56 @@ class TestHelperFunctions(unittest.TestCase): hash3 = string_to_int64_hash("different_string") self.assertNotEqual(hash1, hash3) + def test_zmq_ctx_invalid_type(self): + with self.assertRaises(ValueError): + with zmq_ctx("INVALID", "tcp://127.0.0.1:5555"): + pass + + @patch( + "vllm_ascend.distributed.mooncake_layerwise_connector.make_zmq_socket") + def test_zmq_ctx_ok(self, mock_make_socket): + mock_socket = MagicMock() + mock_make_socket.return_value = mock_socket + with zmq_ctx(zmq.REQ, "tcp://localhost:1234") as s: # type: ignore + self.assertEqual(s, mock_socket) + + @patch("vllm_ascend.distributed.mooncake_layerwise_connector.logger") + def test_ensure_zmq_send_success(self, _): + mock_socket = MagicMock() + ensure_zmq_send(mock_socket, b"hello") + mock_socket.send.assert_called_once_with(b"hello") + + @patch("vllm_ascend.distributed.mooncake_layerwise_connector.logger") + def test_ensure_zmq_send_retry_and_fail(self, _): + mock_socket = MagicMock() + mock_socket.send.side_effect = zmq.ZMQError( # type: ignore + "send failed") + with self.assertRaises(RuntimeError): + ensure_zmq_send(mock_socket, b"hello", max_retries=2) + self.assertEqual(mock_socket.send.call_count, 2) + + @patch("vllm_ascend.distributed.mooncake_layerwise_connector.logger") + def test_ensure_zmq_recv_success(self, _): + mock_socket = MagicMock() + mock_socket.recv.return_value = b"response" + mock_poller = MagicMock() + mock_poller.poll.return_value = [ + (mock_socket, zmq.POLLIN) # type: ignore + ] + data = ensure_zmq_recv(mock_socket, mock_poller) + self.assertEqual(data, b"response") + + @patch("vllm_ascend.distributed.mooncake_layerwise_connector.logger") + def test_ensure_zmq_recv_timeout_and_fail(self, _): + mock_socket = MagicMock() + mock_poller = MagicMock() + mock_poller.poll.return_value = [] + with self.assertRaises(RuntimeError): + ensure_zmq_recv(mock_socket, + mock_poller, + timeout=0.01, + max_retries=2) + class TestMooncakeLayerwiseConnectorForScheduler(unittest.TestCase): @@ -645,220 +789,11 @@ class TestMooncakeLayerwiseConnector(unittest.TestCase): mock_method.assert_called_once_with(request, [1, 2, 3]) -class TestMooncakeLayerwiseConnectorScheduler(unittest.TestCase): - - def setUp(self): - self.config = MockVllmConfig() - self.scheduler = MooncakeLayerwiseConnectorScheduler( - self.config, "test_engine") - - def test_get_num_new_matched_tokens_no_remote_prefill(self): - request = MockRequest("req1") - tokens, async_flag = self.scheduler.get_num_new_matched_tokens( - request, 0) - self.assertEqual(tokens, 0) - self.assertFalse(async_flag) - - def test_get_num_new_matched_tokens_with_remote_prefill(self): - request = MockRequest("req1", - kv_transfer_params={"do_remote_prefill": True}) - tokens, async_flag = self.scheduler.get_num_new_matched_tokens( - request, 0) - self.assertEqual(tokens, 4) - self.assertTrue(async_flag) - - def test_update_state_after_alloc_no_remote_prefill(self): - request = MockRequest("req1") - blocks = MagicMock() - self.scheduler.update_state_after_alloc(request, blocks, 0) - self.assertEqual(len(self.scheduler._reqs_need_recv), 0) - - def test_update_state_after_alloc_with_remote_prefill(self): - request = MockRequest("req1", - kv_transfer_params={ - "do_remote_prefill": True, - "remote_block_ids": [1, 2, 3], - "remote_engine_id": "remote", - "remote_host": "localhost", - "remote_port": 5000 - }) - blocks = MockKVCacheBlocks() - self.scheduler.update_state_after_alloc(request, blocks, 3) - self.assertEqual(len(self.scheduler._reqs_need_recv), 1) - self.assertEqual(self.scheduler._reqs_need_recv["req1"][0], request) - self.assertEqual(self.scheduler._reqs_need_recv["req1"][1], [4, 5, 6]) - - def test_request_finished_no_remote_decode(self): - request = MockRequest("req1") - delay_free, params = self.scheduler.request_finished( - request, [1, 2, 3]) - self.assertFalse(delay_free) - self.assertIsNone(params) - - -class TestUtils(unittest.TestCase): - - def test_string_to_int64_hash(self): - h1 = string_to_int64_hash("hello") - h2 = string_to_int64_hash("hello") - h3 = string_to_int64_hash("world") - self.assertEqual(h1, h2) - self.assertNotEqual(h1, h3) - self.assertIsInstance(h1, int) - - def test_group_concurrent_contiguous(self): - src: list[int] = [1, 2, 3, 5, 6] - dst: list[int] = [10, 11, 12, 20, 21] - src_g, dst_g = group_concurrent_contiguous(src, dst) - self.assertEqual(src_g, [[1, 2, 3], [5, 6]]) - self.assertEqual(dst_g, [[10, 11, 12], [20, 21]]) - - def test_group_empty(self): - src_g, dst_g = group_concurrent_contiguous([], []) - self.assertEqual(src_g, []) - self.assertEqual(dst_g, []) - - def test_zmq_ctx_invalid_type(self): - with self.assertRaises(ValueError): - with zmq_ctx("INVALID", "tcp://127.0.0.1:5555"): - pass - - @patch( - "vllm_ascend.distributed.mooncake_layerwise_connector.make_zmq_socket") - def test_zmq_ctx_ok(self, mock_make_socket): - mock_socket = MagicMock() - mock_make_socket.return_value = mock_socket - with zmq_ctx(zmq.REQ, "tcp://localhost:1234") as s: # type: ignore - self.assertEqual(s, mock_socket) - - @patch("vllm_ascend.distributed.mooncake_layerwise_connector.logger") - def test_ensure_zmq_send_success(self, mock_logger): - mock_socket = MagicMock() - ensure_zmq_send(mock_socket, b"hello") - mock_socket.send.assert_called_once_with(b"hello") - - @patch("vllm_ascend.distributed.mooncake_layerwise_connector.logger") - def test_ensure_zmq_send_retry_and_fail(self, mock_logger): - mock_socket = MagicMock() - mock_socket.send.side_effect = zmq.ZMQError( # type: ignore - "send failed") - with self.assertRaises(RuntimeError): - ensure_zmq_send(mock_socket, b"hello", max_retries=2) - self.assertEqual(mock_socket.send.call_count, 2) - - @patch("vllm_ascend.distributed.mooncake_layerwise_connector.logger") - def test_ensure_zmq_recv_success(self, mock_logger): - mock_socket = MagicMock() - mock_socket.recv.return_value = b"response" - mock_poller = MagicMock() - mock_poller.poll.return_value = [ - (mock_socket, zmq.POLLIN) # type: ignore - ] - data = ensure_zmq_recv(mock_socket, mock_poller) - self.assertEqual(data, b"response") - - @patch("vllm_ascend.distributed.mooncake_layerwise_connector.logger") - def test_ensure_zmq_recv_timeout_and_fail(self, mock_logger): - mock_socket = MagicMock() - mock_poller = MagicMock() - mock_poller.poll.return_value = [] - with self.assertRaises(RuntimeError): - ensure_zmq_recv(mock_socket, - mock_poller, - timeout=0.01, - max_retries=2) - - -class MockMooncakeAgentMetadata: - - def __init__(self, **kwargs): - pass - - -class MockMooncakeLayerwiseConnectorMetadata: - - def __init__(self): - self.requests = {} - - -class MockKVCacheSendingThread(threading.Thread): - - def __init__(self, *args, **kwargs): - super().__init__() - self.daemon = True - self._finished_requests = set() - - def get_and_clear_finished_requests(self): - return self._finished_requests - - def start(self): - pass - - -class MockKVCacheRecvingThread(threading.Thread): - - def __init__(self, *args, **kwargs): - super().__init__() - self.daemon = True - self._finished_requests = set() - self.add_request = MagicMock() - - def get_and_clear_finished_requests(self): - return self._finished_requests - - def start(self): - pass - - -class MockTensor: - - def __init__(self, *args, **kwargs): - self.size = MagicMock(return_value=(10, 16, 8, 16)) - self.element_size = MagicMock(return_value=4) - self.shape = (10, 16, 8, 16) - self.data_ptr = MagicMock(return_value=0x1000) - - -mock_envs_ascend = MagicMock() -mock_envs_ascend.MOONCAKE_CONNECTOR_PROTOCOL = "mock_protocol" - -mock_logger = MagicMock() - - -class MockTransferEngine: - - def initialize(self, *args, **kwargs): - return 0 - - def register_memory(self, *args, **kwargs): - return 1 - - -class MockEnvsAscend: - MOONCAKE_CONNECTOR_PROTOCOL = "mock_protocol" - PHYSICAL_DEVICES = "10,11" - - -def mock_get_tensor_model_parallel_rank(): - return 0 - - -def mock_get_tp_group(): - return MagicMock() - - -def mock_get_ip(): - return "127.0.0.1" - - -def mock_string_to_int64_hash(s): - return hash(s) - - class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase): def setUp(self): - self.envs_ascend_mock = MockEnvsAscend() + self.envs_ascend_mock = type("MockEnvsAscend", (), + {"PHYSICAL_DEVICES": "10,11"})() self.mock_transfer_engine = MagicMock() self.mock_transfer_engine.get_rpc_port.return_value = 9090 self.mock_transfer_engine.initialize.return_value = 0 @@ -873,16 +808,16 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase): patch('random.Random'), patch( 'vllm_ascend.distributed.mooncake_layerwise_connector.get_tensor_model_parallel_rank', - mock_get_tensor_model_parallel_rank), + return_value=0), patch( 'vllm_ascend.distributed.mooncake_layerwise_connector.get_tp_group', - mock_get_tp_group), + return_value=None), patch( 'vllm_ascend.distributed.mooncake_layerwise_connector.get_ip', - mock_get_ip), + return_value="127.0.0.1"), patch( 'vllm_ascend.distributed.mooncake_layerwise_connector.string_to_int64_hash', - mock_string_to_int64_hash), + side_effect=lambda s: hash(s)), patch( 'vllm_ascend.distributed.mooncake_layerwise_connector.TransferEngine', return_value=self.mock_transfer_engine), @@ -904,13 +839,7 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase): 'vllm_ascend.distributed.mooncake_layerwise_connector.get_ascend_config', return_value=SimpleNamespace(pd_tp_ratio=1, num_head_replica=1, - pd_head_ratio=1), - ), - patch( - 'vllm_ascend.distributed.mooncake_layerwise_connector.get_current_vllm_config', - return_value=SimpleNamespace(scheduler_config=SimpleNamespace( - max_model_len=128)), - ) + pd_head_ratio=1)), ] for p in self.patches: @@ -925,12 +854,9 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase): p.stop() # type: ignore def test_worker_use_ascend_direct(self): - test_case = [True, False] - - for use_ascend_direct in test_case: + for use_ascend_direct in (True, False): with self.subTest(use_ascend_direct=use_ascend_direct): - config = MagicMock() - config.kv_transfer_config = MagicMock() + config = MockVllmConfig() config.kv_transfer_config.get_from_extra_config.side_effect = ( lambda k, d: { "prefill": { @@ -943,28 +869,14 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase): }, "use_ascend_direct": use_ascend_direct, }.get(k, d)) - - config.parallel_config = MagicMock() - config.parallel_config.tensor_parallel_size = 2 - config.parallel_config.data_parallel_rank_local = 0 - config.parallel_config.data_parallel_size_local = 1 - config.kv_transfer_config.kv_port = 8000 - config.kv_transfer_config.kv_role = 'worker' - - with patch( - "vllm_ascend.distributed.mooncake_layerwise_connector.get_tensor_model_parallel_rank", - return_value=0): - with patch( - "vllm_ascend.distributed.mooncake_layerwise_connector.get_tp_group", - return_value=None): - with patch( - "vllm_ascend.distributed.mooncake_layerwise_connector.get_ip", - return_value="127.0.0.1"): - worker = MooncakeLayerwiseConnectorWorker( - config, self.engine_id) - self.assertIsNotNone(worker) + worker = MooncakeLayerwiseConnectorWorker( + config, self.engine_id) + self.assertIsNotNone(worker) def test_register_kv_caches_producer(self): + + self.vllm_config.kv_transfer_config.is_kv_producer = True + self.vllm_config.kv_transfer_config.is_kv_consumer = False worker = MooncakeLayerwiseConnectorWorker(self.vllm_config, self.engine_id) worker.register_kv_caches(self.kv_caches) @@ -973,7 +885,9 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase): self.assertIsNone(worker.kv_recv_layer_thread) def test_register_kv_caches_consumer(self): - self.vllm_config.kv_transfer_config.kv_role = 'kv_consumer' + + self.vllm_config.kv_transfer_config.is_kv_producer = False + self.vllm_config.kv_transfer_config.is_kv_consumer = True worker = MooncakeLayerwiseConnectorWorker(self.vllm_config, self.engine_id) worker.register_kv_caches(self.kv_caches) @@ -986,7 +900,6 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase): mla_cache2 = MagicMock() mla_cache2.size.return_value = (10, 16, 1, 8) mla_caches = {"layer1": (mla_cache1, mla_cache2)} - worker = MooncakeLayerwiseConnectorWorker(self.vllm_config, self.engine_id) worker.register_kv_caches(mla_caches) @@ -994,12 +907,10 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase): self.assertEqual(len(worker.block_len), 2) def test_device_id_selection_with_physical_devices(self): - # Test with physical devices set worker = MooncakeLayerwiseConnectorWorker(self.vllm_config, self.engine_id) - # Default tp_rank is 0, so device_id should be 10 self.assertEqual(worker.device_id, 10) if __name__ == '__main__': - unittest.main() + unittest.main() \ No newline at end of file diff --git a/vllm_ascend/distributed/mooncake_layerwise_connector.py b/vllm_ascend/distributed/mooncake_layerwise_connector.py index c3fb6e1..e500f53 100644 --- a/vllm_ascend/distributed/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/mooncake_layerwise_connector.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 import contextlib +import copy import hashlib import math import queue -import random import struct import threading import time -from collections import defaultdict +from collections import defaultdict, deque from collections.abc import Iterator from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass @@ -20,15 +20,13 @@ import numpy.typing as npt import torch import zmq from mooncake.engine import TransferEngine # type: ignore -from vllm import envs -from vllm.config import VllmConfig, get_current_vllm_config +from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, get_tp_group, get_world_group) from vllm.utils import get_ip, logger, make_zmq_path, make_zmq_socket from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.request import RequestStatus import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config @@ -42,329 +40,125 @@ if TYPE_CHECKING: from vllm.v1.request import Request GET_META_MSG = b"get_meta_msg" -DONE_RECVING_MSG = b"done_recving_msg" +DONE_SENDING_MSG = b"done_sending_msg" class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True): - engine_id: str te_rpc_port: int kv_caches_base_addr: list[int] - num_blocks: int @dataclass class ReqMeta: local_block_ids: list[int] + token_ids: list[int] # Not None if layer-wise is disabled - remote_block_ids: Optional[list[int]] + remote_block_ids: list[int] + remote_engine_id: Optional[str] remote_host: Optional[str] remote_port: Optional[int] - remote_engine_id: Optional[str] - # Not None if layer-wise is enabled + remote_te_rpc_port: Optional[int] + remote_kv_caches_base_addr: Optional[list[int]] metaserver: Optional[str] - remote_tp_size: Optional[int] - - -class DecodeMooncakeAgentMetadata(msgspec.Struct, - omit_defaults=True, - dict=True): - req_id: str - block_ids: list[int] - host: str - port: int - engine_id: str - te_rpc_port: int - kv_caches_base_addr: list[int] - num_blocks: int - - -class KVCacheTaskTracker: - - def __init__(self, - target_count: int = 1, - on_done: Callable[[str], None] = lambda x: None, - on_timeout: Callable[[set[str]], Any] = lambda x: None): - super().__init__() - self.target_count = target_count - self.done_task_lock = threading.Lock() - self.done_task_counts: defaultdict[str, int] = defaultdict(int) - self.finished_requests: set[str] = set() - # Only used in prefill node. Tracks requests whose kv blocks freeing is - # intentionally delayed. Each entry is a tuple of (request_id, - # timestamp). If a request remains in this queue for too long, it will - # be force-freed. - # Notice: In layer-wise mode, the transfer may complete before it is - # added to delayed_free_requests when prefill node finishes forwarding. - # Therefore we need to track requests that are removed before being added. - self.delayed_free_requests: dict[str, float] = {} - self.removed_delayed_free_requests: set[str] = set() - self.on_done = on_done - self.on_timeout = on_timeout - - def update_done_task_count(self, request_id: str): - self.done_task_counts[request_id] += 1 - if self.done_task_counts[request_id] == self.target_count: - with self.done_task_lock: - self.finished_requests.add(request_id) - self.done_task_counts.pop(request_id) - self.on_done(request_id) - - def get_and_clear_finished_requests(self) -> set[str]: - """ - Get and clear the requests that have been completed. - Returns: - A set of request IDs that have been completed. - """ - with self.done_task_lock: - finished_requests = self.finished_requests.copy() - expired_requests = self._retrieve_expired_requests() - finished_requests.update(expired_requests) - self.finished_requests.clear() - self.on_timeout(expired_requests) - return finished_requests - - def add_delayed_request(self, request_id: str, delay_start_time: float): - """Add a delayed free request, where delay_start_time is monotonic increasing.""" - with self.done_task_lock: - if request_id in self.removed_delayed_free_requests: - self.removed_delayed_free_requests.remove(request_id) - else: - self.delayed_free_requests[request_id] = delay_start_time - - def _retrieve_expired_requests(self): - """Retrieve all expired delayed requests.""" - expired_requests: set[str] = set() - # Free delayed requests if they exceed the timeout - current_time = time.time() - while self.delayed_free_requests: - request_id, delay_start_time = next( - iter(self.delayed_free_requests.items())) - if (current_time - delay_start_time - > envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT): - self.delayed_free_requests.pop(request_id) - expired_requests.add(request_id) - logger.info("Force freed request: %s", request_id) - else: - break - return expired_requests - - def remove_delayed_request(self, request_id: str): - """Remove all delayed free requests matching the given request_id.""" - with self.done_task_lock: - if self.delayed_free_requests.pop(request_id, None) is None: - self.removed_delayed_free_requests.add(request_id) class KVCacheSendingLayerThread(threading.Thread): - def __init__(self, tp_rank: int, tp_size: int, decode_tp_size: int, - local_engine_id: str, side_channel_host: str, - side_channel_port: int, metadata: MooncakeAgentMetadata, - ready_event: threading.Event, total_layers: int, - engine: TransferEngine, local_kv_base_addr: list[int], - block_len: list[int], use_mla: bool, - first_kv_cache: torch.Tensor): + def __init__(self, + engine: TransferEngine, + total_layers: int, + ready_event: threading.Event, + tp_rank: int, + pd_head_ratio: int, + num_head_replica: int, + kv_cache_base_addr: list[int], + use_mla: bool, + block_len: list[int], + first_kv_cache: torch.Tensor, + callback_func: Callable[..., None] = lambda x: None): super().__init__(daemon=True, name="KVCacheSendingLayerThread") - self.tp_rank = tp_rank - self.tp_size = tp_size - self.decode_tp_size = decode_tp_size - self.local_engine_id = local_engine_id - self.side_channel_host = side_channel_host - self.side_channel_port = side_channel_port - self.task_tracker = KVCacheTaskTracker(total_layers, - on_done=self._post_transfer, - on_timeout=self._abort_requests) - self.send_layer_thread = SendingLayerThread( - self.task_tracker, total_layers, engine, local_kv_base_addr, - block_len, use_mla, self.tp_rank, first_kv_cache) - self.ready_decode = dict[str, DecodeMooncakeAgentMetadata]() - self.pending_decode = dict[str, - list[tuple[list[int], int, torch.Tensor, - torch.Tensor]]]() - self.total_layers = total_layers - self.lock = threading.Lock() - self.ready_event = ready_event - - def get_and_clear_finished_requests(self) -> set[str]: - """ - Get and clear the requests that have been completed. - Returns: - A set of request IDs that have been completed. - """ - # vllm won't call us if all inference is done, so we can't do step 9 here - return self.task_tracker.get_and_clear_finished_requests() - - def add_delayed_request(self, request_id: str, delay_start_time: float): - return self.task_tracker.add_delayed_request(request_id, - delay_start_time) - - def run(self): - """Run the thread to handle KV cache transfer requests.""" - self.send_layer_thread.start() - handshake_port = self.side_channel_port + self.tp_rank - path = make_zmq_path("tcp", self.side_channel_host, handshake_port) - logger.info("Starting listening on path: %s", path) - with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore - self.ready_event.set() - decoder = msgspec.msgpack.Decoder(type=DecodeMooncakeAgentMetadata) - while True: - try: - frames = sock.recv_multipart() - if len(frames) < 2: - logger.error("Invalid message format: %s", frames) - continue - - identity = frames[0] - payload = [f for f in frames[1:] if f != b""] - if len(payload) != 1: - logger.error("Invalid message format: %s", frames) - continue - - metadata = decoder.decode(payload[0]) - request_id = metadata.req_id - logger.debug( - f"Prefiller has received that request {request_id} from the decoder." - ) - sock.send_multipart((identity, b"", b"ACK")) - self.task_tracker.remove_delayed_request(request_id) - with self.lock: - self.ready_decode[request_id] = metadata - pending = self.pending_decode.pop(request_id, []) - for local_block_ids, layer_index, key, value in pending: - self.send_layer_thread.send_queue.put( - (metadata, request_id, local_block_ids, - layer_index, key, value)) - except Exception as e: - logger.error("Failed to decode message: %s", e) - - def _post_transfer(self, request_id: str): - with self.lock: - decoder_meta = self.ready_decode.pop(request_id) - path = make_zmq_path("tcp", decoder_meta.host, decoder_meta.port) - msg_encoder = msgspec.msgpack.Encoder() - encoded_data = msg_encoder.encode(request_id) - with zmq_ctx(zmq.REQ, path) as sock: # type: ignore - ensure_zmq_send(sock, encoded_data) - ack = sock.recv() - if ack != b"ACK": - raise ValueError(f"Unexpected ACK response: {ack}") - - def add_request(self, request_id: str, local_block_ids: list[int], - layer_index: int, key: torch.Tensor, value: torch.Tensor): - # add request to send layer thread - with self.lock: - if request_id in self.ready_decode: - self.send_layer_thread.send_queue.put( - (self.ready_decode[request_id], request_id, - local_block_ids, layer_index, key, value)) - else: - self.pending_decode.setdefault(request_id, []).append( - (local_block_ids, layer_index, key, value)) - - def _abort_requests(self, request_ids: set[str]): - with self.lock: - for request_id in request_ids: - self.pending_decode.pop(request_id, None) - - -class SendingLayerThread(threading.Thread): - - def __init__(self, task_tracker: KVCacheTaskTracker, total_layers: int, - engine: TransferEngine, local_kv_base_addr: list[int], - block_len: list[int], use_mla: bool, tp_rank: int, - first_kv_cache: torch.Tensor): - super().__init__(daemon=True, name="KVCacheRecvingPrefillerByeThread") - self.send_queue = queue.Queue[tuple[DecodeMooncakeAgentMetadata, str, - list[int], int, torch.Tensor, - torch.Tensor]]() - self.completion_event: Optional[threading.Event] = None - self.completion_event_count: int - self.task_tracker = task_tracker - self.total_layers = total_layers - self.local_kv_base_addr = local_kv_base_addr - self.block_len = block_len - self.use_mla = use_mla self.engine = engine self.tp_rank = tp_rank - self.pd_tp_ratio = get_ascend_config().pd_tp_ratio - self.num_head_replica = get_ascend_config().num_head_replica - self.pd_head_ratio = get_ascend_config().pd_head_ratio - vllm_config = get_current_vllm_config() - max_model_len = vllm_config.scheduler_config.max_model_len - first_kv_cache = first_kv_cache[:max_model_len] - alignment = 2 * 1024 * 1024 - self.k_buffer = torch.zeros( - first_kv_cache.numel() + alignment, - dtype=first_kv_cache.dtype, - device=first_kv_cache.device) # 【4,1,128】-》【1000, 128】 - self.k_buffer = align_memory(self.k_buffer, - alignment)[:first_kv_cache.numel()].view( - -1, first_kv_cache.shape[-1]) - self.v_buffer = torch.zeros(first_kv_cache.numel() + alignment, - dtype=first_kv_cache.dtype, - device=first_kv_cache.device) - self.v_buffer = align_memory(self.v_buffer, - alignment)[:first_kv_cache.numel()].view( - -1, first_kv_cache.shape[-1]) + self.pd_head_ratio = pd_head_ratio + self.num_head_replica = num_head_replica + self.kv_caches_base_addr = kv_cache_base_addr + self.total_layers = total_layers + self.use_mla = use_mla + self.block_len = block_len - for tensor in (self.k_buffer, self.v_buffer): - assert tensor.data_ptr( - ) % alignment == 0, "The address of the registered kv cache should be aligned to 2M" - ret_value = self.engine.register_memory(tensor.data_ptr(), - tensor.numel()) - logger.info( - f"Sendinglayerthread register_memory {tensor.data_ptr()} {tensor.numel()} {ret_value=}" - ) - if ret_value != 0: - raise RuntimeError("Mooncake memory registration failed. ") + if self.pd_head_ratio > 1: + # regesit kv buffer for tp inequal + alignment = 2 * 1024 * 1024 + self.k_buffer = torch.zeros(first_kv_cache.numel() + alignment, + dtype=first_kv_cache.dtype, + device=first_kv_cache.device) + self.k_buffer = align_memory( + self.k_buffer, alignment)[:first_kv_cache.numel()].view( + -1, first_kv_cache.shape[-1]) + self.v_buffer = torch.zeros(first_kv_cache.numel() + alignment, + dtype=first_kv_cache.dtype, + device=first_kv_cache.device) + self.v_buffer = align_memory( + self.v_buffer, alignment)[:first_kv_cache.numel()].view( + -1, first_kv_cache.shape[-1]) + + for tensor in (self.k_buffer, self.v_buffer): + assert tensor.data_ptr( + ) % alignment == 0, "The address of the registered kv cache should be aligned to 2M" + ret_value = self.engine.register_memory( + tensor.data_ptr(), tensor.numel()) + logger.info( + f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} {tensor.numel()} {ret_value=}" + ) + if ret_value != 0: + raise RuntimeError("Mooncake memory registration failed. ") + + self.send_queue = queue.Queue[Tuple[str, ReqMeta, int, torch.Tensor, + torch.Tensor]]() + + self.ready_event = ready_event + self.callback_func = callback_func def run(self): - """Run the thread to handle KV cache receiving for prefiller bye messages.""" - # send kv cache for request in send_queue local_rank = get_world_group().local_rank device = torch.device(f"npu:{local_rank}") torch.npu.set_device(device) + self.ready_event.set() while True: - request = self.send_queue.get() - self._handle_request(request) - - def _handle_request(self, request: tuple[DecodeMooncakeAgentMetadata, str, - list[int], int, torch.Tensor, - torch.Tensor]): - # send kv layer to remote - req_meta, request_id, local_block_ids, layer_index, key, value = request + req_id, req_meta, layer_index, key, value = self.send_queue.get() + self._handle_request(req_id, req_meta, layer_index, key, value) + def _handle_request(self, req_id, req_meta, layer_index, key, value): try: logger.debug( - f"Starting to transfer KV cache for request {request_id}.") - self._transfer_kv_cache(req_meta, local_block_ids, layer_index, - key, value) + f"Starting to transfer KV cache for request {req_id} {req_meta.remote_te_rpc_port=}." + ) + self._transfer_kv_cache(req_id, req_meta, layer_index, key, value) logger.debug( - f"Finished transferring KV cache for request {request_id}.") + f"Finished transferring KV cache for request {req_id} {req_meta.remote_te_rpc_port=}." + ) except Exception as e: logger.error("Failed to transfer KV cache for request " - f"{request_id}: {e}") - finally: - self.task_tracker.update_done_task_count(request_id) - self.send_queue.task_done() + f"{req_id}: {e}") - def _transfer_kv_cache(self, req_meta: DecodeMooncakeAgentMetadata, - local_block_ids: list[int], layer_index: int, key, - value): + def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value): # send kv layer to remote - if len(local_block_ids) == 0: + if len(req_meta.local_block_ids) == 0: + return + # not need to send kv cache + if self.tp_rank % self.num_head_replica != 0: return - remote_host = req_meta.host - remote_te_port = req_meta.te_rpc_port - remote_kv_base_addrs = req_meta.kv_caches_base_addr + remote_host = req_meta.remote_host + remote_block_ids = req_meta.remote_block_ids + remote_te_port = req_meta.remote_te_rpc_port + remote_kv_base_addrs = req_meta.remote_kv_caches_base_addr + local_kv_base_addr = self.kv_caches_base_addr + local_block_ids = req_meta.local_block_ids - remote_block_ids = req_meta.block_ids - if self.tp_rank % self.num_head_replica != 0: - pass - elif self.pd_head_ratio == 1: + if self.pd_head_ratio == 1: layer_local_kv_base_addr = [ - self.local_kv_base_addr[i] + local_kv_base_addr[i] for i in [2 * layer_index, 2 * layer_index + 1] ] layer_remote_kv_base_addr = [ @@ -393,10 +187,8 @@ class SendingLayerThread(threading.Thread): torch.npu.synchronize() ret = self.engine.batch_transfer_sync_write( session_id, src_list, dst_list, length_list) - if ret < 0: - logger.error("Mooncake transfer failed for request %s", - req_meta.req_id) + logger.error("Mooncake transfer failed for request %s", req_id) raise RuntimeError(f"Mooncake transfer failed, ret: {ret}") else: key = key.view(-1, key.shape[-1]) @@ -447,33 +239,31 @@ class SendingLayerThread(threading.Thread): ret = self.engine.batch_transfer_sync_write( session_id, src_list, dst_list, length_list) if ret < 0: - logger.error("Mooncake transfer failed for request %s", - req_meta.req_id) + logger.error("Mooncake transfer failed for request %s", req_id) raise RuntimeError(f"Mooncake transfer failed, ret: {ret}") - if self.completion_event is not None: - self.completion_event_count -= 1 - if self.completion_event_count == 0: - self.completion_event.set() - self.completion_event = None - def add_event(self, event: threading.Event, count: int) -> None: - self.completion_event = event - self.completion_event_count = count + if layer_index == (self.total_layers - 1): + self.callback_func(req_id, req_meta) class KVCacheRecvingLayerThread(threading.Thread): def __init__(self, tp_rank: int, side_channel_port: int, tp_size: int, - local_engine_id: str, ready_event: threading.Event): + pd_head_ratio: int, local_engine_id: str, + metadata: MooncakeAgentMetadata, + ready_event: threading.Event): super().__init__(daemon=True, name="KVCacheRecvingLayerThread") self.tp_rank = tp_rank self.tp_size = tp_size + self.pd_head_ratio = pd_head_ratio self.local_engine_id = local_engine_id self.side_channel_host = get_ip() self.side_channel_port = side_channel_port self.lock = threading.Lock() self.done_requests = set[str]() + self.task_tracker = dict[str, int]() self.ready_event = ready_event + self.metadata = metadata def get_and_clear_finished_requests(self) -> set[str]: """ @@ -486,22 +276,23 @@ class KVCacheRecvingLayerThread(threading.Thread): self.done_requests = set() return finished_requests + def update_task(self, req_id): + with self.lock: + self.task_tracker[req_id] += 1 + if self.task_tracker[req_id] == self.pd_head_ratio: + self.task_tracker.pop(req_id) + self.done_requests.add(req_id) + def run(self): """Run the thread to handle KV cache transfer requests.""" - #TODO layerwise step9 - # with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore - # while True: - # recv_msg from prefill request send finish= - # Listen for new requests for metadata. - # NOTE(rob): we need each rank to have a unique port. This hack to keeps - # us moving. We will switch when moving to etcd or where we have a - # single ZMQ socket in the scheduler. handshake_port = self.side_channel_port + self.tp_rank path = make_zmq_path("tcp", self.side_channel_host, handshake_port) logger.info("Starting listening on path: %s", path) + encoder = msgspec.msgpack.Encoder() + encoded_data = encoder.encode(self.metadata) with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore self.ready_event.set() - decoder = msgspec.msgpack.Decoder(type=str) + decoder = msgspec.msgpack.Decoder(type=tuple) while True: try: frames = sock.recv_multipart() @@ -515,10 +306,20 @@ class KVCacheRecvingLayerThread(threading.Thread): logger.error("Invalid message format: %s", frames) continue - request_id = decoder.decode(payload[0]) - with self.lock: - self.done_requests.add(request_id) - sock.send_multipart((identity, b"", b"ACK")) + msg = decoder.decode(payload[0]) + if msg[0] == GET_META_MSG: + logger.info("Got GET META INFO for request %s", msg[0]) + sock.send_multipart((identity, b"", encoded_data)) + elif msg[0] == DONE_SENDING_MSG: + logger.debug("Got DONE_RECVING_MSG for request %s", + msg[1]) + request_id = msg[1] + self.update_task(request_id) + sock.send_multipart((identity, b"", b"ACK")) + else: + logger.error( + "Connection listener got unexpected message %s", + msg) except Exception as e: logger.error("Failed to decode message: %s", e) @@ -527,21 +328,24 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata): def __init__(self): self.requests: dict[str, ReqMeta] = {} - self.requests_to_send: dict[str, float] = {} def add_new_req(self, request_id: str, local_block_ids: list[int], kv_transfer_params: dict[str, Any], - metaserver=None): + token_ids: Optional[list[int]] = None): self.requests[request_id] = ReqMeta( + token_ids=token_ids or [], local_block_ids=local_block_ids, - remote_block_ids=kv_transfer_params.get("remote_block_ids", None), - remote_engine_id=kv_transfer_params["remote_engine_id"], - remote_host=kv_transfer_params["remote_host"], - remote_port=kv_transfer_params["remote_port"], - metaserver=metaserver, - remote_tp_size=kv_transfer_params.get("remote_tp_size", None), + remote_block_ids=kv_transfer_params.get("remote_block_ids", []), + remote_engine_id=kv_transfer_params.get("remote_engine_id", None), + remote_host=kv_transfer_params.get("remote_host", None), + remote_port=kv_transfer_params.get("remote_port", None), + remote_te_rpc_port=kv_transfer_params.get("remote_te_rpc_port", + None), + remote_kv_caches_base_addr=kv_transfer_params.get( + "remote_kv_caches_base_addr", None), + metaserver=kv_transfer_params.get("metaserver", None), ) @@ -550,6 +354,7 @@ class MooncakeLayerwiseConnector(KVConnectorBase_V1): def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): assert vllm_config.kv_transfer_config is not None self.engine_id = vllm_config.kv_transfer_config.engine_id + self._connector_metadata = MooncakeLayerwiseConnectorMetadata() if role == KVConnectorRole.SCHEDULER: self.connector_scheduler: Optional[MooncakeLayerwiseConnectorScheduler] = \ @@ -594,10 +399,6 @@ class MooncakeLayerwiseConnector(KVConnectorBase_V1): assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) - def get_finished_count(self) -> Optional[int]: - assert self.connector_scheduler is not None - return self.connector_scheduler.get_finished_count() - ############################################################ # Worker Side Methods ############################################################ @@ -656,16 +457,17 @@ class MooncakeLayerwiseConnectorScheduler: # Handshake base port self.side_channel_port = ( vllm_config.kv_transfer_config.kv_port + - vllm_config.parallel_config.data_parallel_rank_local * + vllm_config.parallel_config.data_parallel_rank * vllm_config.parallel_config.tensor_parallel_size) # Requests that need to start recv. # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. - self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} - self._reqs_need_send: dict[str, float] = {} - self._reqs_need_send_layerwise: dict[str, tuple[str, int, - list[int]]] = {} + self._reqs_need_recv: dict[str, tuple[Request, list[int], + list[int]]] = {} + self._reqs_need_send_layerwise: dict[str, tuple[ + int, list[int], + Request]] = {} # req_id, (len(prompt), local_block_ids, request) def get_num_new_matched_tokens( self, request: "Request", @@ -692,13 +494,11 @@ class MooncakeLayerwiseConnectorScheduler: num_computed_tokens, params) if params is not None and params.get("do_remote_prefill"): - assert num_computed_tokens == 0, "Currently only support " \ - "prefill with num_computed_tokens == 0." - # Assume that the request's KV cache is already fully prefilled and - # can be fetched entirely from the prefill node. - count = len(request.prompt_token_ids) - if count > 0: - return count, True + # Remote prefill: get all prompt blocks from remote. + assert num_computed_tokens % self.block_size == 0 + # Note: We use the full token count as transmit data here. + count = max(len(request.prompt_token_ids) - num_computed_tokens, 0) + return count, count > 0 # No remote prefill for this request. return 0, False @@ -714,25 +514,21 @@ class MooncakeLayerwiseConnectorScheduler: num_external_tokens, params) if params is not None and params.get("do_remote_prefill"): - if all(p in params for p in ("remote_engine_id", "remote_host", - "remote_port")): - local_block_ids = (blocks.get_unhashed_block_ids() - if num_external_tokens > 0 else []) - # Get unhashed blocks to pull from remote. - self._reqs_need_recv[request.request_id] = (request, - local_block_ids) - else: - logger.warning( - "Got invalid KVTransferParams: %s. This " - "request will not utilize KVTransfer", params) + local_block_ids = (blocks.get_unhashed_block_ids() + if num_external_tokens > 0 else []) + # Get unhashed blocks to pull from remote. + self._reqs_need_recv[request.request_id] = ( + request, + [], #request._all_token_ids, + local_block_ids) + params["do_remote_prefill"] = False # Layerwise prefiller add request need send if params is not None and params.get("do_remote_decode"): local_block_ids = (blocks.get_block_ids()[0]) - self._reqs_need_send_layerwise[request.request_id] = ( - params["metaserver"], len(request.all_token_ids), - local_block_ids) + self._reqs_need_send_layerwise[request.request_id] = (len( + request.all_token_ids), local_block_ids, request) def build_connector_meta( self, @@ -741,16 +537,16 @@ class MooncakeLayerwiseConnectorScheduler: meta = MooncakeLayerwiseConnectorMetadata() # Loop through scheduled reqs and convert to ReqMeta. - for req_id, (req, block_ids) in self._reqs_need_recv.items(): + for req_id, (req, token_ids, + block_ids) in self._reqs_need_recv.items(): assert req.kv_transfer_params is not None # For the case where there are no remote blocks to pull # (block_ids is empty), we don't need to schedule # an async read on the worker side. - meta.add_new_req( - request_id=req_id, - local_block_ids=block_ids, - kv_transfer_params=req.kv_transfer_params, - ) + meta.add_new_req(request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + token_ids=token_ids) # Clear the list once workers start the transfers self._reqs_need_recv.clear() @@ -760,7 +556,7 @@ class MooncakeLayerwiseConnectorScheduler: for req_id, new_blocks in zip(cached_reqs.req_ids, cached_reqs.new_block_ids): if req_id in self._reqs_need_send_layerwise and new_blocks is not None: - metaserver, total_tokens, block_ids = self._reqs_need_send_layerwise[ + total_tokens, block_ids, req = self._reqs_need_send_layerwise[ req_id] block_ids.extend(new_blocks[0]) @@ -770,21 +566,16 @@ class MooncakeLayerwiseConnectorScheduler: for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items( ): if req_id in self._reqs_need_send_layerwise: - metaserver, total_tokens, block_ids = self._reqs_need_send_layerwise[ + total_tokens, block_ids, req = self._reqs_need_send_layerwise[ req_id] current_tokens = computed_tokens.get(req_id, 0) + scheduled_tokens if current_tokens == total_tokens: - meta.add_new_req( - request_id=req_id, - local_block_ids=block_ids, - kv_transfer_params=defaultdict(lambda: None), - metaserver=metaserver) + meta.add_new_req(request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + token_ids=[]) self._reqs_need_send_layerwise.pop(req_id) - - meta.requests_to_send = self._reqs_need_send - self._reqs_need_send = {} - return meta def request_finished( @@ -796,52 +587,8 @@ class MooncakeLayerwiseConnectorScheduler: Once a request is finished, determine whether request blocks should be freed now or will be sent asynchronously and freed later. """ - - params = request.kv_transfer_params - logger.debug( - "MooncakeLayerwiseConnector request_finished, request_status=%s, " - "kv_transfer_params=%s", request.status, params) - - if (params is None or not params.get("do_remote_decode") - or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): - return False, None - - computed_block_ids = block_ids - delay_free_blocks = len(computed_block_ids) > 0 - if delay_free_blocks: - logger.info("Delaying free of %d blocks for request %s", - len(computed_block_ids), request.request_id) - self._reqs_need_send[request.request_id] = time.time() - - return delay_free_blocks, dict( - do_remote_prefill=True, - do_remote_decode=False, - remote_engine_id=self.engine_id, - remote_host=self.side_channel_host, - remote_port=self.side_channel_port, - remote_block_ids=computed_block_ids, - ) - - def get_finished_count(self) -> Optional[int]: - prefill_parallel_config: dict[ - str, - Any] = self.vllm_config.kv_transfer_config.get_from_extra_config( - "prefill", {}) - - assert "tp_size" in prefill_parallel_config.keys() - self._prefill_tp_size = prefill_parallel_config["tp_size"] - decode_parallel_config: dict[ - str, - Any] = self.vllm_config.kv_transfer_config.get_from_extra_config( - "decode", {}) - assert "tp_size" in decode_parallel_config.keys() - self._decode_tp_size = decode_parallel_config["tp_size"] - - if self.vllm_config.model_config.use_mla: - return self._decode_tp_size - else: - # TODO support mha and gqa - return None + # layer_wise push, not need delay_free_blocks + return False, None class MooncakeLayerwiseConnectorWorker: @@ -860,22 +607,21 @@ class MooncakeLayerwiseConnectorWorker: self.engine = TransferEngine() # Metadata. - self.completion_event: threading.Event self.vllm_config = vllm_config + self.local_engine_id: str = " " self.engine_id = engine_id self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = vllm_config.parallel_config.tensor_parallel_size self.tp_group = get_tp_group() - self.dp_rank = vllm_config.parallel_config.data_parallel_rank_local + self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.dp_size = vllm_config.parallel_config.data_parallel_size_local self.kv_caches: dict[str, torch.Tensor] = {} self.side_channel_host = get_ip() self.max_device_id = self.tp_size * self.dp_size - self.kv_role = vllm_config.kv_transfer_config.kv_role self.total_layers = vllm_config.model_config.get_num_layers( vllm_config.parallel_config) - self.executor = ThreadPoolExecutor(1) + self.executor = ThreadPoolExecutor(32) self.metaserver_client = httpx.Client( limits=httpx.Limits(max_connections=100000), timeout=None) if self.tp_rank == 0 else None @@ -883,7 +629,7 @@ class MooncakeLayerwiseConnectorWorker: # Handshake base port self.side_channel_port = ( vllm_config.kv_transfer_config.kv_port + - vllm_config.parallel_config.data_parallel_rank_local * + vllm_config.parallel_config.data_parallel_rank * vllm_config.parallel_config.tensor_parallel_size) self.handshake_port = self.side_channel_port + self.tp_rank self.sockets: dict = {} @@ -910,7 +656,7 @@ class MooncakeLayerwiseConnectorWorker: self.device_id = device_ids[self.tp_rank] # type: ignore if vllm_config.kv_transfer_config.get_from_extra_config( - 'use_ascend_direct', False): + 'use_ascend_direct', True): hostname = self.side_channel_host else: hostname = f"{self.side_channel_host}:0:npu_{self.device_id}" @@ -918,8 +664,8 @@ class MooncakeLayerwiseConnectorWorker: self.te_rpc_port = self.engine.get_rpc_port() # Background thread for sending or receiving KV caches. - self.kv_send_layer_thread: Optional[KVCacheSendingLayerThread] = None self.kv_recv_layer_thread: Optional[KVCacheRecvingLayerThread] = None + self.kv_send_layer_thread: Optional[KVCacheSendingLayerThread] = None self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size @@ -927,8 +673,23 @@ class MooncakeLayerwiseConnectorWorker: self.pd_tp_ratio = get_ascend_config().pd_tp_ratio self.pd_head_ratio = get_ascend_config().pd_head_ratio + self.num_head_replica = get_ascend_config().num_head_replica self.first_kv_cache = None + self.remote_poller = zmq.Poller() # type: ignore + self.decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata) + self.encoder = msgspec.msgpack.Encoder() + + self.remote_kv_caches_base_addr: dict[str, dict[int, list[int]]] = \ + defaultdict(dict) + self.remote_te_port: dict[str, dict[int, int]] = \ + defaultdict(dict) + self.remote_sockets_lock = threading.Lock() + self.remote_sockets: dict[ # type: ignore + str, deque[zmq.Socket]] = defaultdict( # type: ignore + deque) + self.remote_poller = zmq.Poller() # type: ignore + self.timeout = 1.0 # seconds def _get_prefill_decode_size(self, vllm_config: VllmConfig): # get prefill tp and dp size from extra config @@ -1022,27 +783,33 @@ class MooncakeLayerwiseConnectorWorker: # After KV Caches registered, start the sending or receiving thread. metadata = MooncakeAgentMetadata( - engine_id=self.engine_id, te_rpc_port=self.te_rpc_port, - kv_caches_base_addr=kv_caches_base_addr, - num_blocks=self.num_blocks, + kv_caches_base_addr=self.kv_caches_base_addr, ) - - ready_event = threading.Event() - if self.kv_role == 'kv_producer': + if self.vllm_config.kv_transfer_config.is_kv_producer: + ready_event = threading.Event() self.kv_send_layer_thread = KVCacheSendingLayerThread( - self.tp_rank, self.tp_size, self._decode_tp_size, - self.engine_id, self.side_channel_host, self.side_channel_port, - metadata, ready_event, self.total_layers, self.engine, - kv_caches_base_addr, self.block_len, self.use_mla, - self.first_kv_cache) + engine=self.engine, + total_layers=self.total_layers, + ready_event=ready_event, + tp_rank=self.tp_rank, + pd_head_ratio=self.pd_head_ratio, + num_head_replica=self.num_head_replica, + kv_cache_base_addr=self.kv_caches_base_addr, + use_mla=self.use_mla, + block_len=self.block_len, + first_kv_cache=first_kv_cache, + callback_func=self.send_done_send_signal) self.kv_send_layer_thread.start() - else: + ready_event.wait() + + if self.vllm_config.kv_transfer_config.is_kv_consumer: + ready_event = threading.Event() self.kv_recv_layer_thread = KVCacheRecvingLayerThread( self.tp_rank, self.side_channel_port, self.tp_size, - self.engine_id, ready_event) + self.pd_head_ratio, self.engine_id, metadata, ready_event) self.kv_recv_layer_thread.start() - ready_event.wait() + ready_event.wait() def _register(self, ptr, length): logger.info( @@ -1053,42 +820,52 @@ class MooncakeLayerwiseConnectorWorker: raise RuntimeError("Mooncake memory registration failed.") def _access_metaserver(self, url, message): - self.metaserver_client.post(url, json=message) + success = False + retry = 0 + while retry < 3 and success is False: + retry += 1 + try: + self.metaserver_client.post(url, json=message) + success = True + except Exception as e: + logger.error( + f"Failed to connect to metaserver: {url}, retry {retry} time." + ) + if retry == 3: + raise e def get_finished(self) -> tuple[set[str], set[str]]: - done_sending = ( - self.kv_send_layer_thread. - get_and_clear_finished_requests( # type: ignore[union-attr] - ) if self.kv_role == 'kv_producer' else set()) done_recving = ( self.kv_recv_layer_thread. get_and_clear_finished_requests( # type: ignore[union-attr] - ) if self.kv_role == 'kv_consumer' else set()) - if self.tp_rank == 0: - logger.debug( - "Number of completed KV cache send requests: %d, receive " - "requests: %d", len(done_sending), len(done_recving)) - return done_sending, done_recving + ) if self.vllm_config.kv_transfer_config.is_kv_consumer else set()) + if len(done_recving) > 0: + logger.info( + "Number of completed KV cache recv requests: %d, receive " + "requests: %d", 0, len(done_recving)) + return set(), done_recving def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata): """Start loading KV blocks from remote engine.""" self.current_layer = 0 - if self.vllm_config.kv_transfer_config.is_kv_producer: + if self.vllm_config.kv_transfer_config.is_kv_consumer: for req_id, meta in metadata.requests.items(): - logger.debug( - f"Send request: {req_id} to proxy metaserver: {meta.metaserver}" - ) - if self.tp_rank == 0: + if self.tp_rank % self.tp_size == 0: + logger.info( + f"Send request: {req_id} to proxy metaserver: {meta.metaserver}" + ) # All parameters here should appear in the returned dict of # request_finished in the scheduler side except "request_id". kv_transfer_params = dict( + token_ids=meta.token_ids, request_id=req_id, - do_remote_prefill=True, - do_remote_decode=False, + do_remote_prefill=False, + do_remote_decode=True, + remote_block_ids=meta.local_block_ids, remote_engine_id=self.engine_id, remote_host=self.side_channel_host, - remote_port=self.side_channel_port) - + remote_port=self.side_channel_port, + ) future = self.executor.submit( self._access_metaserver, url=meta.metaserver, @@ -1102,42 +879,9 @@ class MooncakeLayerwiseConnectorWorker: ) future.add_done_callback(handle_exception) - else: - for req_id, meta in metadata.requests.items(): - for offset in range(self.pd_tp_ratio): - path = make_zmq_path( - "tcp", meta.remote_host, meta.remote_port + - self.tp_rank * self.pd_tp_ratio + offset) - logger.info( - f"Notify the prefiller: {path} that request: {req_id} from decoder is ready." - ) - msg_encoder = msgspec.msgpack.Encoder() - docode_metadata = DecodeMooncakeAgentMetadata( - req_id=req_id, - block_ids=meta.local_block_ids, - port=self.handshake_port, - host=self.side_channel_host, - engine_id=self.engine_id, - te_rpc_port=self.te_rpc_port, - kv_caches_base_addr=self.kv_caches_base_addr, - num_blocks=self.num_blocks) - encoded_data = msg_encoder.encode(docode_metadata) - size_in_bytes = len(encoded_data) - logger.debug( - "Size of encoded Mooncake agent metadata: %d bytes", - size_in_bytes) - with zmq_ctx(zmq.REQ, path) as sock: # type: ignore - ensure_zmq_send(sock, encoded_data) - ack = sock.recv() - if ack != b"ACK": - raise ValueError( - f"Unexpected ACK from prefill node: {ack}") - - if self.kv_send_layer_thread is not None: - for req_id, delay_start_time in metadata.requests_to_send.items(): - if self.tp_rank in self._get_remote_tp_ranks_for_req(req_id): - self.kv_send_layer_thread.add_delayed_request( - req_id, delay_start_time) + assert self.kv_recv_layer_thread is not None + with self.kv_recv_layer_thread.lock: + self.kv_recv_layer_thread.task_tracker[req_id] = 0 def save_kv_layer(self, layer_name: str, kv_layer: Tuple[torch.Tensor, torch.Tensor], @@ -1145,16 +889,16 @@ class MooncakeLayerwiseConnectorWorker: connector_metadata: MooncakeLayerwiseConnectorMetadata, **kwargs) -> None: """MooncakeLayerwiseConnector does not save explicitly.""" - if self.kv_role == 'kv_producer' and connector_metadata.requests.keys( + if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys( ): + # enable decode prefix cache + for request in connector_metadata.requests.values(): + assert len(request.local_block_ids) >= len( + request.remote_block_ids + ), "When prefix cache enabled, remote KVCacheBlocks num should not larger than local KVCacheBlocks num." + request.local_block_ids = request.local_block_ids[ + -len(request.remote_block_ids):] if self.pd_head_ratio != 1: - if self.current_layer != 0: - self.completion_event.wait() - self.completion_event = threading.Event() - if self.kv_send_layer_thread is not None: - self.kv_send_layer_thread.send_layer_thread.add_event( - self.completion_event, - len(connector_metadata.requests.keys())) def sort_kv_cache(input_kv: list[list[int]]): return torch.cat([ @@ -1189,44 +933,103 @@ class MooncakeLayerwiseConnectorWorker: else: key = None value = None - for req_id, request in connector_metadata.requests.items(): - logger.info(f"Add request {req_id} to kv send layer thread. ") + for req_id, req_meta in connector_metadata.requests.items(): + logger.debug( + f"Add request {req_id} to kv send layer thread. {req_meta=}" + ) if self.pd_head_ratio != 1: key_block_num = len( - request.local_block_ids) * key_block_size + req_meta.local_block_ids) * key_block_size value_block_num = len( - request.local_block_ids) * value_block_size - key = keys[key_start_id:key_start_id + - key_block_num] #.clone().contiguous() + req_meta.local_block_ids) * value_block_size + key = keys[key_start_id:key_start_id + key_block_num] value = values[value_start_id:value_start_id + - value_block_num] #.clone().contiguous() + value_block_num] key_start_id += key_block_num value_start_id += value_block_num - if self.kv_send_layer_thread is not None: - self.kv_send_layer_thread.add_request( - request_id=req_id, - local_block_ids=request.local_block_ids, - layer_index=self.current_layer, - key=key, - value=value) + req_meta_update = self.update_decoder_info(req_id, req_meta) + assert self.kv_send_layer_thread is not None + self.kv_send_layer_thread.send_queue.put( + (req_id, req_meta_update, self.current_layer, key, value)) self.current_layer += 1 + def _get_remote_socket( + self, remote_host: str, + remote_handshake_port: int) -> zmq.Socket: # type: ignore + """Get a socket to the remote host.""" + remote_path = make_zmq_path("tcp", remote_host, remote_handshake_port) + with self.remote_sockets_lock: + if self.remote_sockets[remote_path]: + return self.remote_sockets[remote_path].popleft() + + ctx = zmq.Context() # type: ignore + sock = make_zmq_socket( + ctx=ctx, + path=remote_path, + socket_type=zmq.REQ, # type: ignore + bind=False) + sock.setsockopt( + zmq.SNDTIMEO, # type: ignore + int(self.timeout * 1000)) + self.remote_poller.register(sock, zmq.POLLIN) # type: ignore + return sock + + def update_decoder_info(self, req_id, req_meta): + req_meta_update = copy.deepcopy(req_meta) + if self.pd_tp_ratio > 1: + req_meta_update.remote_port = req_meta_update.remote_port + self.tp_rank // self.pd_tp_ratio + else: + req_meta_update.remote_port = req_meta_update.remote_port + self.tp_rank + if req_meta_update.remote_engine_id not in self.remote_kv_caches_base_addr or \ + req_meta_update.remote_port not in self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id]: + try: + encoded_data = self.encoder.encode((GET_META_MSG, req_id)) + sock = self._get_remote_socket(req_meta_update.remote_host, + req_meta_update.remote_port) + ensure_zmq_send(sock, encoded_data) + metadata_bytes = ensure_zmq_recv(sock, self.remote_poller) + agent_meta = self.decoder.decode(metadata_bytes) + except Exception as e: + logger.error( + f"Query to port and kv base addr for request {req_id} from {req_meta_update.remote_host}:{req_meta_update.remote_port} fail with error: {e}" + ) + assert req_meta_update.remote_engine_id != self.engine_id, ( + f"Conflict engine id {req_meta_update.remote_engine_id} with local engine id " + f"{self.local_engine_id}.") + self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id][ + req_meta_update.remote_port] = agent_meta.kv_caches_base_addr + self.remote_te_port[req_meta_update.remote_engine_id][ + req_meta_update.remote_port] = agent_meta.te_rpc_port + logger.info( + f"Query to port and kv base addr for request {req_id} from {req_meta_update.remote_host}:{req_meta_update.remote_port} success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}" + ) + req_meta_update.remote_te_rpc_port = self.remote_te_port[ + req_meta_update.remote_engine_id][req_meta_update.remote_port] + req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[ + req_meta_update.remote_engine_id][req_meta_update.remote_port] + return req_meta_update + + def send_done_send_signal(self, req_id, req_meta): + logger.info("Sending done sending signal for request %s to %s:%d", + req_id, req_meta.remote_host, req_meta.remote_port) + try: + path = make_zmq_path("tcp", req_meta.remote_host, + req_meta.remote_port) + msg_encoder = msgspec.msgpack.Encoder() + encoded_data = msg_encoder.encode((DONE_SENDING_MSG, req_id)) + with zmq_ctx(zmq.REQ, path) as sock: # type: ignore + ensure_zmq_send(sock, encoded_data) + ack = sock.recv() + if ack != b"ACK": + raise ValueError(f"Unexpected ACK response: {ack}") + except Exception as e: + logger.error( + f"Sending done sending signal for request {req_id} to {req_meta.remote_host}:{req_meta.remote_port} fail with error: {e}" + ) + def wait_for_layer_load(self, layer_name: str) -> None: pass - def _get_remote_tp_rank(self, req_id: str) -> int: - return self._get_remote_tp_ranks_for_req(req_id)[self.tp_rank] - - def _get_remote_tp_ranks_for_req(self, req_id: str) -> list[int]: - if self._prefill_tp_size == self._decode_tp_size: - return list(range(self._prefill_tp_size)) - - seed = string_to_int64_hash(req_id) - rand = random.Random(seed) - sampled_nums = rand.sample(range(self._prefill_tp_size), - self._decode_tp_size) - return sampled_nums - @contextlib.contextmanager def zmq_ctx(socket_type: Any, diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 49f6e47..f373cb5 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -89,7 +89,8 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): parallel_config.data_parallel_size, num_head_replica, -1, alltoall_group_size ) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size] - group_ranks = group_ranks.view(-1, alltoall_group_size).unbind(0) + group_ranks = group_ranks.reshape(-1, + alltoall_group_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] local_rank = get_world_group().local_rank num = next(