diff --git a/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py b/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py index 67c34ee8..82e06e08 100644 --- a/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py +++ b/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py @@ -561,9 +561,12 @@ async def metaserver(request: Request): max_retries=global_args.max_retries, base_delay=global_args.retry_delay) proxy_state.release_prefiller(prefiller_idx, prefiller_score) + proxy_state.release_prefiller_kv(prefiller_idx,prefiller_score) except Exception as e: logger.error(f"Post metaserver failed with: {str(e)}") + proxy_state.release_prefiller(prefiller_idx, prefiller_score) + proxy_state.release_prefiller_kv(prefiller_idx, prefiller_score) if __name__ == '__main__': diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index fa78a46f..a5bc066f 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -978,9 +978,6 @@ class MockTensor: self.data_ptr = MagicMock(return_value=0x1000) -mock_envs_ascend = MagicMock() -mock_envs_ascend.MOONCAKE_CONNECTOR_PROTOCOL = "mock_protocol" - mock_logger = MagicMock() @@ -1017,14 +1014,15 @@ def mock_string_to_int64_hash(s): class TestMooncakeConnectorWorker(unittest.TestCase): def setUp(self): - self.envs_ascend_mock = MockEnvsAscend() self.mock_transfer_engine = MagicMock() self.mock_transfer_engine.get_rpc_port.return_value = 9090 self.mock_transfer_engine.initialize.return_value = 0 self.mock_transfer_engine.register_memory.return_value = 0 self.patches = [ - patch('os.getenv', return_value="10,11"), + patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.envs_ascend.PHYSICAL_DEVICES', + '10,11'), patch('torch.Tensor.size', return_value=(10, 16, 8, 16)), patch('torch.Tensor.element_size', return_value=4), patch('torch.Tensor.data_ptr', return_value=0x1000), @@ -1053,8 +1051,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase): MagicMock()), patch('vllm_ascend.distributed.mooncake_connector.threading.Event', MagicMock()), - patch.dict('sys.modules', - {'vllm_ascend.envs': self.envs_ascend_mock}), ] for p in self.patches: diff --git a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py index bc9ba253..94ab34dc 100644 --- a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py +++ b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py @@ -792,15 +792,15 @@ class TestMooncakeLayerwiseConnector(unittest.TestCase): class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase): def setUp(self): - self.envs_ascend_mock = type("MockEnvsAscend", (), - {"PHYSICAL_DEVICES": "10,11"})() self.mock_transfer_engine = MagicMock() self.mock_transfer_engine.get_rpc_port.return_value = 9090 self.mock_transfer_engine.initialize.return_value = 0 self.mock_transfer_engine.register_memory.return_value = 0 self.patches = [ - patch('os.getenv', return_value="10,11"), + patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.envs_ascend.PHYSICAL_DEVICES', + '10,11'), patch('torch.Tensor.size', return_value=(10, 16, 8, 16)), patch('torch.Tensor.element_size', return_value=4), patch('torch.Tensor.data_ptr', return_value=0x1000), @@ -833,8 +833,6 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase): patch( 'vllm_ascend.distributed.mooncake_layerwise_connector.threading.Event', MagicMock()), - patch.dict('sys.modules', - {'vllm_ascend.envs': self.envs_ascend_mock}), patch( 'vllm_ascend.distributed.mooncake_layerwise_connector.get_ascend_config', return_value=SimpleNamespace(pd_tp_ratio=1, diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index e72f4eba..d92b724f 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -31,6 +31,7 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import Request, RequestStatus import vllm_ascend.envs as envs_ascend +from vllm_ascend.distributed.utils import get_transfer_timeout_value from vllm_ascend.utils import (AscendSocVersion, get_ascend_soc_version, prefill_context_parallel_enable, vllm_version_is) @@ -438,7 +439,7 @@ class LLMDataDistCMgrConnectorWorker(): assert self.local_agent_metadata is not None llm_config = LLMConfig() llm_config.device_id = self.local_rank - llm_config.sync_kv_timeout = 20000 + llm_config.sync_kv_timeout = get_transfer_timeout_value() llm_config.enable_switch_role = True llm_config.enable_cache_manager = True llm_config.enable_remote_cache_accessible = True diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 5dfb125e..7951760d 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -2,6 +2,7 @@ import contextlib import hashlib import math +import os import queue import random import struct @@ -33,6 +34,7 @@ from vllm.v1.request import RequestStatus import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te +from vllm_ascend.distributed.utils import get_transfer_timeout_value from vllm_ascend.utils import vllm_version_is if vllm_version_is("0.11.0"): @@ -855,6 +857,8 @@ class MooncakeConnectorWorker: def __init__(self, vllm_config: VllmConfig, engine_id: str): self._get_prefill_decode_size(vllm_config) + os.environ["ASCEND_TRANSFER_TIMEOUT"] = str( + get_transfer_timeout_value()) if self._prefill_tp_size < self._decode_tp_size: raise ValueError( f"prefill_tp_size: {self._prefill_tp_size} must be greater than" diff --git a/vllm_ascend/distributed/mooncake_layerwise_connector.py b/vllm_ascend/distributed/mooncake_layerwise_connector.py index 874adb3e..f1a6bb91 100644 --- a/vllm_ascend/distributed/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/mooncake_layerwise_connector.py @@ -3,6 +3,7 @@ import contextlib import copy import hashlib import math +import os import queue import struct import threading @@ -31,6 +32,7 @@ from vllm.v1.core.sched.output import SchedulerOutput import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.utils import (align_memory, + get_transfer_timeout_value, kv_alltoall_and_rearrange) from vllm_ascend.utils import vllm_version_is @@ -602,6 +604,8 @@ class MooncakeLayerwiseConnectorWorker: def __init__(self, vllm_config: VllmConfig, engine_id: str): self._get_prefill_decode_size(vllm_config) + os.environ["ASCEND_TRANSFER_TIMEOUT"] = str( + get_transfer_timeout_value()) if self._prefill_tp_size < self._decode_tp_size: raise ValueError( f"prefill_tp_size: {self._prefill_tp_size} must be greater than" diff --git a/vllm_ascend/distributed/utils.py b/vllm_ascend/distributed/utils.py index 4b1344a1..c25c1f15 100644 --- a/vllm_ascend/distributed/utils.py +++ b/vllm_ascend/distributed/utils.py @@ -1,3 +1,5 @@ +import os + import torch import torch.distributed as dist @@ -45,3 +47,15 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor: aligned_addr = (data_ptr + alignment - 1) // alignment * alignment offset = (aligned_addr - data_ptr) // tensor.element_size() return tensor[int(offset):] + + +def get_transfer_timeout_value(): + ascend_transfer_timeout = os.getenv("ASCEND_TRANSFER_TIMEOUT", "") + if len(ascend_transfer_timeout) > 0: + return int(ascend_transfer_timeout) + hccl_rdma_timeout = int(os.getenv('HCCL_RDMA_TIMEOUT', + '20')) # type: ignore + hccl_rdma_retry_cnt = int(os.getenv('HCCL_RDMA_RETRY_CNT', + '7')) # type: ignore + return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 + + 3000)