diff --git a/vllm_ascend/core/recompute_scheduler.py b/vllm_ascend/core/recompute_scheduler.py index b83ced7e..b35b8e24 100644 --- a/vllm_ascend/core/recompute_scheduler.py +++ b/vllm_ascend/core/recompute_scheduler.py @@ -91,6 +91,11 @@ class RecomputeScheduler(Scheduler): and self.vllm_config.kv_transfer_config and self.vllm_config.kv_transfer_config.is_kv_consumer ) + self.is_kv_producer = self.vllm_config.kv_transfer_config and self.vllm_config.kv_transfer_config.is_kv_producer + self.is_hybrid_model = ( + "qwen3_next" in self.vllm_config.model_config.model_type + or "qwen3_5" in self.vllm_config.model_config.model_type + ) def add_request(self, request: Request) -> None: existing = self.requests.get(request.request_id) @@ -111,6 +116,10 @@ class RecomputeScheduler(Scheduler): request.streaming_queue = deque() # Fill in placeholder tokens to enable full graph compatibility. Without # placeholders, graph matching may fail, forcing eager mode execution. + if self.is_kv_producer and self.is_hybrid_model and request.num_tokens > 1: + request.prompt_token_ids.pop() + request._all_token_ids.pop() + request.num_prompt_tokens -= 1 if self.is_mtp_kv_consumer: request.spec_token_ids = [PLACEHOLDER_TOKEN_ID] * self.num_spec_tokens self.waiting.add_request(request) @@ -118,6 +127,55 @@ class RecomputeScheduler(Scheduler): if self.log_stats: request.record_event(EngineCoreEventType.QUEUED) + def _update_waiting_for_remote_kv(self, request: Request) -> bool: + """ + KV Connector: check if the request_id is finished_recving. + + The finished_recving_kv_req_ids list is populated + on the previous steps()'s update_from_output based + on the worker side connector. + + When the kv transfer is ready, we cache the blocks + and the request state will be moved back to WAITING from + WAITING_FOR_REMOTE_KV. + """ + assert self.connector is not None + if request.request_id not in self.finished_recving_kv_req_ids: + return False + + if request.request_id in self.failed_recving_kv_req_ids: + # Request had KV load failures; num_computed_tokens was already + # updated in _update_requests_with_invalid_blocks + if request.num_computed_tokens: + # Cache any valid computed tokens. + self.kv_cache_manager.cache_blocks(request, request.num_computed_tokens) + else: + # No valid computed tokens, release allocated blocks. + # There may be a local cache hit on retry. + self.kv_cache_manager.free(request) + + self.failed_recving_kv_req_ids.remove(request.request_id) + else: + # Now that the blocks are ready, actually cache them. + block_ids = self.kv_cache_manager.get_block_ids(request.request_id) + if len(block_ids) == 1: + num_computed_tokens = len(block_ids[0]) * self.block_size + # Handle the case where num request tokens less than one block. + num_computed_tokens = min(num_computed_tokens, request.num_tokens) + else: + num_computed_tokens = request.num_tokens + if num_computed_tokens == request.num_tokens: + num_computed_tokens -= 1 + # This will cache the blocks iff caching is enabled. + self.kv_cache_manager.cache_blocks(request, num_computed_tokens) + + # Update the request state for scheduling. + request.num_computed_tokens = num_computed_tokens + + # Return that we are ready. + self.finished_recving_kv_req_ids.remove(request.request_id) + return True + def schedule(self) -> RecomputeSchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py index c49d1621..415bb6c9 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py @@ -1,3 +1,4 @@ +# mypy: ignore-errors # SPDX-License-Identifier: Apache-2.0 import contextlib import copy @@ -24,7 +25,12 @@ import zmq from mooncake.engine import TransferEngine # type: ignore from vllm.config import VllmConfig from vllm.distributed import get_pcp_group -from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, + SupportsHMA, +) from vllm.distributed.parallel_state import ( get_decode_context_model_parallel_rank, get_tensor_model_parallel_rank, @@ -34,8 +40,18 @@ from vllm.distributed.parallel_state import ( from vllm.logger import logger from vllm.utils.math_utils import round_down from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket +from vllm.utils.torch_utils import get_dtype_size from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + FullAttentionSpec, + KVCacheConfig, + KVCacheSpec, + MambaSpec, + SlidingWindowSpec, + UniformTypeKVCacheSpecs, +) +from vllm.v1.worker.utils import extract_layer_index from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector import GET_META_MSG @@ -63,29 +79,38 @@ if TYPE_CHECKING: DONE_SENDING_MSG = b"done_sending_msg" +@dataclass +class LayerMetadata: + tensor_group_idx: list[int] + kv_caches_base_addr: list[int] + block_len: list[int] + block_size_scale: list[int] + + class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True): te_rpc_port: int - kv_caches_base_addr: list[int] + layer_metadata: dict[str, LayerMetadata] @dataclass class ReqMeta: - local_block_ids: list[int] + local_block_ids: list[list[int]] token_ids: list[int] | None # Not None if layer-wise is disabled - remote_block_ids: list[int] + remote_block_ids: list[list[int]] + remote_block_size: list[list[int]] remote_engine_id: str | None remote_host: str | None remote_port: int | None remote_te_rpc_port: int | None - remote_kv_caches_base_addr: list[int] | None + remote_layer_metadata: dict[str, LayerMetadata] | None metaserver: str | None remote_tp_size: int | None remote_pcp_size: int | None remote_dcp_size: int | None chunk_finish: bool = False prompt_len: int = 0 - trans_count: int = 0 + trans_count: list[int] | None = None remote_cache_tokens: int = 0 local_computed_tokens: int = 0 local_transed_tokens: int = 0 @@ -100,13 +125,14 @@ class SendTask: k_cache: torch.Tensor | None = None v_cache: torch.Tensor | None = None layer_idx: int = 0 + layer_name: str = "" # trans block info - rearrange_block_ids: list[int] | None = None - num_blocks: int | None = None - num_tokens: int | None = None - block_table: torch.Tensor | None = None - block_len_tensor: torch.Tensor | None = None - seq_start_tensor: torch.Tensor | None = None + group_rearrange_block_ids: list[list[int]] | None = None + group_num_blocks: list[int] | None = None + group_num_tokens: list[int] | None = None + group_block_table: list[torch.Tensor | None] | None = None + group_block_len_tensor: list[torch.Tensor | None] | None = None + group_seq_start_tensor: list[torch.Tensor | None] | None = None @dataclass @@ -119,14 +145,15 @@ class TransferMeta: @dataclass class SendReqInfo: - local_block_ids: list[int] + local_block_ids: list[list[int]] local_transferred_tokens: int local_computed_tokens: int request: "Request" - def extend_local_block_ids(self, new_block_ids: list[int]) -> None: + def extend_local_block_ids(self, new_block_ids: list[list[int]]) -> None: """extend local block ids for this step""" - self.local_block_ids.extend(new_block_ids) + for i, new_block_id in enumerate(new_block_ids): + self.local_block_ids[i].extend(new_block_id) def update_computed_tokens(self, computed_tokens: int) -> None: """update local computen tokens for this step""" @@ -169,14 +196,18 @@ class KVCacheSendingLayerThread(threading.Thread): def __init__( self, engine: TransferEngine, + vllm_config: VllmConfig, + kv_cache_config: KVCacheConfig, + kv_cache_specs: list[KVCacheSpec], + attn_resharding_group_idx: set, total_layers: int, ready_event: threading.Event, + tp_size: int, tp_rank: int, pd_head_ratio: int, num_head_replica: int, - kv_cache_base_addr: list[int], + layer_metadata: dict[str, LayerMetadata], use_mla: bool, - block_len: list[int], k_buffer: torch.Tensor, v_buffer: torch.Tensor, resharding_stream: torch.npu.Stream, @@ -184,14 +215,17 @@ class KVCacheSendingLayerThread(threading.Thread): ): super().__init__(daemon=True, name="KVCacheSendingLayerThread") self.engine = engine + self.vllm_config = vllm_config + self.kv_cache_config = kv_cache_config + self.kv_cache_specs = kv_cache_specs + self.attn_resharding_group_idx = attn_resharding_group_idx + self.tp_size = tp_size self.tp_rank = tp_rank self.pd_head_ratio = pd_head_ratio self.num_head_replica = num_head_replica - self.kv_caches_base_addr = kv_cache_base_addr + self.layer_metadata = layer_metadata self.total_layers = total_layers self.use_mla = use_mla - self.use_sparse = len(block_len) == 3 - self.block_len = block_len self.resharding_stream = resharding_stream self.current_layer = -1 @@ -216,86 +250,132 @@ class KVCacheSendingLayerThread(threading.Thread): except Exception as e: logger.error(f"Failed to transfer KV cache for layer idx {send_task.layer_idx}, {e}") - def get_transfer_meta(self, send_task: SendTask, req_id: str, req_meta: ReqMeta): - src_list: list[str] = [] - dst_list: list[str] = [] + def get_transfer_meta(self, send_task: SendTask, req_id: str, req_meta: ReqMeta, layer_group_idx: int): + src_list: list[int] = [] + dst_list: list[int] = [] length_list: list[int] = [] - layer_idx = send_task.layer_idx - remote_block_ids = req_meta.remote_block_ids - remote_kv_base_addrs = req_meta.remote_kv_caches_base_addr - local_kv_base_addr = self.kv_caches_base_addr - local_block_ids = req_meta.local_block_ids + layer_name = send_task.layer_name + layer_kv_cache_spec = self.kv_cache_specs[layer_group_idx] + remote_block_ids = req_meta.remote_block_ids[layer_group_idx] + remote_layer_metadata = req_meta.remote_layer_metadata[layer_name] + local_layer_metadata = self.layer_metadata[layer_name] + local_block_ids = req_meta.local_block_ids[layer_group_idx] - if self.pd_head_ratio == 1: - if self.use_sparse: - layer_local_kv_base_addr = [ - local_kv_base_addr[i] for i in [3 * layer_idx, 3 * layer_idx + 1, 3 * layer_idx + 2] - ] - layer_remote_kv_base_addr = [ - remote_kv_base_addrs[i] # type:ignore - for i in [3 * layer_idx, 3 * layer_idx + 1, 3 * layer_idx + 2] - ] + if isinstance(layer_kv_cache_spec, MambaSpec): + # only support one block transfer for mamba + local_conv_addr, local_ssm_addr = local_layer_metadata.kv_caches_base_addr + remote_conv_addr, remote_ssm_addr = remote_layer_metadata.kv_caches_base_addr + local_conv_len, local_ssm_len = local_layer_metadata.block_len + tp_ratio = self.tp_size // req_meta.remote_tp_size + if tp_ratio == 1: + src_list.extend( + [ + local_conv_addr + local_block_ids[0] * local_conv_len, + local_ssm_addr + local_block_ids[0] * local_ssm_len, + ] + ) + dst_list.extend( + [ + remote_conv_addr + remote_block_ids[0] * local_conv_len, + remote_ssm_addr + remote_block_ids[0] * local_ssm_len, + ] + ) + length_list.extend([local_conv_len, local_ssm_len]) else: - layer_local_kv_base_addr = [local_kv_base_addr[i] for i in [2 * layer_idx, 2 * layer_idx + 1]] - layer_remote_kv_base_addr = [ - remote_kv_base_addrs[i] # type:ignore - for i in [2 * layer_idx, 2 * layer_idx + 1] + conv_shape, ssm_shape = layer_kv_cache_spec.shapes + conv_dtype, ssm_dtype = layer_kv_cache_spec.dtypes + remote_conv_len, remote_ssm_len = remote_layer_metadata.block_len + # conv + linear_key_head_dim = self.vllm_config.model_config.hf_text_config.linear_key_head_dim + linear_num_key_heads = self.vllm_config.model_config.hf_text_config.linear_num_key_heads + linear_value_head_dim = self.vllm_config.model_config.hf_text_config.linear_value_head_dim + linear_num_value_heads = self.vllm_config.model_config.hf_text_config.linear_num_value_heads + local_num_key_heads = linear_num_key_heads // self.tp_size + local_num_value_heads = linear_num_value_heads // self.tp_size + local_conv_offsets = [ + 0, + local_num_key_heads * linear_key_head_dim, + local_num_key_heads * 2 * linear_key_head_dim, ] - grouped_remote_block_ids, grouped_local_block_ids = group_concurrent_contiguous( - remote_block_ids, local_block_ids - ) - - block_length = len(self.block_len) - for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( - zip(layer_local_kv_base_addr, layer_remote_kv_base_addr) - ): - block_len = self.block_len[k % block_length] - for group_remote_block_id, group_local_block_id in zip( - grouped_remote_block_ids, grouped_local_block_ids - ): - src = src_layer_base_addr + group_local_block_id[0] * block_len - dst = dst_layer_base_addr + group_remote_block_id[0] * block_len - length = len(group_local_block_id) * block_len - src_list.append(src) - dst_list.append(dst) - length_list.append(length) + local_conv_sizes = [ + local_num_key_heads * linear_key_head_dim, + local_num_key_heads * linear_key_head_dim, + local_num_value_heads * linear_value_head_dim, + ] + for i in range(conv_shape[0]): + for local_conv_offset, local_conv_size in zip(local_conv_offsets, local_conv_sizes): + local_addr_offset = (i * conv_shape[1] + local_conv_offset) * get_dtype_size(conv_dtype) + remote_addr_offset = ( + (i * conv_shape[1] * tp_ratio) + (self.tp_rank % tp_ratio) * local_conv_size + ) * get_dtype_size(conv_dtype) + src_list.append(local_conv_addr + local_block_ids[0] * local_conv_len + local_addr_offset) + dst_list.append(remote_conv_addr + remote_block_ids[0] * remote_conv_len + remote_addr_offset) + length_list.append(local_conv_size * get_dtype_size(conv_dtype)) + # ssm + remote_addr_offset = (self.tp_rank % tp_ratio) * math.prod(ssm_shape) * get_dtype_size(ssm_dtype) + src_list.append(local_ssm_addr + local_block_ids[0] * local_ssm_len) + dst_list.append(remote_ssm_addr + remote_block_ids[0] * remote_ssm_len + remote_addr_offset) + length_list.append(local_ssm_len) else: - rearrange_block_ids = send_task.rearrange_block_ids - rearrange_block_dict = { - value: index - for index, value in enumerate(rearrange_block_ids) # type:ignore - } - layer_local_kv_base_addr = [self.k_buffer.data_ptr(), self.v_buffer.data_ptr()] - - layer_remote_kv_base_addr = [ - remote_kv_base_addrs[i] # type:ignore - for i in [2 * layer_idx, 2 * layer_idx + 1] - ] - - src_list, dst_list, length_list = [], [], [] - for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( - zip(layer_local_kv_base_addr, layer_remote_kv_base_addr) - ): - block_len = self.block_len[0] - remote_block_len = self.block_len[0] * self.pd_head_ratio - for remote_block_id, local_block_id in zip(remote_block_ids, local_block_ids): - src = src_layer_base_addr + rearrange_block_dict[local_block_id] * block_len - dst = ( - dst_layer_base_addr - + remote_block_id * remote_block_len - + block_len * ((self.tp_rank // self.num_head_replica) % self.pd_head_ratio) - ) - src_list.append(src) - dst_list.append(dst) - length_list.append(block_len) + if self.pd_head_ratio == 1: + layer_local_kv_base_addr = local_layer_metadata.kv_caches_base_addr + layer_remote_kv_base_addr = remote_layer_metadata.kv_caches_base_addr + block_lens = local_layer_metadata.block_len + grouped_remote_block_ids, grouped_local_block_ids = group_concurrent_contiguous( + remote_block_ids, local_block_ids + ) + for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( + zip(layer_local_kv_base_addr, layer_remote_kv_base_addr) + ): + block_len = block_lens[k] + for group_remote_block_id, group_local_block_id in zip( + grouped_remote_block_ids, grouped_local_block_ids + ): + src = src_layer_base_addr + group_local_block_id[0] * block_len + dst = dst_layer_base_addr + group_remote_block_id[0] * block_len + length = len(group_local_block_id) * block_len + src_list.append(src) + dst_list.append(dst) + length_list.append(length) + else: + rearrange_block_ids = send_task.group_rearrange_block_ids[layer_group_idx] + rearrange_block_dict = { + value: index + for index, value in enumerate(rearrange_block_ids) # type:ignore + } + layer_local_kv_base_addr = [self.k_buffer.data_ptr(), self.v_buffer.data_ptr()] + layer_remote_kv_base_addr = remote_layer_metadata.kv_caches_base_addr + block_lens = local_layer_metadata.block_len + remote_block_lens = remote_layer_metadata.block_len + assert len(layer_remote_kv_base_addr) == 2, ( + "Layer kv_cache resharding only supports two kv cache tensors." + ) + src_list, dst_list, length_list = [], [], [] + for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( + zip(layer_local_kv_base_addr, layer_remote_kv_base_addr) + ): + block_len = block_lens[k] + remote_block_len = remote_block_lens[k] + for remote_block_id, local_block_id in zip(remote_block_ids, local_block_ids): + src = src_layer_base_addr + rearrange_block_dict[local_block_id] * block_len + dst = ( + dst_layer_base_addr + + remote_block_id * remote_block_len + + block_len * ((self.tp_rank // self.num_head_replica) % self.pd_head_ratio) + ) + src_list.append(src) + dst_list.append(dst) + length_list.append(block_len) return (src_list, dst_list, length_list) def _transfer_kv_cache(self, send_task: SendTask): - if self.pd_head_ratio > 1: + layer_name = send_task.layer_name + layer_group_idx = self.layer_metadata[layer_name].tensor_group_idx[0] + key = send_task.k_cache + value = send_task.v_cache + if self.pd_head_ratio > 1 and key is not None and value is not None: with npu_stream_switch(self.resharding_stream): - key = send_task.k_cache - value = send_task.v_cache key = key.view(-1, key.shape[-1]) # type:ignore value = value.view(-1, key.shape[-1]) # type:ignore self.k_buffer[: key.shape[0]].copy_(key) # [:4, 128] -> @@ -308,7 +388,7 @@ class KVCacheSendingLayerThread(threading.Thread): if session_id not in session_meta: session_meta[session_id] = TransferMeta(src=[], dst=[], length=[], req_ids=[]) - (src_list, dst_list, length_list) = self.get_transfer_meta(send_task, req_id, req_meta) + (src_list, dst_list, length_list) = self.get_transfer_meta(send_task, req_id, req_meta, layer_group_idx) session_meta[session_id].src.extend(src_list) session_meta[session_id].dst.extend(dst_list) @@ -340,14 +420,14 @@ class KVCacheSendingLayerThread(threading.Thread): req_meta = send_task.send_request[req_id] if req_meta.chunk_finish: self.callback_func( - req_id, req_meta + req_id, req_meta, layer_group_idx ) # TODO Send a signal indicating transmission failure else: if send_task.layer_idx == (self.total_layers - 1): for req_id in transfer_meta.req_ids: req_meta = send_task.send_request[req_id] if req_meta.chunk_finish: - self.callback_func(req_id, req_meta) + self.callback_func(req_id, req_meta, layer_group_idx) class KVCacheRecvingLayerThread(threading.Thread): @@ -441,7 +521,7 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata): def add_new_req( self, request_id: str, - local_block_ids: list[int], + local_block_ids: list[list[int]], kv_transfer_params: dict[str, Any], token_ids: list[int] | None = None, chunk_finish: bool = False, @@ -454,11 +534,12 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata): token_ids=token_ids or [], local_block_ids=local_block_ids, remote_block_ids=kv_transfer_params.get("remote_block_ids", []), + remote_block_size=kv_transfer_params.get("remote_block_size", []), remote_engine_id=kv_transfer_params.get("remote_engine_id"), remote_host=kv_transfer_params.get("remote_host"), remote_port=kv_transfer_params.get("remote_port"), remote_te_rpc_port=kv_transfer_params.get("remote_te_rpc_port"), - remote_kv_caches_base_addr=kv_transfer_params.get("remote_kv_caches_base_addr"), + remote_layer_metadata=kv_transfer_params.get("remote_layer_metadata"), metaserver=kv_transfer_params.get("metaserver"), remote_tp_size=kv_transfer_params.get("remote_tp_size"), remote_pcp_size=kv_transfer_params.get("remote_pcp_size"), @@ -468,10 +549,11 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata): local_computed_tokens=local_computed_tokens, prompt_len=prompt_len, local_transed_tokens=local_transed_tokens, + trans_count=[], ) -class MooncakeLayerwiseConnector(KVConnectorBase_V1): +class MooncakeLayerwiseConnector(KVConnectorBase_V1, SupportsHMA): def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: KVCacheConfig | None = None): super().__init__(vllm_config, role, kv_cache_config) assert vllm_config.kv_transfer_config is not None @@ -480,12 +562,12 @@ class MooncakeLayerwiseConnector(KVConnectorBase_V1): if role == KVConnectorRole.SCHEDULER: self.connector_scheduler: MooncakeLayerwiseConnectorScheduler | None = MooncakeLayerwiseConnectorScheduler( - vllm_config, str(self.engine_id) + vllm_config, kv_cache_config, str(self.engine_id) ) self.connector_worker: MooncakeLayerwiseConnectorWorker | None = None elif role == KVConnectorRole.WORKER: self.connector_scheduler = None - self.connector_worker = MooncakeLayerwiseConnectorWorker(vllm_config, str(self.engine_id)) + self.connector_worker = MooncakeLayerwiseConnectorWorker(vllm_config, kv_cache_config, str(self.engine_id)) ############################################################ # Scheduler Side Methods @@ -514,6 +596,14 @@ class MooncakeLayerwiseConnector(KVConnectorBase_V1): assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) + def request_finished_all_groups( + self, + request: "Request", + block_ids: tuple[list[int], ...], + ) -> tuple[bool, dict[str, Any] | None]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished_all_groups(request, block_ids) + ############################################################ # Worker Side Methods ############################################################ @@ -553,14 +643,21 @@ class MooncakeLayerwiseConnector(KVConnectorBase_V1): class MooncakeLayerwiseConnectorScheduler: """Implementation of Scheduler side methods""" - def __init__(self, vllm_config: VllmConfig, engine_id: str): + def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig, engine_id: str): self.vllm_config = vllm_config - self.block_size = vllm_config.cache_config.block_size + self.kv_cache_config = kv_cache_config + self.block_size = [group_spec.kv_cache_spec.block_size for group_spec in kv_cache_config.kv_cache_groups] self.engine_id = engine_id logger.info("Initializing Mooncake Scheduler %s", engine_id) self.side_channel_host = get_ip() + # disable prefill context parallel on decoder nodes + if vllm_config.kv_transfer_config.is_kv_consumer: + assert vllm_config.parallel_config.prefill_context_parallel_size == 1, ( + "Prefill context parallel is not support on decoder nodes" + ) + # Handshake base port self.side_channel_port = ( vllm_config.kv_transfer_config.kv_port @@ -570,7 +667,7 @@ class MooncakeLayerwiseConnectorScheduler: # Requests that need to start recv. # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. - self._reqs_need_recv: dict[str, tuple[Request, list[int], list[int]]] = {} + self._reqs_need_recv: dict[str, tuple[Request, list[int], list[list[int]]]] = {} self._reqs_need_send_layerwise: dict[str, SendReqInfo] = {} self.executor = ThreadPoolExecutor(32) tls_config: dict[str, Any] = vllm_config.kv_transfer_config.get_from_extra_config("tls_config", {}) @@ -613,7 +710,7 @@ class MooncakeLayerwiseConnectorScheduler: if params is not None and params.get("do_remote_prefill"): # Remote prefill: get all prompt blocks from remote. - assert num_computed_tokens % self.block_size == 0 + assert num_computed_tokens % min(self.block_size) == 0 # Note: We use the full token count as transmit data here. count = max(len(request.prompt_token_ids) - num_computed_tokens, 0) return count, count > 0 @@ -630,7 +727,7 @@ class MooncakeLayerwiseConnectorScheduler: ) if params is not None and params.get("do_remote_prefill"): - local_block_ids = (blocks.get_block_ids()[0]) if num_external_tokens > 0 else [] + local_block_ids = (blocks.get_block_ids()) if num_external_tokens > 0 else [] remote_cached_tokens = request.num_computed_tokens # Get unhashed blocks to pull from remote. logger.debug( @@ -655,6 +752,7 @@ class MooncakeLayerwiseConnectorScheduler: do_remote_prefill=False, do_remote_decode=True, remote_block_ids=local_block_ids, + remote_block_size=self.block_size, remote_engine_id=self.engine_id, remote_host=self.side_channel_host, remote_port=self.side_channel_port, @@ -676,14 +774,11 @@ class MooncakeLayerwiseConnectorScheduler: # Layerwise prefiller add request need send if params is not None and params.get("do_remote_decode"): - local_block_ids = blocks.get_block_ids()[0] + local_block_ids = list(blocks.get_block_ids()) logger.debug( f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need send queue" ) - remote_block_ids = copy.deepcopy(params["remote_block_ids"]) - remote_cache_tokens = ( - (len(request.all_token_ids) + self.block_size - 1) // self.block_size - len(remote_block_ids) - ) * self.block_size + remote_cache_tokens = params["remote_cached_tokens"] local_transferred_tokens = remote_cache_tokens local_computed_tokens = 0 self._reqs_need_send_layerwise[request.request_id] = SendReqInfo( @@ -721,7 +816,7 @@ class MooncakeLayerwiseConnectorScheduler: scheduled_spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens for req_id, new_blocks in zip(cached_reqs.req_ids, cached_reqs.new_block_ids): if req_id in self._reqs_need_send_layerwise and new_blocks is not None: - self._reqs_need_send_layerwise[req_id].extend_local_block_ids(new_blocks[0]) + self._reqs_need_send_layerwise[req_id].extend_local_block_ids(new_blocks) computed_tokens = dict( list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens)) + [(x.req_id, x.num_computed_tokens) for x in new_reqs] @@ -731,7 +826,7 @@ class MooncakeLayerwiseConnectorScheduler: send_req_info = self._reqs_need_send_layerwise[req_id] # update local transferred tokens send_req_info.update_transferred_tokens( - round_down(send_req_info.local_computed_tokens, self.block_size) + round_down(send_req_info.local_computed_tokens, min(self.block_size)) ) # update local computed tokens, not transfer spec decode tokens spec_decode_tokens = ( @@ -801,11 +896,23 @@ class MooncakeLayerwiseConnectorScheduler: # layer_wise push, not need delay_free_blocks return False, None + def request_finished_all_groups( + self, + request: "Request", + block_ids: tuple[list[int], ...], + ) -> tuple[bool, dict[str, Any] | None]: + """ + Once a request is finished, determine whether request blocks + should be freed now or will be sent asynchronously and freed later. + """ + # layer_wise push, not need delay_free_blocks + return False, None + class MooncakeLayerwiseConnectorWorker: """Implementation of Worker side methods""" - def __init__(self, vllm_config: VllmConfig, engine_id: str): + def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig, engine_id: str): os.environ["ASCEND_TRANSFER_TIMEOUT"] = str(get_transfer_timeout_value()) if TransferEngine is None: @@ -814,16 +921,18 @@ class MooncakeLayerwiseConnectorWorker: # Metadata. self.vllm_config = vllm_config + self.kv_cache_config = kv_cache_config + self.num_kv_cache_groups = len(self.kv_cache_config.kv_cache_groups) + self.kv_cache_specs: list[KVCacheSpec] = [spec.kv_cache_spec for spec in self.kv_cache_config.kv_cache_groups] self.local_engine_id: str = " " self.engine_id = engine_id - self.tp_rank = get_tensor_model_parallel_rank() - self.tp_size = vllm_config.parallel_config.tensor_parallel_size - self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size - self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 - self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size - self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0 + self.tp_rank: int = get_tensor_model_parallel_rank() + self.tp_size: int = vllm_config.parallel_config.tensor_parallel_size + self.pcp_size: int = vllm_config.parallel_config.prefill_context_parallel_size + self.pcp_rank: int = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 + self.dcp_size: int = vllm_config.parallel_config.decode_context_parallel_size + self.dcp_rank: int = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0 self.tp_group = get_tp_group() - self._decode_tp_size: int | None = None self.kv_caches: dict[str, torch.Tensor] = {} self.side_channel_host = get_ip() self.total_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config) @@ -848,11 +957,11 @@ class MooncakeLayerwiseConnectorWorker: self.kv_recv_layer_thread: KVCacheRecvingLayerThread | None = None self.kv_send_layer_thread: KVCacheSendingLayerThread | None = None - self.vllm_config = vllm_config - self.block_size = vllm_config.cache_config.block_size - self.kv_caches_base_addr: list[int] = [] + self.block_size: list[int] = [spec.block_size for spec in self.kv_cache_specs] + self.kernel_block_size_scale: list[int] = [1 for _ in range(self.num_kv_cache_groups)] + self.layer_metadata: dict[str, LayerMetadata] = {} + self.attn_resharding_group_idx = set[int]() - self.pd_tp_ratio = get_ascend_config().pd_tp_ratio self.pd_head_ratio = get_ascend_config().pd_head_ratio self.num_head_replica = get_ascend_config().num_head_replica self.resharding_stream = None @@ -863,7 +972,8 @@ class MooncakeLayerwiseConnectorWorker: self.decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata) self.encoder = msgspec.msgpack.Encoder() - self.remote_kv_caches_base_addr: dict[str, dict[int, list[int]]] = SizedDict() + self.index_to_name = defaultdict(list) + self.remote_layer_metadata: dict[str, dict[int, dict[str, LayerMetadata]]] = SizedDict() self.remote_te_port: dict[str, dict[int, int]] = SizedDict() self.remote_sockets_lock = threading.Lock() self.remote_sockets: dict[ # type: ignore @@ -907,72 +1017,120 @@ class MooncakeLayerwiseConnectorWorker: def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data.""" + layer2group_ids: dict[str, int] = {} + kv_cache_groups = self.kv_cache_config.kv_cache_groups + for i, kv_cache_group_spec in enumerate(kv_cache_groups): + for layer_name in kv_cache_group_spec.layer_names: + layer2group_ids[layer_name] = i - _, first_kv_cache_tuple = next(iter(kv_caches.items())) - first_kv_cache = first_kv_cache_tuple[0] - self.create_kv_buffer(first_kv_cache) + use_mamba, use_attn, use_attn_mamba_hybrid = False, False, False + conv_total_padding_size = 0 + for kv_cache_tensor in self.kv_cache_config.kv_cache_tensors: + for layer_name in kv_cache_tensor.shared_by: + layer_kv_cache_spec = self.kv_cache_specs[layer2group_ids[layer_name]] + if isinstance(layer_kv_cache_spec, MambaSpec): + use_mamba = True + conv_shape, conv_dtype = layer_kv_cache_spec.shapes[0], layer_kv_cache_spec.dtypes[0] + conv_total_padding_size = ( + self.kv_cache_config.num_blocks * math.prod(conv_shape) * get_dtype_size(conv_dtype) + ) + if isinstance(layer_kv_cache_spec, AttentionSpec): + use_attn = True + if use_mamba and use_attn: + use_attn_mamba_hybrid = True + break - # TODO(tms): Find a more robust way to detect and handle MLA - self.use_mla = ( - first_kv_cache_tuple[0].size(-1) != first_kv_cache_tuple[1].size(-1) and len(first_kv_cache_tuple) == 2 - ) - self.use_sparse = len(first_kv_cache_tuple) == 3 - - self.num_blocks = first_kv_cache.shape[0] - logger.info("num_blocks: %s", self.num_blocks) - block_rank = 3 - self.block_len = [] - if self.use_mla or self.use_sparse: - for i in range(len(first_kv_cache_tuple)): - block_shape = first_kv_cache_tuple[i].shape[-block_rank:] - logger.info("block_shape: %s", block_shape) - self.block_len.append(first_kv_cache[i].element_size() * math.prod(block_shape)) - else: - # [num_block, block_size, num_head, hidden_dim] - block_shape = first_kv_cache.shape[-block_rank:] - logger.info("block_shape: %s", block_shape) - self.block_len = [first_kv_cache.element_size() * math.prod(block_shape)] - - logger.info( - "Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s", - self.use_mla, - self.use_sparse, - first_kv_cache.shape, - ) - - self.kv_caches = kv_caches - kv_caches_base_addr = [] ptrs = [] lengths = [] - length = len(self.block_len) - for cache_or_caches in kv_caches.values(): - # Normalize to always be a list of caches - for i, cache in enumerate(cache_or_caches, 0): - base_addr = cache.data_ptr() - region_len = self.num_blocks * self.block_len[i % length] - kv_caches_base_addr.append(base_addr) - ptrs.append(base_addr) - lengths.append(region_len) + use_resharding_buffer = False + resharding_buffer = None + for layer_name, kv_cache_tuple in kv_caches.items(): + if isinstance(kv_cache_tuple, (list, tuple)) is False: + kv_cache_tuple = [kv_cache_tuple] + layer_kv_group_id = layer2group_ids[layer_name] + layer_kv_cache_spec = kv_cache_groups[layer_kv_group_id].kv_cache_spec + if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs): + layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name] + if self.pd_head_ratio > 1 and (isinstance(layer_kv_cache_spec, (FullAttentionSpec, SlidingWindowSpec))): + self.attn_resharding_group_idx.add(layer_kv_group_id) + if use_resharding_buffer is False: + use_resharding_buffer = True + resharding_buffer = kv_cache_tuple[0] + self.resharding_stream = torch.npu.Stream() + single_layer_meta = LayerMetadata([], [], [], []) + for single_kv_cache in kv_cache_tuple: + block_start_rank = 1 + num_blocks = self.kv_cache_config.num_blocks + tensor_num_blocks = single_kv_cache.shape[0] + assert tensor_num_blocks % num_blocks == 0, ( + "The external block size must be an integer multiple of the kernel block size." + ) + block_size_scale = tensor_num_blocks // num_blocks + block_shape = single_kv_cache.shape[block_start_rank:] + single_layer_meta.tensor_group_idx.append(layer_kv_group_id) + single_layer_meta.kv_caches_base_addr.append(single_kv_cache.data_ptr()) + single_layer_meta.block_len.append(single_kv_cache.element_size() * math.prod(block_shape)) + single_layer_meta.block_size_scale.append(block_size_scale) + self.kernel_block_size_scale[layer2group_ids[layer_name]] = block_size_scale + if single_kv_cache.data_ptr() not in ptrs and not use_attn_mamba_hybrid: + ptrs.append(single_kv_cache.data_ptr()) + lengths.append( + num_blocks * single_kv_cache.element_size() * math.prod(block_shape) * block_size_scale + ) + logger.info(f"layer: {layer_name}, num_blocks: {num_blocks}, block_shape: {block_shape}") + self.layer_metadata[layer_name] = single_layer_meta + + if use_attn_mamba_hybrid: + for kv_cache_tensor in self.kv_cache_config.kv_cache_tensors: + tensor_addrs = [] + for layer_name in kv_cache_tensor.shared_by: + tensor_addrs.extend(self.layer_metadata[layer_name].kv_caches_base_addr) + if "mtp" in layer_name: + tensor_addrs.append(min(tensor_addrs) - conv_total_padding_size) + assert min(set(tensor_addrs)) % (2 * 1024 * 1024) == 0, "Tensor start addr is not align with 2M." + ptrs.append(min(set(tensor_addrs))) + lengths.append(kv_cache_tensor.size) + global_te.register_buffer(ptrs, lengths) - self.kv_caches_base_addr = kv_caches_base_addr + if use_resharding_buffer: + self.create_kv_buffer(resharding_buffer) + + num_attn_module = 2 if self.vllm_config.model_config.hf_text_config.model_type == "longcat_flash" else 1 + mtp_layer_name = "" + for layer_name in kv_caches: + if "mtp" in layer_name: + mtp_layer_name = layer_name + continue + self.index_to_name[extract_layer_index(layer_name, num_attn_module)].append(layer_name) + assert len(self.index_to_name[extract_layer_index(layer_name, num_attn_module)]) == 1, ( + "Mooncake Layerwise Connector does not support multiple `attn_module` in one layer now." + ) + if mtp_layer_name != "": + self.index_to_name[max(self.index_to_name.keys()) + 1].append(mtp_layer_name) + if self.total_layers < len(self.layer_metadata.keys()): + self.total_layers = len(self.layer_metadata.keys()) # After KV Caches registered, start the sending or receiving thread. metadata = MooncakeAgentMetadata( te_rpc_port=self.te_rpc_port, - kv_caches_base_addr=self.kv_caches_base_addr, + layer_metadata=self.layer_metadata, ) if self.vllm_config.kv_transfer_config.is_kv_producer: ready_event = threading.Event() self.kv_send_layer_thread = KVCacheSendingLayerThread( engine=self.engine, + vllm_config=self.vllm_config, + kv_cache_config=self.kv_cache_config, + kv_cache_specs=self.kv_cache_specs, + attn_resharding_group_idx=self.attn_resharding_group_idx, total_layers=self.total_layers, ready_event=ready_event, + tp_size=self.tp_size, tp_rank=self.tp_rank, pd_head_ratio=self.pd_head_ratio, num_head_replica=self.num_head_replica, - kv_cache_base_addr=self.kv_caches_base_addr, + layer_metadata=self.layer_metadata, use_mla=self.use_mla, - block_len=self.block_len, k_buffer=self.k_buffer, v_buffer=self.v_buffer, resharding_stream=self.resharding_stream, @@ -1009,7 +1167,7 @@ class MooncakeLayerwiseConnectorWorker: return set(), done_recving # {(ip, port)]: {local_block_ids: [], remote_block_ids: {}}} - def _get_kv_split_metadata(self, req_meta, req_idx, req_id): + def _get_kv_split_metadata(self, req_meta: ReqMeta, req_idx: int, req_id: str, group_idx: int): remote_pcp_size = req_meta.remote_pcp_size remote_dcp_size = req_meta.remote_dcp_size remote_tp_size = req_meta.remote_tp_size @@ -1036,10 +1194,10 @@ class MooncakeLayerwiseConnectorWorker: cp_size = self.pcp_size * self.dcp_size # to_trans_idx all tokens that have been processed up to the current step if req_meta.chunk_finish: - to_trans_idx = math.ceil(local_computed_tokens / self.block_size) + to_trans_idx = math.ceil(local_computed_tokens / self.block_size[group_idx]) else: - to_trans_idx = math.floor(local_computed_tokens / self.block_size) - prompt_block_size = math.ceil(prompt_len / self.block_size) + to_trans_idx = math.floor(local_computed_tokens / self.block_size[group_idx]) + prompt_block_size = math.ceil(prompt_len / self.block_size[group_idx]) # num_local_blocks = prompt_block_size // cp_size + int( (prompt_block_size % cp_size) > (self.pcp_rank * self.dcp_size + self.dcp_rank) @@ -1049,7 +1207,7 @@ class MooncakeLayerwiseConnectorWorker: ) if num_local_blocks == already_send_blocks: req_meta.chunk_finish = True - transed_idx = math.floor(local_transed_tokens / self.block_size) + transed_idx = math.floor(local_transed_tokens / self.block_size[group_idx]) p_cp_group = get_cp_group(self.tp_size, self.total_num_kv_heads, self.dcp_size) d_cp_group = get_cp_group(remote_tp_size, self.total_num_kv_heads, remote_dcp_size) @@ -1095,7 +1253,7 @@ class MooncakeLayerwiseConnectorWorker: selected_p_cp_group, selected_d_cp_group, prompt_len, - self.block_size, + self.block_size[group_idx], req_meta, self.total_num_kv_heads, req_id, @@ -1107,6 +1265,7 @@ class MooncakeLayerwiseConnectorWorker: pd_head_mapping, d_trans_count_mapping, req_meta, + group_idx, p_parallel_info, req_id, transed_idx, @@ -1117,6 +1276,57 @@ class MooncakeLayerwiseConnectorWorker: ) return transfer_mappings + def _get_kv_split_metadata_for_mamba(self, req_meta: ReqMeta, req_idx: int, req_id: str, group_idx: int): + assert self.tp_size >= req_meta.remote_tp_size, ( + "Mamba group prefill TP_size must equal or larger than decode TP_size." + ) + remote_tp_size = req_meta.remote_tp_size + tp_raito = self.tp_size // remote_tp_size + remote_host = req_meta.remote_host + remote_port = req_meta.remote_port + self.tp_rank // tp_raito + transfer_mappings: dict[tuple[str, int], dict[str, Any]] = {} + if req_meta.chunk_finish: + transfer_mappings[(remote_host, remote_port)] = { + "local_block_ids": req_meta.local_block_ids[group_idx], + "remote_block_ids": req_meta.remote_block_ids[group_idx], + "trans_count": self.tp_size // remote_tp_size, + } + + return transfer_mappings + + def _align_remote_block_ids(self, req_meta: ReqMeta): + remote_block_size = req_meta.remote_block_size + remote_block_ids = req_meta.remote_block_ids + for i in range(self.num_kv_cache_groups): + if isinstance(self.kv_cache_specs[i], MambaSpec): + continue + if remote_block_size[i] != self.block_size[i] and len(req_meta.remote_block_ids[i]) > 0: + assert remote_block_size[i] > self.block_size[i] and remote_block_size[i] % self.block_size[i] == 0, ( + "Remote block size must be divisible by local block size." + ) + assert self.pcp_size * self.dcp_size * req_meta.remote_pcp_size * req_meta.remote_dcp_size == 1, ( + "Context parallel does not support different P/D block size now." + ) + pd_block_size_ratio = remote_block_size[i] // self.block_size[i] + remtote_block_ids_with_scale = [ + block_id * pd_block_size_ratio + j + for block_id in remote_block_ids[i] + for j in range(pd_block_size_ratio) + ] + req_meta.remote_block_ids[i] = remtote_block_ids_with_scale + + def _get_kernel_block_ids(self, block_ids): + for i in range(self.num_kv_cache_groups): + if isinstance(self.kv_cache_specs[i], MambaSpec): + continue + if len(block_ids[i]) > 0: + block_ids[i] = [ + block_id * self.kernel_block_size_scale[i] + j + for block_id in block_ids[i] + for j in range(self.kernel_block_size_scale[i]) + ] + return block_ids + def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata): """Start loading KV blocks from remote engine.""" self.current_layer = 0 @@ -1131,15 +1341,38 @@ class MooncakeLayerwiseConnectorWorker: # update trans info update_metadata = {} for req_idx, (req_id, req_meta) in enumerate(metadata.requests.items()): - self._decode_tp_size = req_meta.remote_tp_size - transfer_mappings = self._get_kv_split_metadata(req_meta, req_idx, req_id) + transfer_mappings: dict[tuple[str, int], dict[str, Any]] = {} + self._align_remote_block_ids(req_meta) + for i, kv_cache_spec in enumerate(self.kv_cache_specs): + if isinstance(kv_cache_spec, MambaSpec): + single_group_transfer_mappings = self._get_kv_split_metadata_for_mamba( + req_meta, req_idx, req_id, i + ) + else: + single_group_transfer_mappings = self._get_kv_split_metadata(req_meta, req_idx, req_id, i) + for (host, port), block_dict in single_group_transfer_mappings.items(): + if (host, port) not in transfer_mappings: + transfer_mappings[(host, port)] = { + "local_block_ids": [[] for _ in range(self.num_kv_cache_groups)], + "remote_block_ids": [[] for _ in range(self.num_kv_cache_groups)], + "trans_count": [0 for _ in range(self.num_kv_cache_groups)], + } + transfer_mappings[(host, port)]["local_block_ids"][i].extend( + single_group_transfer_mappings[(host, port)]["local_block_ids"] + ) + transfer_mappings[(host, port)]["remote_block_ids"][i].extend( + single_group_transfer_mappings[(host, port)]["remote_block_ids"] + ) + transfer_mappings[(host, port)]["trans_count"][i] = single_group_transfer_mappings[ + (host, port) + ]["trans_count"] assert len(transfer_mappings) <= 1, f"Not support add mutil transfer task for req_id:{req_id}" update_req_meta = copy.deepcopy(req_meta) for (host, port), block_dict in transfer_mappings.items(): update_req_meta.remote_host = host update_req_meta.remote_port = port - update_req_meta.local_block_ids = block_dict["local_block_ids"] - update_req_meta.remote_block_ids = block_dict["remote_block_ids"] + update_req_meta.local_block_ids = self._get_kernel_block_ids(block_dict["local_block_ids"]) + update_req_meta.remote_block_ids = self._get_kernel_block_ids(block_dict["remote_block_ids"]) update_req_meta.trans_count = block_dict["trans_count"] update_metadata[req_id] = update_req_meta metadata.requests = {} @@ -1149,19 +1382,35 @@ class MooncakeLayerwiseConnectorWorker: # update send task trans block info if self.pd_head_ratio != 1: send_task = metadata.send_task - send_task.rearrange_block_ids = sorted( - {block_id for req_id in metadata.requests for block_id in metadata.requests[req_id].local_block_ids} - ) - + send_task.group_rearrange_block_ids = [[] for _ in range(self.num_kv_cache_groups)] + send_task.group_num_blocks = [0 for _ in range(self.num_kv_cache_groups)] + send_task.group_num_tokens = [0 for _ in range(self.num_kv_cache_groups)] + send_task.group_block_table = [None for _ in range(self.num_kv_cache_groups)] + send_task.group_block_len_tensor = [None for _ in range(self.num_kv_cache_groups)] + send_task.group_seq_start_tensor = [None for _ in range(self.num_kv_cache_groups)] device = self.k_buffer.device # type: ignore - flat_block_ids = send_task.rearrange_block_ids - block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int32, device=device) - send_task.num_blocks = len(flat_block_ids) - send_task.num_tokens = send_task.num_blocks * self.block_size + for i in self.attn_resharding_group_idx: + send_task.group_rearrange_block_ids[i].extend( + sorted( + { + block_id + for req_id in metadata.requests + for block_id in metadata.requests[req_id].local_block_ids[i] + } + ) + ) + flat_block_ids = send_task.group_rearrange_block_ids[i] + block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int32, device=device) + send_task.group_num_blocks[i] = len(flat_block_ids) + send_task.group_num_tokens[i] = send_task.group_num_blocks[i] * ( + self.block_size[i] // self.kernel_block_size_scale[i] + ) - send_task.block_table = block_ids_tensor.view(1, -1) - send_task.block_len_tensor = torch.tensor([send_task.num_tokens], dtype=torch.int32, device=device) - send_task.seq_start_tensor = torch.tensor([0], dtype=torch.int32, device=device) + send_task.group_block_table[i] = block_ids_tensor.view(1, -1) + send_task.group_block_len_tensor[i] = torch.tensor( + [send_task.group_num_tokens[i]], dtype=torch.int32, device=device + ) + send_task.group_seq_start_tensor[i] = torch.tensor([0], dtype=torch.int32, device=device) def save_kv_layer( self, @@ -1174,41 +1423,65 @@ class MooncakeLayerwiseConnectorWorker: """MooncakeLayerwiseConnector does not save explicitly.""" if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys(): # get reshape and cache event - if self.use_mla or self.use_sparse: + if layer_name == "": + layer_name = self.index_to_name[self.current_layer][0] + if ( + type(attn_metadata) is dict and not getattr(attn_metadata[layer_name], "reshape_cache_event", None) + ) or (not getattr(attn_metadata, "reshape_cache_event", None)): + reshape_cache_event = torch.npu.Event() + reshape_cache_event.record() + elif self.use_mla: reshape_cache_event = attn_metadata[layer_name].reshape_cache_event else: reshape_cache_event = attn_metadata.reshape_cache_event send_task = connector_metadata.send_task - if self.pd_head_ratio != 1: + layer_group_idx = self.layer_metadata[layer_name].tensor_group_idx[0] + keys = None + values = None + if ( + self.pd_head_ratio != 1 + and (isinstance(self.kv_cache_specs[layer_group_idx], (FullAttentionSpec, SlidingWindowSpec))) + and send_task.group_num_blocks[layer_group_idx] > 0 + ): assert self.resharding_stream is not None with npu_stream_switch(self.resharding_stream): reshape_cache_event.wait() dtype = self.k_buffer.dtype # type: ignore device = self.k_buffer.device # type: ignore # Initialize buffers - keys = torch.empty((send_task.num_tokens, *kv_layer[0].size()[-2:]), dtype=dtype, device=device) - values = torch.empty((send_task.num_tokens, *kv_layer[1].size()[-2:]), dtype=dtype, device=device) + keys = torch.empty( + (send_task.group_num_tokens[layer_group_idx], *kv_layer[0].size()[-2:]), + dtype=dtype, + device=device, + ) + values = torch.empty( + (send_task.group_num_tokens[layer_group_idx], *kv_layer[1].size()[-2:]), + dtype=dtype, + device=device, + ) # Load cache data into buffers torch_npu.atb.npu_paged_cache_load( kv_layer[0], kv_layer[1], - send_task.block_table, - send_task.block_len_tensor, - seq_starts=send_task.seq_start_tensor, + send_task.group_block_table[layer_group_idx], + send_task.group_block_len_tensor[layer_group_idx], + seq_starts=send_task.group_seq_start_tensor[layer_group_idx], key=keys, value=values, ) # sort kv caches for each block keys = ( - keys.view(send_task.num_blocks, self.pd_head_ratio, -1, *keys.shape[1:]) + keys.view(send_task.group_num_blocks[layer_group_idx], self.pd_head_ratio, -1, *keys.shape[1:]) .transpose(0, 1) .reshape_as(keys) ) values = ( - values.view(send_task.num_blocks, self.pd_head_ratio, -1, *values.shape[1:]) + values.view( + send_task.group_num_blocks[layer_group_idx], self.pd_head_ratio, -1, *values.shape[1:] + ) .transpose(0, 1) .reshape_as(values) ) @@ -1216,9 +1489,6 @@ class MooncakeLayerwiseConnectorWorker: keys = keys.reshape(-1, *kv_layer[0].shape[2:]) values = values.reshape(-1, *kv_layer[1].shape[2:]) (keys, values) = kv_alltoall_and_rearrange(self.pd_head_ratio, keys, values) - else: - keys = None - values = None assert self.kv_send_layer_thread is not None assert reshape_cache_event is not None @@ -1227,9 +1497,12 @@ class MooncakeLayerwiseConnectorWorker: k_cache=keys, v_cache=values, layer_idx=self.current_layer, - rearrange_block_ids=send_task.rearrange_block_ids, + layer_name=layer_name, + group_rearrange_block_ids=send_task.group_rearrange_block_ids, ) for req_id, req_meta in connector_metadata.requests.items(): + if len(req_meta.local_block_ids[layer_group_idx]) == 0: + continue req_meta_update = self.update_decoder_info(req_id, req_meta) logger.debug(f"Add request {req_id} to kv send layer thread. {req_meta_update=}") layer_send_task.send_request[req_id] = req_meta_update @@ -1258,10 +1531,10 @@ class MooncakeLayerwiseConnectorWorker: self.remote_poller.register(sock, zmq.POLLIN) # type: ignore return sock - def update_decoder_info(self, req_id, req_meta): + def update_decoder_info(self, req_id, req_meta: ReqMeta): if ( - req_meta.remote_engine_id not in self.remote_kv_caches_base_addr - or req_meta.remote_port not in self.remote_kv_caches_base_addr[req_meta.remote_engine_id] + req_meta.remote_engine_id not in self.remote_layer_metadata + or req_meta.remote_port not in self.remote_layer_metadata[req_meta.remote_engine_id] ): try: encoded_data = self.encoder.encode((GET_META_MSG, req_id)) @@ -1269,7 +1542,7 @@ class MooncakeLayerwiseConnectorWorker: path = f"{req_meta.remote_host}:{req_meta.remote_port}" ensure_zmq_send(sock, encoded_data, path) metadata_bytes = ensure_zmq_recv(sock, self.remote_poller, path) - agent_meta = self.decoder.decode(metadata_bytes) + agent_meta: MooncakeAgentMetadata = self.decoder.decode(metadata_bytes) except Exception as e: logger.error( f"Query to port and kv base addr for request {req_id}" @@ -1279,30 +1552,30 @@ class MooncakeLayerwiseConnectorWorker: assert req_meta.remote_engine_id != self.engine_id, ( f"Conflict engine id {req_meta.remote_engine_id} with local engine id {self.local_engine_id}." ) - self.remote_kv_caches_base_addr[req_meta.remote_engine_id][req_meta.remote_port] = ( - agent_meta.kv_caches_base_addr - ) + self.remote_layer_metadata[req_meta.remote_engine_id][req_meta.remote_port] = agent_meta.layer_metadata self.remote_te_port[req_meta.remote_engine_id][req_meta.remote_port] = agent_meta.te_rpc_port - logger.info( + logger.debug( f"Query to port and kv base addr for request {req_id}" f"from {req_meta.remote_host}:{req_meta.remote_port}" - f"success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}" + f"success {agent_meta.layer_metadata=} {agent_meta.te_rpc_port=}" ) if self.pd_head_ratio > 1: # for tp inequal, pre-create link to prevent alltoall out of memory session_id = f"{req_meta.remote_host}:{agent_meta.te_rpc_port}" + first_layer_name = next(iter(self.layer_metadata.keys())) ret = self.engine.batch_transfer_sync_write( - session_id, [self.kv_caches_base_addr[0]], [agent_meta.kv_caches_base_addr[0]], [128] + session_id, + [self.layer_metadata[first_layer_name].kv_caches_base_addr[0]], + [agent_meta.layer_metadata[first_layer_name].kv_caches_base_addr[0]], + [128], ) if ret < 0: logger.error(f"Mooncake transfer failed to create link to device {session_id}") req_meta.remote_te_rpc_port = self.remote_te_port[req_meta.remote_engine_id][req_meta.remote_port] - req_meta.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[req_meta.remote_engine_id][ - req_meta.remote_port - ] + req_meta.remote_layer_metadata = self.remote_layer_metadata[req_meta.remote_engine_id][req_meta.remote_port] return req_meta - def send_done_send_signal(self, req_id, req_meta): + def send_done_send_signal(self, req_id, req_meta, group_idx): external_req_id = get_external_request_id(req_id) logger.info( "Sending done sending signal for request %s to %s:%d", @@ -1313,7 +1586,7 @@ class MooncakeLayerwiseConnectorWorker: try: path = make_zmq_path("tcp", req_meta.remote_host, req_meta.remote_port) msg_encoder = msgspec.msgpack.Encoder() - encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id, req_meta.trans_count)) + encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id, req_meta.trans_count[group_idx])) with zmq_ctx(zmq.REQ, path) as sock: # type: ignore ensure_zmq_send(sock, encoded_data, f"{req_meta.remote_host}:{req_meta.remote_port}") ack = sock.recv() diff --git a/vllm_ascend/distributed/kv_transfer/utils/utils.py b/vllm_ascend/distributed/kv_transfer/utils/utils.py index 19ef0ed0..250f1d51 100644 --- a/vllm_ascend/distributed/kv_transfer/utils/utils.py +++ b/vllm_ascend/distributed/kv_transfer/utils/utils.py @@ -261,6 +261,7 @@ def get_transfer_mappings( pd_head_mapping: dict[int, set], d_trans_count_mapping: dict[tuple[str, int], int], req_meta, + block_group_idx: int, p_parallel_info: parallel_info, req_id: str, transed_idx: int, @@ -272,15 +273,17 @@ def get_transfer_mappings( transfer_mappings: dict[tuple[str, int], dict[str, Any]] = {} p_head_group_rank = (tp_rank - dcp_rank) // p_parallel_info.dcp_size p_block_idxs: list[int] = p_rank_block_mapping[pcp_rank][p_head_group_rank][dcp_rank] + p_block_ids = req_meta.local_block_ids[block_group_idx] + d_block_ids = req_meta.remote_block_ids[block_group_idx] for p_block_idx, logic_block_idx in enumerate(p_block_idxs): if logic_block_idx < transed_idx or logic_block_idx >= to_trans_idx: continue for d_head_group_rank in pd_head_mapping[p_head_group_rank]: - p_block_id = req_meta.local_block_ids[p_block_idx] + p_block_id = p_block_ids[p_block_idx] remote_host = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["host"] remote_port = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["port"] d_block_idx = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["block_idx"] - d_block_id = req_meta.remote_block_ids[d_block_idx] + d_block_id = d_block_ids[d_block_idx] if (remote_host, remote_port) not in transfer_mappings: transfer_mappings[(remote_host, remote_port)] = { "local_block_ids": [], diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 965794ee..e4a7c399 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -59,6 +59,7 @@ def init_ascend_model_parallel( global _P_TP assert _P_TP is None, "distributed prefill tensor parallel group is already initialized" prefill_tensor_model_parallel_size = pd_tp_ratio + pcp_size = parallel_config.prefill_context_parallel_size # divide alltoall groups if pd_head_ratio > 1 and get_current_vllm_config().kv_transfer_config.is_kv_producer: num_head_replica = get_ascend_config().num_head_replica @@ -67,13 +68,13 @@ def init_ascend_model_parallel( group_ranks = all_ranks.view(-1, prefill_tensor_model_parallel_size).unbind(0) else: group_ranks = all_ranks.clone().view( - global_dp_size, -1, num_head_replica + global_dp_size * pcp_size, -1, num_head_replica ) # [DP_size, num_head, num_head_replica] group_ranks = group_ranks.permute(0, 2, 1) group_ranks = group_ranks.reshape(-1, group_ranks.size(-1)) # [DP_size * num_head_replica, num_head] alltoall_group_size = group_ranks.size(-1) // remote_tp_size group_ranks = group_ranks.unsqueeze(-1).view( - global_dp_size, num_head_replica, -1, alltoall_group_size + global_dp_size * pcp_size, num_head_replica, -1, alltoall_group_size ) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size] group_ranks = group_ranks.reshape(-1, alltoall_group_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] diff --git a/vllm_ascend/patch/worker/patch_qwen3_next.py b/vllm_ascend/patch/worker/patch_qwen3_next.py index 7e7a5eec..29694aed 100644 --- a/vllm_ascend/patch/worker/patch_qwen3_next.py +++ b/vllm_ascend/patch/worker/patch_qwen3_next.py @@ -29,6 +29,7 @@ from vllm.v1.attention.backend import AttentionMetadata # type: ignore from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from vllm.v1.attention.backends.utils import PAD_SLOT_ID +from vllm_ascend.attention.utils import maybe_save_kv_layer_to_connector from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch from vllm_ascend.utils import enable_sp @@ -85,6 +86,7 @@ class AscendQwen3Next_GatedDeltaNet(Qwen3NextGatedDeltaNet): # ============================================================ # Part 3: Output Projection # ============================================================ + maybe_save_kv_layer_to_connector("", []) z_shape_og = z.shape # Reshape input data into 2D tensor core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index b7d1bae1..92230b1e 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1073,7 +1073,11 @@ def refresh_block_size(vllm_config): return # TODO(MengqingCao): Remove the model_type check, after resolving the hidden error in get_kv_cache_groups. - if model_config.hf_text_config.model_type != "qwen3_next" and cache_config.block_size != 128: + if ( + "qwen3_next" not in model_config.hf_text_config.model_type + and "qwen3_5" not in model_config.hf_text_config.model_type + and cache_config.block_size != 128 + ): if cache_config.enable_prefix_caching or scheduler_config.enable_chunked_prefill: logger.info("Block size is set to 128 if prefix cache or chunked prefill is enabled.") cache_config.block_size = 128