diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py index d9ef029b..4288b0d1 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py @@ -212,6 +212,7 @@ class KVCacheSendingLayerThread(threading.Thread): num_head_replica: int, layer_metadata: dict[str, LayerMetadata], use_mla: bool, + use_attn_mamba_hybrid: bool, k_buffer: torch.Tensor, v_buffer: torch.Tensor, enable_kv_quant: bool, @@ -239,10 +240,17 @@ class KVCacheSendingLayerThread(threading.Thread): self.layer_metadata = layer_metadata self.total_layers = total_layers self.use_mla = use_mla + self.use_attn_mamba_hybrid = use_attn_mamba_hybrid self.resharding_stream = resharding_stream self.current_layer = -1 - self.send_queue = queue.Queue[SendTask]() + send_queue_size = 0 + if self.pd_head_ratio != 1: + if self.use_attn_mamba_hybrid: + send_queue_size = len(self.kv_cache_specs) + else: + send_queue_size = 1 + self.send_queue = queue.Queue[SendTask](maxsize=send_queue_size) self.k_buffer = k_buffer self.v_buffer = v_buffer self.enable_kv_quant = enable_kv_quant @@ -465,6 +473,7 @@ class KVCacheSendingLayerThread(threading.Thread): for session_id, transfer_meta in session_meta.items(): if len(transfer_meta.src) > 0: + req_start_time = time.perf_counter() ret = self.engine.batch_transfer_sync_write( session_id, transfer_meta.src, transfer_meta.dst, transfer_meta.length ) @@ -480,6 +489,16 @@ class KVCacheSendingLayerThread(threading.Thread): req_id, req_meta, layer_group_idx ) # TODO Send a signal indicating transmission failure else: + req_end_time = time.perf_counter() + total_transfer_size = sum(transfer_meta.length) / 1024 + req_transfer_elapsed = (req_end_time - req_start_time) * 1000 + logger.debug( + "Layer%d KV cache transfer task %dKB to remote_session_id [%s] took %.3f ms.", + send_task.layer_idx, + total_transfer_size, + session_id, + req_transfer_elapsed, + ) if send_task.layer_idx == (self.total_layers - 1): for req_id in transfer_meta.req_ids: req_meta = send_task.send_request[req_id] @@ -996,6 +1015,7 @@ class MooncakeLayerwiseConnectorWorker: self.side_channel_host = get_ip() self.total_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config) self.use_mla = self.vllm_config.model_config.use_mla + self.use_attn_mamba_hybrid = False self.request_map = dict[str, str]() if self.use_mla: self.total_num_kv_heads = 1 @@ -1074,7 +1094,7 @@ class MooncakeLayerwiseConnectorWorker: buffer_list.append(self.k_buffer) buffer_list.append(self.v_buffer) if self.enable_kv_quant: - quant_k_cache_numel = first_kv_cache_tuple[0].numel() // 2 + quant_k_cache_numel = first_kv_cache_tuple[0].numel() self.k_quant_buffer = torch.zeros( quant_k_cache_numel + alignment, dtype=torch.int8, device=first_kv_cache.device ) @@ -1093,10 +1113,8 @@ class MooncakeLayerwiseConnectorWorker: for tensor in buffer_list: assert tensor.data_ptr() % alignment == 0, "The address of the registered kv cache should be aligned to 2M" - ret_value = self.engine.register_memory(tensor.data_ptr(), tensor.numel()) - logger.info( - f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} {tensor.numel()} {ret_value=}" - ) + ret_value = self.engine.register_memory(tensor.data_ptr(), tensor.numel() * tensor.element_size()) + logger.info(f"Register memory buffer for transfer, buffer size:{tensor.numel() * tensor.element_size()}") if ret_value != 0: raise RuntimeError("Mooncake memory registration failed. ") @@ -1108,7 +1126,7 @@ class MooncakeLayerwiseConnectorWorker: for layer_name in kv_cache_group_spec.layer_names: layer2group_ids[layer_name] = i - use_mamba, use_attn, use_attn_mamba_hybrid = False, False, False + use_mamba, use_attn = False, False conv_total_padding_size = 0 for kv_cache_tensor in self.kv_cache_config.kv_cache_tensors: for layer_name in kv_cache_tensor.shared_by: @@ -1122,7 +1140,7 @@ class MooncakeLayerwiseConnectorWorker: if isinstance(layer_kv_cache_spec, AttentionSpec): use_attn = True if use_mamba and use_attn: - use_attn_mamba_hybrid = True + self.use_attn_mamba_hybrid = True break ptrs = [] @@ -1158,7 +1176,7 @@ class MooncakeLayerwiseConnectorWorker: single_layer_meta.block_len.append(single_kv_cache.element_size() * math.prod(block_shape)) single_layer_meta.block_size_scale.append(block_size_scale) self.kernel_block_size_scale[layer2group_ids[layer_name]] = block_size_scale - if single_kv_cache.data_ptr() not in ptrs and not use_attn_mamba_hybrid: + if single_kv_cache.data_ptr() not in ptrs and not self.use_attn_mamba_hybrid: ptrs.append(single_kv_cache.data_ptr()) lengths.append( num_blocks * single_kv_cache.element_size() * math.prod(block_shape) * block_size_scale @@ -1166,7 +1184,7 @@ class MooncakeLayerwiseConnectorWorker: logger.info(f"layer: {layer_name}, num_blocks: {num_blocks}, block_shape: {block_shape}") self.layer_metadata[layer_name] = single_layer_meta - if use_attn_mamba_hybrid: + if self.use_attn_mamba_hybrid: for kv_cache_tensor in self.kv_cache_config.kv_cache_tensors: tensor_addrs = [] for layer_name in kv_cache_tensor.shared_by: @@ -1217,6 +1235,7 @@ class MooncakeLayerwiseConnectorWorker: num_head_replica=self.num_head_replica, layer_metadata=self.layer_metadata, use_mla=self.use_mla, + use_attn_mamba_hybrid=self.use_attn_mamba_hybrid, k_buffer=self.k_buffer, v_buffer=self.v_buffer, enable_kv_quant=self.enable_kv_quant,