[Bugfix] TP size larger than KV cache head causes accuracy issues (#3366)

### What this PR does / why we need it?
Resolve the issue where, in the case of unequal TP (Tensor Parallelism),
the TP size is larger than the number of model attention kvcache heads,
causing the KV cache to generate duplicates, which leads to transmission
errors in the original code.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
By ci
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: nwpu-zxr <zhouxuerong2@huawei.com>
This commit is contained in:
wangxiaoteng888
2025-10-11 11:22:23 +08:00
committed by GitHub
parent ace300a549
commit ca05f7d632
8 changed files with 685 additions and 36 deletions

View File

@@ -277,7 +277,7 @@ class SendingLayerThread(threading.Thread):
self.send_queue = queue.Queue[tuple[DecodeMooncakeAgentMetadata, str,
list[int], int, torch.Tensor,
torch.Tensor]]()
self.completion_event: threading.Event
self.completion_event: Optional[threading.Event] = None
self.completion_event_count: int
self.task_tracker = task_tracker
self.total_layers = total_layers
@@ -287,6 +287,8 @@ class SendingLayerThread(threading.Thread):
self.engine = engine
self.tp_rank = tp_rank
self.pd_tp_ratio = get_ascend_config().pd_tp_ratio
self.num_head_replica = get_ascend_config().num_head_replica
self.pd_head_ratio = get_ascend_config().pd_head_ratio
vllm_config = get_current_vllm_config()
max_model_len = vllm_config.scheduler_config.max_model_len
first_kv_cache = first_kv_cache[:max_model_len]
@@ -358,7 +360,9 @@ class SendingLayerThread(threading.Thread):
remote_kv_base_addrs = req_meta.kv_caches_base_addr
remote_block_ids = req_meta.block_ids
if self.pd_tp_ratio == 1:
if self.num_head_replica >= 1 and self.tp_rank % self.num_head_replica != 0:
pass
elif self.pd_head_ratio == 1:
layer_local_kv_base_addr = [
self.local_kv_base_addr[i]
for i in [2 * layer_index, 2 * layer_index + 1]
@@ -420,7 +424,7 @@ class SendingLayerThread(threading.Thread):
src_layer_addr = src_layer_base_addr
for group_remote_block_id in grouped_remote_block_ids:
block_len = self.block_len[0]
remote_block_len = self.block_len[0] * self.pd_tp_ratio
remote_block_len = self.block_len[0] * self.pd_head_ratio
src_list.append(src_layer_addr)
if src_layer_addr + len(
@@ -436,23 +440,21 @@ class SendingLayerThread(threading.Thread):
dst_list.append(dst_layer_base_addr +
group_remote_block_id[0] *
remote_block_len + length *
(self.tp_rank % self.pd_tp_ratio))
((self.tp_rank // self.num_head_replica) %
self.pd_head_ratio))
src_layer_addr += length
torch.npu.synchronize()
ret = self.engine.batch_transfer_sync_write(
session_id, src_list, dst_list, length_list)
self.completion_event_count -= 1
if self.completion_event_count == 0 and self.completion_event is not None:
print(
f"[_transfer_kv_cache] {self.completion_event_count} self.event.set()"
)
self.completion_event.set()
if ret < 0:
logger.error("Mooncake transfer failed for request %s",
req_meta.req_id)
raise RuntimeError(f"Mooncake transfer failed, ret: {ret}")
if self.completion_event is not None:
self.completion_event_count -= 1
if self.completion_event_count == 0:
self.completion_event.set()
self.completion_event = None
def add_event(self, event: threading.Event, count: int) -> None:
self.completion_event = event
@@ -924,6 +926,8 @@ class MooncakeLayerwiseConnectorWorker:
self.kv_caches_base_addr: list[int] = []
self.pd_tp_ratio = get_ascend_config().pd_tp_ratio
self.pd_head_ratio = get_ascend_config().pd_head_ratio
self.first_kv_cache = None
def _get_prefill_decode_size(self, vllm_config: VllmConfig):
@@ -1104,7 +1108,7 @@ class MooncakeLayerwiseConnectorWorker:
path = make_zmq_path(
"tcp", meta.remote_host, meta.remote_port +
self.tp_rank * self.pd_tp_ratio + offset)
logger.debug(
logger.info(
f"Notify the prefiller: {path} that request: {req_id} from decoder is ready."
)
msg_encoder = msgspec.msgpack.Encoder()
@@ -1142,7 +1146,7 @@ class MooncakeLayerwiseConnectorWorker:
**kwargs) -> None:
"""MooncakeLayerwiseConnector does not save explicitly."""
if self.kv_role == 'kv_producer':
if self.pd_tp_ratio != 1:
if self.pd_head_ratio != 1:
if self.current_layer != 0:
self.completion_event.wait()
self.completion_event = threading.Event()
@@ -1153,8 +1157,9 @@ class MooncakeLayerwiseConnectorWorker:
def sort_kv_cache(input_kv: list[list[int]]):
return torch.cat([
torch.chunk(tensor, self.pd_tp_ratio, dim=0)[x]
for x in range(self.pd_tp_ratio) for tensor in input_kv
torch.chunk(tensor, self.pd_head_ratio, dim=0)[x]
for x in range(self.pd_head_ratio)
for tensor in input_kv
])
total_block_ids = [
@@ -1176,7 +1181,7 @@ class MooncakeLayerwiseConnectorWorker:
keys = sort_kv_cache(keys) # [req1_key, req2_key]
values = sort_kv_cache(values)
(keys,
values) = kv_alltoall_and_rearrange(self.pd_tp_ratio, keys,
values) = kv_alltoall_and_rearrange(self.pd_head_ratio, keys,
values)
key_start_id = 0
value_start_id = 0
@@ -1185,7 +1190,7 @@ class MooncakeLayerwiseConnectorWorker:
value = None
for req_id, request in connector_metadata.requests.items():
logger.info(f"Add request {req_id} to kv send layer thread. ")
if self.pd_tp_ratio != 1:
if self.pd_head_ratio != 1:
key_block_num = len(
request.local_block_ids) * key_block_size
value_block_num = len(