[P/D]Mooncake Layerwise Connector supports hybrid attention manager with multiple kvcache groups (#7022)
### What this PR does / why we need it?
Mooncake Layerwise Connector supports hybrid attention manager with
multiple kvcache groups.
### Does this PR introduce _any_ user-facing change?
Yes.
### How was this patch tested?
By CI.
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
This commit is contained in:
@@ -91,6 +91,11 @@ class RecomputeScheduler(Scheduler):
|
|||||||
and self.vllm_config.kv_transfer_config
|
and self.vllm_config.kv_transfer_config
|
||||||
and self.vllm_config.kv_transfer_config.is_kv_consumer
|
and self.vllm_config.kv_transfer_config.is_kv_consumer
|
||||||
)
|
)
|
||||||
|
self.is_kv_producer = self.vllm_config.kv_transfer_config and self.vllm_config.kv_transfer_config.is_kv_producer
|
||||||
|
self.is_hybrid_model = (
|
||||||
|
"qwen3_next" in self.vllm_config.model_config.model_type
|
||||||
|
or "qwen3_5" in self.vllm_config.model_config.model_type
|
||||||
|
)
|
||||||
|
|
||||||
def add_request(self, request: Request) -> None:
|
def add_request(self, request: Request) -> None:
|
||||||
existing = self.requests.get(request.request_id)
|
existing = self.requests.get(request.request_id)
|
||||||
@@ -111,6 +116,10 @@ class RecomputeScheduler(Scheduler):
|
|||||||
request.streaming_queue = deque()
|
request.streaming_queue = deque()
|
||||||
# Fill in placeholder tokens to enable full graph compatibility. Without
|
# Fill in placeholder tokens to enable full graph compatibility. Without
|
||||||
# placeholders, graph matching may fail, forcing eager mode execution.
|
# placeholders, graph matching may fail, forcing eager mode execution.
|
||||||
|
if self.is_kv_producer and self.is_hybrid_model and request.num_tokens > 1:
|
||||||
|
request.prompt_token_ids.pop()
|
||||||
|
request._all_token_ids.pop()
|
||||||
|
request.num_prompt_tokens -= 1
|
||||||
if self.is_mtp_kv_consumer:
|
if self.is_mtp_kv_consumer:
|
||||||
request.spec_token_ids = [PLACEHOLDER_TOKEN_ID] * self.num_spec_tokens
|
request.spec_token_ids = [PLACEHOLDER_TOKEN_ID] * self.num_spec_tokens
|
||||||
self.waiting.add_request(request)
|
self.waiting.add_request(request)
|
||||||
@@ -118,6 +127,55 @@ class RecomputeScheduler(Scheduler):
|
|||||||
if self.log_stats:
|
if self.log_stats:
|
||||||
request.record_event(EngineCoreEventType.QUEUED)
|
request.record_event(EngineCoreEventType.QUEUED)
|
||||||
|
|
||||||
|
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
|
||||||
|
"""
|
||||||
|
KV Connector: check if the request_id is finished_recving.
|
||||||
|
|
||||||
|
The finished_recving_kv_req_ids list is populated
|
||||||
|
on the previous steps()'s update_from_output based
|
||||||
|
on the worker side connector.
|
||||||
|
|
||||||
|
When the kv transfer is ready, we cache the blocks
|
||||||
|
and the request state will be moved back to WAITING from
|
||||||
|
WAITING_FOR_REMOTE_KV.
|
||||||
|
"""
|
||||||
|
assert self.connector is not None
|
||||||
|
if request.request_id not in self.finished_recving_kv_req_ids:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if request.request_id in self.failed_recving_kv_req_ids:
|
||||||
|
# Request had KV load failures; num_computed_tokens was already
|
||||||
|
# updated in _update_requests_with_invalid_blocks
|
||||||
|
if request.num_computed_tokens:
|
||||||
|
# Cache any valid computed tokens.
|
||||||
|
self.kv_cache_manager.cache_blocks(request, request.num_computed_tokens)
|
||||||
|
else:
|
||||||
|
# No valid computed tokens, release allocated blocks.
|
||||||
|
# There may be a local cache hit on retry.
|
||||||
|
self.kv_cache_manager.free(request)
|
||||||
|
|
||||||
|
self.failed_recving_kv_req_ids.remove(request.request_id)
|
||||||
|
else:
|
||||||
|
# Now that the blocks are ready, actually cache them.
|
||||||
|
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
|
||||||
|
if len(block_ids) == 1:
|
||||||
|
num_computed_tokens = len(block_ids[0]) * self.block_size
|
||||||
|
# Handle the case where num request tokens less than one block.
|
||||||
|
num_computed_tokens = min(num_computed_tokens, request.num_tokens)
|
||||||
|
else:
|
||||||
|
num_computed_tokens = request.num_tokens
|
||||||
|
if num_computed_tokens == request.num_tokens:
|
||||||
|
num_computed_tokens -= 1
|
||||||
|
# This will cache the blocks iff caching is enabled.
|
||||||
|
self.kv_cache_manager.cache_blocks(request, num_computed_tokens)
|
||||||
|
|
||||||
|
# Update the request state for scheduling.
|
||||||
|
request.num_computed_tokens = num_computed_tokens
|
||||||
|
|
||||||
|
# Return that we are ready.
|
||||||
|
self.finished_recving_kv_req_ids.remove(request.request_id)
|
||||||
|
return True
|
||||||
|
|
||||||
def schedule(self) -> RecomputeSchedulerOutput:
|
def schedule(self) -> RecomputeSchedulerOutput:
|
||||||
# NOTE(woosuk) on the scheduling algorithm:
|
# NOTE(woosuk) on the scheduling algorithm:
|
||||||
# There's no "decoding phase" nor "prefill phase" in the scheduler.
|
# There's no "decoding phase" nor "prefill phase" in the scheduler.
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -261,6 +261,7 @@ def get_transfer_mappings(
|
|||||||
pd_head_mapping: dict[int, set],
|
pd_head_mapping: dict[int, set],
|
||||||
d_trans_count_mapping: dict[tuple[str, int], int],
|
d_trans_count_mapping: dict[tuple[str, int], int],
|
||||||
req_meta,
|
req_meta,
|
||||||
|
block_group_idx: int,
|
||||||
p_parallel_info: parallel_info,
|
p_parallel_info: parallel_info,
|
||||||
req_id: str,
|
req_id: str,
|
||||||
transed_idx: int,
|
transed_idx: int,
|
||||||
@@ -272,15 +273,17 @@ def get_transfer_mappings(
|
|||||||
transfer_mappings: dict[tuple[str, int], dict[str, Any]] = {}
|
transfer_mappings: dict[tuple[str, int], dict[str, Any]] = {}
|
||||||
p_head_group_rank = (tp_rank - dcp_rank) // p_parallel_info.dcp_size
|
p_head_group_rank = (tp_rank - dcp_rank) // p_parallel_info.dcp_size
|
||||||
p_block_idxs: list[int] = p_rank_block_mapping[pcp_rank][p_head_group_rank][dcp_rank]
|
p_block_idxs: list[int] = p_rank_block_mapping[pcp_rank][p_head_group_rank][dcp_rank]
|
||||||
|
p_block_ids = req_meta.local_block_ids[block_group_idx]
|
||||||
|
d_block_ids = req_meta.remote_block_ids[block_group_idx]
|
||||||
for p_block_idx, logic_block_idx in enumerate(p_block_idxs):
|
for p_block_idx, logic_block_idx in enumerate(p_block_idxs):
|
||||||
if logic_block_idx < transed_idx or logic_block_idx >= to_trans_idx:
|
if logic_block_idx < transed_idx or logic_block_idx >= to_trans_idx:
|
||||||
continue
|
continue
|
||||||
for d_head_group_rank in pd_head_mapping[p_head_group_rank]:
|
for d_head_group_rank in pd_head_mapping[p_head_group_rank]:
|
||||||
p_block_id = req_meta.local_block_ids[p_block_idx]
|
p_block_id = p_block_ids[p_block_idx]
|
||||||
remote_host = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["host"]
|
remote_host = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["host"]
|
||||||
remote_port = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["port"]
|
remote_port = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["port"]
|
||||||
d_block_idx = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["block_idx"]
|
d_block_idx = d_block_rank_mapping[logic_block_idx][d_head_group_rank]["block_idx"]
|
||||||
d_block_id = req_meta.remote_block_ids[d_block_idx]
|
d_block_id = d_block_ids[d_block_idx]
|
||||||
if (remote_host, remote_port) not in transfer_mappings:
|
if (remote_host, remote_port) not in transfer_mappings:
|
||||||
transfer_mappings[(remote_host, remote_port)] = {
|
transfer_mappings[(remote_host, remote_port)] = {
|
||||||
"local_block_ids": [],
|
"local_block_ids": [],
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ def init_ascend_model_parallel(
|
|||||||
global _P_TP
|
global _P_TP
|
||||||
assert _P_TP is None, "distributed prefill tensor parallel group is already initialized"
|
assert _P_TP is None, "distributed prefill tensor parallel group is already initialized"
|
||||||
prefill_tensor_model_parallel_size = pd_tp_ratio
|
prefill_tensor_model_parallel_size = pd_tp_ratio
|
||||||
|
pcp_size = parallel_config.prefill_context_parallel_size
|
||||||
# divide alltoall groups
|
# divide alltoall groups
|
||||||
if pd_head_ratio > 1 and get_current_vllm_config().kv_transfer_config.is_kv_producer:
|
if pd_head_ratio > 1 and get_current_vllm_config().kv_transfer_config.is_kv_producer:
|
||||||
num_head_replica = get_ascend_config().num_head_replica
|
num_head_replica = get_ascend_config().num_head_replica
|
||||||
@@ -67,13 +68,13 @@ def init_ascend_model_parallel(
|
|||||||
group_ranks = all_ranks.view(-1, prefill_tensor_model_parallel_size).unbind(0)
|
group_ranks = all_ranks.view(-1, prefill_tensor_model_parallel_size).unbind(0)
|
||||||
else:
|
else:
|
||||||
group_ranks = all_ranks.clone().view(
|
group_ranks = all_ranks.clone().view(
|
||||||
global_dp_size, -1, num_head_replica
|
global_dp_size * pcp_size, -1, num_head_replica
|
||||||
) # [DP_size, num_head, num_head_replica]
|
) # [DP_size, num_head, num_head_replica]
|
||||||
group_ranks = group_ranks.permute(0, 2, 1)
|
group_ranks = group_ranks.permute(0, 2, 1)
|
||||||
group_ranks = group_ranks.reshape(-1, group_ranks.size(-1)) # [DP_size * num_head_replica, num_head]
|
group_ranks = group_ranks.reshape(-1, group_ranks.size(-1)) # [DP_size * num_head_replica, num_head]
|
||||||
alltoall_group_size = group_ranks.size(-1) // remote_tp_size
|
alltoall_group_size = group_ranks.size(-1) // remote_tp_size
|
||||||
group_ranks = group_ranks.unsqueeze(-1).view(
|
group_ranks = group_ranks.unsqueeze(-1).view(
|
||||||
global_dp_size, num_head_replica, -1, alltoall_group_size
|
global_dp_size * pcp_size, num_head_replica, -1, alltoall_group_size
|
||||||
) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size]
|
) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size]
|
||||||
group_ranks = group_ranks.reshape(-1, alltoall_group_size).unbind(0)
|
group_ranks = group_ranks.reshape(-1, alltoall_group_size).unbind(0)
|
||||||
group_ranks = [x.tolist() for x in group_ranks]
|
group_ranks = [x.tolist() for x in group_ranks]
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from vllm.v1.attention.backend import AttentionMetadata # type: ignore
|
|||||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
||||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
||||||
|
|
||||||
|
from vllm_ascend.attention.utils import maybe_save_kv_layer_to_connector
|
||||||
from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat
|
from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat
|
||||||
from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
|
from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
|
||||||
from vllm_ascend.utils import enable_sp
|
from vllm_ascend.utils import enable_sp
|
||||||
@@ -85,6 +86,7 @@ class AscendQwen3Next_GatedDeltaNet(Qwen3NextGatedDeltaNet):
|
|||||||
# ============================================================
|
# ============================================================
|
||||||
# Part 3: Output Projection
|
# Part 3: Output Projection
|
||||||
# ============================================================
|
# ============================================================
|
||||||
|
maybe_save_kv_layer_to_connector("", [])
|
||||||
z_shape_og = z.shape
|
z_shape_og = z.shape
|
||||||
# Reshape input data into 2D tensor
|
# Reshape input data into 2D tensor
|
||||||
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
|
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
|
||||||
|
|||||||
@@ -1073,7 +1073,11 @@ def refresh_block_size(vllm_config):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# TODO(MengqingCao): Remove the model_type check, after resolving the hidden error in get_kv_cache_groups.
|
# TODO(MengqingCao): Remove the model_type check, after resolving the hidden error in get_kv_cache_groups.
|
||||||
if model_config.hf_text_config.model_type != "qwen3_next" and cache_config.block_size != 128:
|
if (
|
||||||
|
"qwen3_next" not in model_config.hf_text_config.model_type
|
||||||
|
and "qwen3_5" not in model_config.hf_text_config.model_type
|
||||||
|
and cache_config.block_size != 128
|
||||||
|
):
|
||||||
if cache_config.enable_prefix_caching or scheduler_config.enable_chunked_prefill:
|
if cache_config.enable_prefix_caching or scheduler_config.enable_chunked_prefill:
|
||||||
logger.info("Block size is set to 128 if prefix cache or chunked prefill is enabled.")
|
logger.info("Block size is set to 128 if prefix cache or chunked prefill is enabled.")
|
||||||
cache_config.block_size = 128
|
cache_config.block_size = 128
|
||||||
|
|||||||
Reference in New Issue
Block a user