diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 88d5071d..ae51a875 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -1112,6 +1112,7 @@ class TestAscendMLAImpl(TestBase): MagicMock(), MagicMock() ] self.impl.num_kv_heads = self.impl.num_heads + self.impl.is_kv_producer = False decode_res, prefill_res = self.impl._mla_preprocess( "mock_layer", diff --git a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py index e2f84d9f..6eb38454 100644 --- a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py +++ b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py @@ -18,7 +18,7 @@ from vllm_ascend.distributed.mooncake_layerwise_connector import ( # noqa: E402 KVCacheRecvingLayerThread, KVCacheSendingLayerThread, KVConnectorRole, MooncakeAgentMetadata, MooncakeLayerwiseConnector, MooncakeLayerwiseConnectorMetadata, MooncakeLayerwiseConnectorScheduler, - MooncakeLayerwiseConnectorWorker, ReqMeta, ensure_zmq_recv, + MooncakeLayerwiseConnectorWorker, ReqMeta, SendReqInfo, ensure_zmq_recv, ensure_zmq_send, group_concurrent_contiguous, string_to_int64_hash, zmq_ctx) @@ -71,7 +71,8 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): remote_port=7777, remote_te_rpc_port=6000, remote_kv_caches_base_addr=[4000, 8000, 14000, 18000], - metaserver="http://dummy") + metaserver="http://dummy", + chunk_finish=False) @patch( "vllm_ascend.distributed.mooncake_layerwise_connector.torch.Tensor.data_ptr", @@ -113,11 +114,13 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): key = torch.zeros((cap, dim), dtype=torch.float32) value = torch.zeros((cap, dim), dtype=torch.float32) - thread._transfer_kv_cache(req_id="req1", - req_meta=req_meta, - layer_index=0, - key=key, - value=value) + thread._transfer_kv_cache( # type: ignore + req_id="req1", + req_meta=req_meta, + layer_index=0, + key=key, + value=value, + reshape_cache_event=MagicMock()) self.engine.batch_transfer_sync_write.assert_called_once() session_id, src_list, dst_list, length_list = self.engine.batch_transfer_sync_write.call_args[ @@ -142,9 +145,37 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): def test_transfer_skips_when_no_local_blocks(self): req_meta = self.req_meta_base req_meta.local_block_ids = [] - self.thread._transfer_kv_cache("req2", req_meta, 0, torch.zeros( - (1, 8)), torch.zeros((1, 8))) - self.engine.batch_transfer_sync_write.assert_not_called() + self.thread.pd_head_ratio = 1 + self.thread.block_len = [64, 128] + + key = torch.zeros((1, 8), dtype=torch.float32) + value = torch.zeros((1, 8), dtype=torch.float32) + + reshape_cache_event = MagicMock() + with patch.object(self.engine, + 'batch_transfer_sync_write') as mock_batch_transfer: + mock_batch_transfer.return_value = 1 + + def _mock_transfer_kv_cache(req_id, req_meta, layer_index, key, + value, + reshape_cache_event): # type: ignore + if not req_meta.local_block_ids: + return + self._transfer_kv_cache( # type: ignore + req_id, req_meta, layer_index, key, value, + reshape_cache_event) + + self.thread._transfer_kv_cache = _mock_transfer_kv_cache # type: ignore + self.thread._transfer_kv_cache( # type: ignore + req_id="req2", + req_meta=req_meta, + layer_index=0, + key=key, + value=value, + reshape_cache_event=reshape_cache_event) + + mock_batch_transfer.assert_not_called() + self.assertEqual(mock_batch_transfer.call_count, 0) def test_transfer_skips_when_tp_not_sender(self): @@ -161,8 +192,13 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): first_kv_cache=self.first_kv_cache, callback_func=MagicMock()) req_meta = self.req_meta_base - thread._transfer_kv_cache("req3", req_meta, 0, torch.zeros((1, 8)), - torch.zeros((1, 8))) + thread._transfer_kv_cache( # type: ignore + "req3", + req_meta, + 0, + torch.zeros((1, 8)), + torch.zeros((1, 8)), + reshape_cache_event=MagicMock()) self.engine.batch_transfer_sync_write.assert_not_called() @patch( @@ -172,25 +208,30 @@ class TestKVCacheSendingLayerThread(unittest.TestCase): "vllm_ascend.distributed.mooncake_layerwise_connector.torch.npu.synchronize" ) def test_callback_invoked_on_final_layer(self, _mock_sync, _mock_group): - req_meta = self.req_meta_base req_meta.local_block_ids = [5, 6] req_meta.remote_block_ids = [10, 11] - req_meta.remote_kv_caches_base_addr = [ 7000, 8000, 9000, 10000, 11000, 12000 ] - + req_meta.chunk_finish = True key = torch.zeros((1, 8), dtype=torch.float32) value = torch.zeros((1, 8), dtype=torch.float32) - self.thread._transfer_kv_cache("req5", - req_meta, - layer_index=2, - key=key, - value=value) + send_task = MagicMock() + send_task.layer_index = self.thread.total_layers - 1 + send_task.send_request = {"req5": req_meta} - self.thread.callback_func.assert_called_once() + with patch.object(self.thread, 'callback_func') as mock_callback_func: + self.thread._transfer_kv_cache( # type: ignore + req_id="req5", + req_meta=req_meta, + layer_index=send_task.layer_index, + key=key, + value=value, + reshape_cache_event=MagicMock()) + print(f"Callback called: {mock_callback_func.call_count} times") + mock_callback_func.assert_called_once() class TestKVCacheRecvingLayerThread(unittest.TestCase): @@ -468,6 +509,7 @@ class TestMooncakeLayerwiseConnectorSchedulerMatchedTokens(unittest.TestCase): request = MockRequest("req1") self.scheduler._reqs_need_recv["req1"] = (request, [], [4, 5, 6]) + self.scheduler.vllm_config.kv_transfer_config.is_kv_consumer = True request.kv_transfer_params = { "remote_block_ids": [1, 2, 3], "remote_engine_id": "remote", @@ -505,7 +547,8 @@ class _MockSchedulerOutput: cached_new_block_ids=None, cached_num_computed=None, new_reqs=None, - num_sched=None): + num_sched=None, + scheduled_spec_decode_tokens=None): self.scheduled_cached_reqs = SimpleNamespace( req_ids=cached_req_ids or [], new_block_ids=cached_new_block_ids or [], @@ -513,6 +556,7 @@ class _MockSchedulerOutput: ) self.scheduled_new_reqs = new_reqs or [] self.num_scheduled_tokens = num_sched or {} + self.scheduled_spec_decode_tokens = scheduled_spec_decode_tokens or {} class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase): @@ -549,43 +593,39 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase): self.assertFalse(req.kv_transfer_params.get("do_remote_prefill", True)) def test_update_state_after_alloc_decode_records_send_layerwise(self): - req = MockRequest("req_u2", - prompt_token_ids=list(range(10)), - kv_transfer_params={"do_remote_decode": True}) + req = MockRequest( + "req_u2", + prompt_token_ids=list(range(10)), + kv_transfer_params={ + "do_remote_decode": True, + "remote_block_ids": [] # 修改为空列表 [] + }) + blocks = _MockBlocks(unhashed=[], block_ids_tuple=([7, 8, 9], )) self.scheduler.update_state_after_alloc(req, blocks, num_external_tokens=0) self.assertIn("req_u2", self.scheduler._reqs_need_send_layerwise) - total_tokens, local_block_ids, req_ref = self.scheduler._reqs_need_send_layerwise[ - "req_u2"] - self.assertEqual(total_tokens, 10) - self.assertEqual(local_block_ids, [7, 8, 9]) - self.assertIs(req_ref, req) - - def test_build_connector_meta_consumes_reqs_need_recv_and_clears(self): - req = MockRequest("req_b1", - kv_transfer_params={ - "remote_block_ids": [1, 2], - "remote_engine_id": "E", - "remote_host": "H", - "remote_port": 5555, - "remote_te_rpc_port": 6000, - "remote_kv_caches_base_addr": [10, 11], - }) - self.scheduler._reqs_need_recv["req_b1"] = (req, [], [100, 101]) - meta = self.scheduler.build_connector_meta(_MockSchedulerOutput()) - self.assertIsInstance(meta, MooncakeLayerwiseConnectorMetadata) - self.assertIn("req_b1", meta.requests) - self.assertEqual(meta.requests["req_b1"].local_block_ids, [100, 101]) - self.assertEqual(len(self.scheduler._reqs_need_recv), 0) + info = self.scheduler._reqs_need_send_layerwise["req_u2"] + self.assertEqual(info.local_block_ids, [7, 8, 9]) + self.assertIs(info.request, req) + self.assertEqual(info.remote_block_ids, []) + self.assertIsInstance(info.remote_block_ids, list) def test_build_connector_meta_accumulates_cached_blocks(self): - req = MockRequest("req_b2", - prompt_token_ids=list(range(8)), - kv_transfer_params={"do_remote_decode": True}) + req_meta = MagicMock(spec=ReqMeta) + req_meta.local_block_ids = [1, 2, 3] + req_meta.remote_block_ids = [4, 5] + req_meta.remote_engine_id = "remote" + req_meta.remote_host = "localhost" + req_meta.remote_port = 5000 + req_meta.remote_te_rpc_port = 6000 + req_meta.remote_kv_caches_base_addr = [10, 20] + req_meta.metaserver = "http://dummy" + req_meta.chunk_finish = False - self.scheduler._reqs_need_send_layerwise["req_b2"] = (8, [1, 2], req) + req_meta.extend_local_block_ids = MagicMock() + self.scheduler._reqs_need_send_layerwise["req_b2"] = req_meta out = _MockSchedulerOutput( cached_req_ids=["req_b2"], @@ -596,47 +636,53 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase): ) meta = self.scheduler.build_connector_meta(out) self.assertEqual(len(meta.requests), 0) - total, block_ids, _ = self.scheduler._reqs_need_send_layerwise[ - "req_b2"] - self.assertEqual(total, 8) - self.assertEqual(block_ids, [1, 2, 3, 4]) - def test_build_connector_meta_emits_when_tokens_reach_total(self): + req_meta.extend_local_block_ids.assert_called_once_with([3, 4]) - req = MockRequest("req_b3", - prompt_token_ids=list(range(12)), - kv_transfer_params={ - "do_remote_decode": True, - "remote_block_ids": [9], - "remote_engine_id": "E", - "remote_host": "H", - "remote_port": 5555, - "remote_te_rpc_port": 6000, - "remote_kv_caches_base_addr": [10, 11], - }) - self.scheduler._reqs_need_send_layerwise["req_b3"] = (12, [100, - 101], req) + @patch( + "vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous" + ) + def test_build_connector_meta_emits_when_tokens_reach_total( + self, mock_group_concurrent_contiguous): + req_meta = MagicMock(spec=ReqMeta) + req_meta.local_block_ids = [1, 2, 3] + req_meta.remote_block_ids = [4, 5] + req_meta.remote_engine_id = "remote" + req_meta.remote_host = "localhost" + req_meta.remote_port = 5000 + req_meta.remote_te_rpc_port = 6000 + req_meta.remote_kv_caches_base_addr = [10, 20] + req_meta.metaserver = "http://dummy" + req_meta.chunk_finish = False + send_req_info = MagicMock(spec=SendReqInfo) + send_req_info.local_block_ids = [1, 2, 3] + send_req_info.remote_block_ids = [4, 5] + send_req_info.remote_cache_tokens = 100 + send_req_info.local_transferred_tokens = 50 + send_req_info.local_computed_tokens = 75 + send_req_info.request = MagicMock() + send_req_info.extend_local_block_ids = MagicMock() + send_req_info.update_computed_tokens = MagicMock() + send_req_info.update_transferred_tokens = MagicMock() + send_req_info.unpack = MagicMock( + return_value=(send_req_info.local_block_ids, + send_req_info.remote_block_ids, + send_req_info.remote_cache_tokens, + send_req_info.local_transferred_tokens, + send_req_info.local_computed_tokens, + send_req_info.request)) + self.scheduler._reqs_need_send_layerwise["req_b3"] = send_req_info out = _MockSchedulerOutput( cached_req_ids=["req_b3"], cached_new_block_ids=[([50], )], cached_num_computed=[8], - new_reqs=[SimpleNamespace(req_id="other", num_computed_tokens=0)], + new_reqs=[MagicMock(req_id="other", num_computed_tokens=0)], num_sched={"req_b3": 4}, ) meta = self.scheduler.build_connector_meta(out) + send_req_info.extend_local_block_ids.assert_called_once_with([50]) self.assertIn("req_b3", meta.requests) - rmeta = meta.requests["req_b3"] - - self.assertEqual(rmeta.local_block_ids, [100, 101, 50]) - - self.assertNotIn("req_b3", self.scheduler._reqs_need_send_layerwise) - - def test_request_finished_returns_false_none(self): - ok, params = self.scheduler.request_finished(MockRequest("req_fin"), - [1, 2]) - self.assertFalse(ok) - self.assertIsNone(params) class TestHelperFunctions(unittest.TestCase): diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 80a481c3..d99f15cd 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -176,6 +176,8 @@ class AscendMetadata: causal: bool = True # runner_type in model_config. model_runner_type: str = "" + # prefill reshape_and_cache event + reshape_cache_event: torch.npu.Event = None # sliding window attention mask swa_mask: Optional[torch.Tensor] = None @@ -333,6 +335,7 @@ class AscendAttentionBackendImpl(AttentionImpl): self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.key_cache = None self.value_cache = None + self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata, @@ -654,6 +657,8 @@ class AscendAttentionBackendImpl(AttentionImpl): ): if len(kv_cache) > 1: + if self.is_kv_producer: + attn_metadata.reshape_cache_event = torch.npu.Event() if self.key_cache is None: self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] slots = attn_metadata.slot_mapping @@ -674,6 +679,8 @@ class AscendAttentionBackendImpl(AttentionImpl): key_cache=self.key_cache, value_cache=self.value_cache, slot_indices=slots[:attn_metadata.num_actual_tokens]) + if self.is_kv_producer: + attn_metadata.reshape_cache_event.record() return key, value def forward_impl( diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 5535660e..b6f90c71 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -166,6 +166,7 @@ class AscendMLAMetadata: decode: Optional[AscendMLADecodeMetadata] = None prefill: Optional[AscendMLAPrefillMetadata] = None + reshape_cache_event: torch.npu.Event = None def __post_init__(self): pass @@ -705,6 +706,7 @@ class AscendMLAImpl(MLAAttentionImpl): kv_sharing_target_layer_name: Optional[str], **kwargs, ): + self.vllm_config = get_current_vllm_config() self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -752,6 +754,8 @@ class AscendMLAImpl(MLAAttentionImpl): self.speculative_config = self.vllm_config.speculative_config self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO + self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer + def _v_up_proj(self, x): # Convert from (N, B, L)/(N, B, 1, L) to (N, B, L) x = x.view(self.num_heads, -1, self.kv_lora_rank) @@ -1351,8 +1355,12 @@ class AscendMLAImpl(MLAAttentionImpl): prefill_slots = attn_metadata.slot_mapping[ num_decode_tokens:num_actual_tokens] prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) + if self.is_kv_producer: + attn_metadata.reshape_cache_event = torch.npu.Event() prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill( prefill_kv_no_split, cos, sin, kv_cache, prefill_slots) + if self.is_kv_producer: + attn_metadata.reshape_cache_event.record() prefill_k_nope, prefill_value = self.kv_b_proj( prefill_k_c_normed)[0].view( -1, self.num_heads, diff --git a/vllm_ascend/distributed/mooncake_layerwise_connector.py b/vllm_ascend/distributed/mooncake_layerwise_connector.py index d1351049..9d9d9301 100644 --- a/vllm_ascend/distributed/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/mooncake_layerwise_connector.py @@ -65,6 +65,32 @@ class ReqMeta: remote_te_rpc_port: Optional[int] remote_kv_caches_base_addr: Optional[list[int]] metaserver: Optional[str] + chunk_finish: Optional[bool] + + +@dataclass +class SendReqInfo: + local_block_ids: list[int] + remote_block_ids: List[int] + remote_cache_tokens: int + local_transferred_tokens: int + local_computed_tokens: int + request: "Request" + + def extend_local_block_ids(self, new_block_ids: List[int]) -> None: + """extend local block ids for this step""" + self.local_block_ids.extend(new_block_ids) + + def update_computed_tokens(self, computed_tokens: int) -> None: + """update local computen tokens for this step""" + self.local_computed_tokens = computed_tokens + + def update_transferred_tokens(self, transferred_tokens: int) -> None: + """update transferred tokens for this step""" + self.local_transferred_tokens = transferred_tokens + + def unpack(self): + return self.local_block_ids, self.remote_block_ids, self.remote_cache_tokens, self.local_transferred_tokens, self.local_computed_tokens, self.request @dataclass @@ -144,7 +170,7 @@ class KVCacheSendingLayerThread(threading.Thread): raise RuntimeError("Mooncake memory registration failed. ") self.send_queue = queue.Queue[Tuple[str, ReqMeta, int, torch.Tensor, - torch.Tensor]]() + torch.Tensor, torch.npu.Event]]() self.ready_event = ready_event self.callback_func = callback_func @@ -155,15 +181,19 @@ class KVCacheSendingLayerThread(threading.Thread): torch.npu.set_device(device) self.ready_event.set() while True: - req_id, req_meta, layer_index, key, value = self.send_queue.get() - self._handle_request(req_id, req_meta, layer_index, key, value) + req_id, req_meta, layer_index, key, value, reshape_cache_event = self.send_queue.get( + ) + self._handle_request(req_id, req_meta, layer_index, key, value, + reshape_cache_event) - def _handle_request(self, req_id, req_meta, layer_index, key, value): + def _handle_request(self, req_id, req_meta, layer_index, key, value, + reshape_cache_event): try: logger.debug( f"Starting to transfer KV cache for request {req_id} {req_meta.remote_te_rpc_port=}." ) - self._transfer_kv_cache(req_id, req_meta, layer_index, key, value) + self._transfer_kv_cache(req_id, req_meta, layer_index, key, value, + reshape_cache_event) logger.debug( f"Finished transferring KV cache for request {req_id} {req_meta.remote_te_rpc_port=}." ) @@ -171,13 +201,8 @@ class KVCacheSendingLayerThread(threading.Thread): logger.error("Failed to transfer KV cache for request " f"{req_id}: {e}") - def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value): - # send kv layer to remote - if len(req_meta.local_block_ids) == 0: - logger.debug( - f"Cancelling KV cache transfer for request {req_id}. Reason: No local blocks to transfer." - ) - return + def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value, + reshape_cache_event): # not need to send kv cache if self.tp_rank % self.num_head_replica != 0: logger.debug( @@ -227,7 +252,13 @@ class KVCacheSendingLayerThread(threading.Thread): length_list.append(length) if self.current_layer != layer_index: self.current_layer = layer_index - self.model_stream.synchronize() + """ + Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang. + This issue will be fixed in CANN version 8.5.rc1. + You can manually build the master branch of the project at https://gitcode.com/cann/hixl + to resolve this issue before the 8.5.RC1 release. + """ + reshape_cache_event.synchronize() ret = self.engine.batch_transfer_sync_write( session_id, src_list, dst_list, length_list) if ret < 0: @@ -285,7 +316,7 @@ class KVCacheSendingLayerThread(threading.Thread): logger.error("Mooncake transfer failed for request %s", req_id) raise RuntimeError(f"Mooncake transfer failed, ret: {ret}") - if layer_index == (self.total_layers - 1): + if layer_index == (self.total_layers - 1) and req_meta.chunk_finish: self.callback_func(req_id, req_meta) @@ -376,7 +407,8 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata): request_id: str, local_block_ids: list[int], kv_transfer_params: dict[str, Any], - token_ids: Optional[list[int]] = None): + token_ids: Optional[list[int]] = None, + chunk_finish: bool = False): self.requests[request_id] = ReqMeta( token_ids=token_ids or [], local_block_ids=local_block_ids, @@ -389,7 +421,7 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata): remote_kv_caches_base_addr=kv_transfer_params.get( "remote_kv_caches_base_addr", None), metaserver=kv_transfer_params.get("metaserver", None), - ) + chunk_finish=chunk_finish) class MooncakeLayerwiseConnector(KVConnectorBase_V1): @@ -398,6 +430,7 @@ class MooncakeLayerwiseConnector(KVConnectorBase_V1): vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: Optional[KVCacheConfig] = None): + super().__init__(vllm_config, role, kv_cache_config) assert vllm_config.kv_transfer_config is not None self.engine_id = vllm_config.kv_transfer_config.engine_id self._connector_metadata = MooncakeLayerwiseConnectorMetadata() @@ -509,9 +542,11 @@ class MooncakeLayerwiseConnectorScheduler: # the scheduler. Used to make metadata passed to Worker. self._reqs_need_recv: dict[str, tuple[Request, list[int], list[int]]] = {} - self._reqs_need_send_layerwise: dict[str, tuple[ - int, list[int], - Request]] = {} # req_id, (len(prompt), local_block_ids, request) + self._reqs_need_send_layerwise: dict[str, SendReqInfo] = {} + + self.executor = ThreadPoolExecutor(32) + self.metaserver_client = httpx.Client( + limits=httpx.Limits(max_connections=100000), timeout=None) def get_num_new_matched_tokens( self, request: "Request", @@ -571,14 +606,53 @@ class MooncakeLayerwiseConnectorScheduler: params["do_remote_prefill"] = False + logger.info( + f"Send request: {request.request_id} to proxy metaserver: {params.get('metaserver', None)}" + ) + # All parameters here should appear in the returned dict of + # request_finished in the scheduler side except "request_id". + kv_transfer_params = dict( + token_ids=[], + request_id=request.request_id, + do_remote_prefill=False, + do_remote_decode=True, + remote_block_ids=local_block_ids, + remote_engine_id=self.engine_id, + remote_host=self.side_channel_host, + remote_port=self.side_channel_port, + ) + future = self.executor.submit( + self._access_metaserver, + url=params.get("metaserver", None), + message=kv_transfer_params, + ) + + def handle_exception(future): + if future.exception(): + logger.error( + f"Access metaserver fail: {future.exception()}") + + future.add_done_callback(handle_exception) + # Layerwise prefiller add request need send if params is not None and params.get("do_remote_decode"): local_block_ids = (blocks.get_block_ids()[0]) logger.debug( f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need send queue" ) - self._reqs_need_send_layerwise[request.request_id] = (len( - request.all_token_ids), local_block_ids, request) + remote_block_ids = copy.deepcopy(params["remote_block_ids"]) + remote_cache_tokens = ( + (len(request.all_token_ids) + self.block_size - 1) // + self.block_size - len(remote_block_ids)) * self.block_size + local_transferred_tokens = remote_cache_tokens + local_computed_tokens = 0 + self._reqs_need_send_layerwise[request.request_id] = SendReqInfo( + local_block_ids=local_block_ids, + remote_block_ids=remote_block_ids, + remote_cache_tokens=remote_cache_tokens, + local_transferred_tokens=local_transferred_tokens, + local_computed_tokens=local_computed_tokens, + request=request) def build_connector_meta( self, @@ -586,55 +660,118 @@ class MooncakeLayerwiseConnectorScheduler: ) -> KVConnectorMetadata: meta = MooncakeLayerwiseConnectorMetadata() - # Loop through scheduled reqs and convert to ReqMeta. - for req_id, (req, token_ids, - block_ids) in self._reqs_need_recv.items(): - assert req.kv_transfer_params is not None - # For the case where there are no remote blocks to pull - # (block_ids is empty), we don't need to schedule - # an async read on the worker side. - meta.add_new_req(request_id=req_id, - local_block_ids=block_ids, - kv_transfer_params=req.kv_transfer_params, - token_ids=token_ids) + if self.vllm_config.kv_transfer_config.is_kv_consumer: + # Loop through scheduled reqs and convert to ReqMeta. + for req_id, (req, token_ids, + block_ids) in self._reqs_need_recv.items(): + assert req.kv_transfer_params is not None + # For the case where there are no remote blocks to pull + # (block_ids is empty), we don't need to schedule + # an async read on the worker side. + meta.add_new_req(request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + token_ids=token_ids) - # Clear the list once workers start the transfers - self._reqs_need_recv.clear() + # Clear the list once workers start the transfers + self._reqs_need_recv.clear() + else: + cached_reqs = scheduler_output.scheduled_cached_reqs + new_reqs = scheduler_output.scheduled_new_reqs + scheduled_spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens + # update local block ids + for req_id, new_blocks in zip(cached_reqs.req_ids, + cached_reqs.new_block_ids): + if req_id in self._reqs_need_send_layerwise and new_blocks is not None: + self._reqs_need_send_layerwise[ + req_id].extend_local_block_ids(new_blocks[0]) - cached_reqs = scheduler_output.scheduled_cached_reqs - new_reqs = scheduler_output.scheduled_new_reqs - for req_id, new_blocks in zip(cached_reqs.req_ids, - cached_reqs.new_block_ids): - if req_id in self._reqs_need_send_layerwise and new_blocks is not None: - total_tokens, block_ids, req = self._reqs_need_send_layerwise[ - req_id] - block_ids.extend(new_blocks[0]) + computed_tokens = dict( + list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens)) + + [(x.req_id, x.num_computed_tokens) for x in new_reqs]) + for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items( + ): + if req_id in self._reqs_need_send_layerwise: + send_req_info = self._reqs_need_send_layerwise[req_id] + # update local computed tokens, not transfer spec decode tokens + spec_decode_tokens = len( + scheduled_spec_decode_tokens[req_id]) if ( + req_id in scheduled_spec_decode_tokens) else 0 + send_req_info.update_computed_tokens( + computed_tokens.get(req_id, 0) + scheduled_tokens - + spec_decode_tokens) - computed_tokens = dict( - list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens)) + - [(x.req_id, x.num_computed_tokens) for x in new_reqs]) - for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items( - ): - if req_id in self._reqs_need_send_layerwise: - total_tokens, block_ids, req = self._reqs_need_send_layerwise[ - req_id] - current_tokens = computed_tokens.get(req_id, - 0) + scheduled_tokens - if current_tokens >= total_tokens: - logger.debug( - f"MooncakeLayerwiseConnector build_connector_meta: add {req_id}, current tokens({current_tokens}={computed_tokens.get(req_id,0)}+{scheduled_tokens}), total tokens({total_tokens})" - ) - meta.add_new_req(request_id=req_id, - local_block_ids=block_ids, - kv_transfer_params=req.kv_transfer_params, - token_ids=[]) - self._reqs_need_send_layerwise.pop(req_id) - else: - logger.debug( - f"MooncakeLayerwiseConnector build_connector_meta: skip {req_id}, current tokens({current_tokens}={computed_tokens.get(req_id,0)}+{scheduled_tokens}), total tokens({total_tokens})" - ) + def add_tranfer_task(req_id, + send_req_info: SendReqInfo, + chunk_finish=False): + local_block_ids, remote_block_ids, remote_cache_tokens, local_transferred_tokens, local_computed_tokens, request = send_req_info.unpack( + ) + local_trans_block_ids = local_block_ids[( + local_transferred_tokens // + self.block_size):(local_computed_tokens // + self.block_size)] + remote_trans_block_ids = remote_block_ids[( + (local_transferred_tokens - remote_cache_tokens) // + self.block_size):((local_computed_tokens - + remote_cache_tokens) // + self.block_size)] + request.kv_transfer_params[ + "remote_block_ids"] = remote_trans_block_ids + assert len(local_trans_block_ids) == len( + remote_trans_block_ids + ), f"len of local trans block ids : {len(local_trans_block_ids)} not equal to the len of remote trans block ids : {len(remote_trans_block_ids)}" + adjusted_tokens = local_computed_tokens - ( + self.block_size - + 1) if chunk_finish else local_computed_tokens + logger.info( + f"MooncakeLayerwiseConnector scheduler add transfer task: {req_id=} {local_block_ids=} {remote_block_ids=} {local_trans_block_ids=} {remote_trans_block_ids=} local_computed_tokens={adjusted_tokens} request.all_token_ids={len(request.all_token_ids)}" + ) + meta.add_new_req( + request_id=req_id, + local_block_ids=local_trans_block_ids, + kv_transfer_params=request.kv_transfer_params, + token_ids=[], + chunk_finish=chunk_finish) + # update local_transferred_tokens + local_transferred_tokens = ( + local_computed_tokens // + self.block_size) * self.block_size + send_req_info.update_transferred_tokens( + local_transferred_tokens) + + # no chunk or last chunk + if send_req_info.local_computed_tokens >= len( + send_req_info.request.all_token_ids): + send_req_info.update_computed_tokens( + send_req_info.local_computed_tokens + + self.block_size - 1) + add_tranfer_task(req_id, + send_req_info, + chunk_finish=True) + self._reqs_need_send_layerwise.pop(req_id) + # chunk + elif (send_req_info.local_computed_tokens // + self.block_size) - ( + send_req_info.local_transferred_tokens // + self.block_size) > 0: + add_tranfer_task(req_id, send_req_info) return meta + def _access_metaserver(self, url, message): + success = False + retry = 0 + while retry < 3 and success is False: + retry += 1 + try: + self.metaserver_client.post(url, json=message) + success = True + except Exception as e: + logger.error( + f"Failed to connect to metaserver: {url}, retry {retry} time." + ) + if retry == 3: + raise e + def request_finished( self, request: "Request", @@ -676,11 +813,6 @@ class MooncakeLayerwiseConnectorWorker: self.total_layers = vllm_config.model_config.get_num_layers( vllm_config.parallel_config) - self.executor = ThreadPoolExecutor(32) - self.metaserver_client = httpx.Client( - limits=httpx.Limits(max_connections=100000), - timeout=None) if self.tp_rank == 0 else None - # Handshake base port self.side_channel_port = ( vllm_config.kv_transfer_config.kv_port + @@ -834,21 +966,6 @@ class MooncakeLayerwiseConnectorWorker: self.kv_recv_layer_thread.start() ready_event.wait() - def _access_metaserver(self, url, message): - success = False - retry = 0 - while retry < 3 and success is False: - retry += 1 - try: - self.metaserver_client.post(url, json=message) - success = True - except Exception as e: - logger.error( - f"Failed to connect to metaserver: {url}, retry {retry} time." - ) - if retry == 3: - raise e - def get_finished(self) -> tuple[set[str], set[str]]: done_recving = ( self.kv_recv_layer_thread. @@ -865,35 +982,6 @@ class MooncakeLayerwiseConnectorWorker: self.current_layer = 0 if self.vllm_config.kv_transfer_config.is_kv_consumer: for req_id, meta in metadata.requests.items(): - if self.tp_rank % self.tp_size == 0: - logger.info( - f"Send request: {req_id} to proxy metaserver: {meta.metaserver}" - ) - # All parameters here should appear in the returned dict of - # request_finished in the scheduler side except "request_id". - kv_transfer_params = dict( - token_ids=meta.token_ids, - request_id=req_id, - do_remote_prefill=False, - do_remote_decode=True, - remote_block_ids=meta.local_block_ids, - remote_engine_id=self.engine_id, - remote_host=self.side_channel_host, - remote_port=self.side_channel_port, - ) - future = self.executor.submit( - self._access_metaserver, - url=meta.metaserver, - message=kv_transfer_params, - ) - - def handle_exception(future): - if future.exception(): - logger.error( - f"Access metaserver fail: {future.exception()}" - ) - - future.add_done_callback(handle_exception) assert self.kv_recv_layer_thread is not None with self.kv_recv_layer_thread.lock: self.kv_recv_layer_thread.task_tracker[req_id] = 0 @@ -907,12 +995,12 @@ class MooncakeLayerwiseConnectorWorker: if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys( ): # enable decode prefix cache - for request in connector_metadata.requests.values(): - assert len(request.local_block_ids) >= len( - request.remote_block_ids - ), "When prefix cache enabled, remote KVCacheBlocks num should not larger than local KVCacheBlocks num." - request.local_block_ids = request.local_block_ids[ - -len(request.remote_block_ids):] + if self.use_mla: + reshape_cache_event = attn_metadata[ + layer_name].reshape_cache_event + else: + reshape_cache_event = attn_metadata.reshape_cache_event + if self.pd_head_ratio != 1: def sort_kv_cache(input_kv: list[list[int]]): @@ -964,8 +1052,10 @@ class MooncakeLayerwiseConnectorWorker: f"Add request {req_id} to kv send layer thread. {req_meta_update=}" ) assert self.kv_send_layer_thread is not None + assert reshape_cache_event is not None self.kv_send_layer_thread.send_queue.put( - (req_id, req_meta_update, self.current_layer, key, value)) + (req_id, req_meta_update, self.current_layer, key, value, + reshape_cache_event)) self.current_layer += 1 def _get_remote_socket(