[BugFix] Fix mooncake bug in PCP scenario (#5055)
### What this PR does / why we need it?
The mooncake_connector.py file was importing the wrong arguments to the
file, which could cause errors when use PCP; this issue has been
corrected.
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: daishixun <dsxsteven@sina.com>
This commit is contained in:
@@ -22,6 +22,7 @@ sys.modules["mooncake.engine"] = fake_engine
|
||||
_mock_ascend_config = MagicMock(enable_kv_nz=False)
|
||||
_mock_pp_group = MagicMock(rank_in_group=0, world_size=1)
|
||||
_mock_tp_group = MagicMock(rank_in_group=0, world_size=4)
|
||||
_mock_pcp_group = MagicMock(rank_in_group=0, world_size=1)
|
||||
patch('vllm_ascend.distributed.mooncake_connector.get_pp_group',
|
||||
return_value=_mock_pp_group).start()
|
||||
patch('vllm_ascend.distributed.mooncake_connector.get_tp_group',
|
||||
@@ -32,6 +33,8 @@ patch(
|
||||
patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank',
|
||||
return_value=0).start()
|
||||
patch('vllm_ascend.distributed.mooncake_connector.get_pcp_group',
|
||||
return_value=_mock_pcp_group).start()
|
||||
|
||||
from vllm_ascend.distributed.mooncake_connector import ( # noqa: E402
|
||||
KVCacheRecvingThread, KVCacheSendingThread, KVCacheTaskTracker,
|
||||
|
||||
@@ -23,6 +23,7 @@ import zmq
|
||||
from mooncake.engine import TransferEngine # type: ignore
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_pcp_group
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorHandshakeMetadata, KVConnectorMetadata,
|
||||
KVConnectorRole)
|
||||
@@ -43,13 +44,6 @@ from vllm_ascend.distributed.mooncake_transfer_engine import global_te
|
||||
from vllm_ascend.distributed.utils import get_transfer_timeout_value
|
||||
from vllm_ascend.utils import prefill_context_parallel_enable
|
||||
|
||||
# isort: off
|
||||
if prefill_context_parallel_enable():
|
||||
from vllm.distributed import (get_prefill_context_model_parallel_rank,
|
||||
get_prefill_context_model_parallel_world_size
|
||||
)
|
||||
# isort: on
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
@@ -1003,13 +997,12 @@ class MooncakeConnectorWorker:
|
||||
self.pp_size = vllm_config.parallel_config.pipeline_parallel_size
|
||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||
self.side_channel_host = get_ip()
|
||||
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
||||
) if prefill_context_parallel_enable() else 1
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
# Assert that pp_size and pcp_size cannot both be greater than 1
|
||||
assert not (self.pp_size > 1 and self.pcp_size
|
||||
> 1), "pp and pcp cannot open in same time"
|
||||
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
||||
) if self.pcp_size > 1 else 0
|
||||
self.pcp_rank = get_pcp_group(
|
||||
).rank_in_group if self.pcp_size > 1 else 0
|
||||
self.dcp_size = get_decode_context_model_parallel_world_size()
|
||||
self.dcp_rank = get_decode_context_model_parallel_rank(
|
||||
) if self.dcp_size > 1 else 0
|
||||
|
||||
Reference in New Issue
Block a user