mooncake connector support pipeline parallel & fix pp with flashcomm1 (#4054)

### What this PR does / why we need it?
To support pipeline parallel with PD disaggregation, this PR support PP
in mooncake connector and fix other bugs when enable pp with other
optimization params, including following changes:
- mooncake connector support pp in prefill, we do not support decode pp
currently
- fix bugs when enable both pp and flashcomm1
- optimize ascend-scheduler to support full batch in multiple pipeline
stages, original implementation would cause all pipeline stages
batch_size total summed to max_num_seq, which makes pipeline is not
full, this optimization can make all stages running with full batch_size
= max_num_seq, the same changes will contribute to vllm scheduler too.

### Does this PR introduce _any_ user-facing change?
add `pp_size` in mooncake connector kv_connector_extra_config
```
"kv_connector_extra_config": {
            "use_ascend_direct": true,
            "prefill": {
                    "dp_size": 1,
                    "tp_size": 4,
                    "pp_size": 4
             },
             "decode": {
                    "dp_size": 16,
                    "tp_size": 1
             }
        }
```

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Signed-off-by: Kurumi5210 <Jaychou1620@Gmail.com>
Signed-off-by: Kurumi5210 <jaychou1620@gmail.com>
Signed-off-by: 秋刀鱼 <jaychou1620@Gmail.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: zss <zss@qq.com>
Co-authored-by: zss <3265779424@qq.com>
This commit is contained in:
lidenghui1110
2025-12-10 16:01:43 +08:00
committed by GitHub
parent ce5872705e
commit a82b0fa70e
5 changed files with 394 additions and 141 deletions

View File

@@ -19,6 +19,20 @@ fake_engine = types.ModuleType("mooncake.engine")
fake_engine.TransferEngine = MagicMock() # type: ignore[attr-defined]
sys.modules["mooncake.engine"] = fake_engine
_mock_ascend_config = MagicMock(enable_kv_nz=False)
_mock_pp_group = MagicMock(rank_in_group=0, world_size=1)
_mock_tp_group = MagicMock(rank_in_group=0, world_size=4)
patch('vllm_ascend.distributed.mooncake_connector.get_pp_group',
return_value=_mock_pp_group).start()
patch('vllm_ascend.distributed.mooncake_connector.get_tp_group',
return_value=_mock_tp_group).start()
patch(
'vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_world_size',
return_value=4).start()
patch(
'vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank',
return_value=0).start()
from vllm_ascend.distributed.mooncake_connector import ( # noqa: E402
KVCacheRecvingThread, KVCacheSendingThread, KVCacheTaskTracker,
KVConnectorRole, MooncakeAgentMetadata, MooncakeConnector,
@@ -88,6 +102,7 @@ class TestKVCacheSendingThreadInit(unittest.TestCase):
'side_channel_host': 'localhost',
'side_channel_port': 5555,
'metadata': MagicMock(),
'vllm_config': MockVllmConfig(),
'ready_event': threading.Event(),
'kv_caches': kv_caches,
'pcp_rank': 0
@@ -130,6 +145,7 @@ class TestGetAndClearFinishedRequests(unittest.TestCase):
'prefill_tp_size': 4,
'local_engine_id': 'engine_1',
'side_channel_host': 'localhost',
'vllm_config': MockVllmConfig(),
'side_channel_port': 5555,
'metadata': {
"test": "metadata"
@@ -159,27 +175,32 @@ class TestKVCacheSendingThread(unittest.TestCase):
kv_caches_base_addr=[12345678],
num_blocks=2,
)
vllm_config = MockVllmConfig()
host = "127.0.0.1"
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
free_port = s.getsockname()[1]
base_port = s.getsockname()[1]
thread = KVCacheSendingThread(tp_rank=0,
prefill_tp_size=1,
local_engine_id="engine1",
side_channel_host=host,
side_channel_port=free_port,
side_channel_port=base_port,
metadata=metadata,
vllm_config=vllm_config,
ready_event=ready_event,
kv_caches={},
pcp_rank=0)
thread.start()
actual_port = base_port + (thread.pp_rank * thread.tp_size +
thread.tp_rank +
thread.pcp_rank * thread.prefill_tp_size)
self.assertTrue(ready_event.wait(timeout=3),
"Server thread startup timeout")
context = zmq.Context() # type: ignore
sock = context.socket(zmq.DEALER) # type: ignore
sock.connect(f"tcp://{host}:{free_port}")
sock.connect(f"tcp://{host}:{actual_port}")
encoder = msgspec.msgpack.Encoder()
decoder = msgspec.msgpack.Decoder(type=MooncakeAgentMetadata)
@@ -213,6 +234,7 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
_prefill_pp_size=1,
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
@@ -231,7 +253,7 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
"remote_host": "localhost",
"remote_handshake_port": 6666,
"offset": 0,
"num_need_pulls": 2,
"tp_num_need_pulls": 2,
"all_task_done": False
}
self.thread.add_request(
@@ -242,7 +264,7 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
remote_host=test_req["remote_host"],
remote_handshake_port=test_req["remote_handshake_port"],
offset=test_req["offset"],
num_need_pulls=test_req["num_need_pulls"],
tp_num_need_pulls=test_req["tp_num_need_pulls"],
all_task_done=test_req["all_task_done"])
queued = self.thread.request_queue.get_nowait()
self.assertEqual(queued["request_id"], "req1")
@@ -265,6 +287,7 @@ class TestSocketManagement(unittest.TestCase):
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
_prefill_pp_size=1,
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
@@ -315,10 +338,13 @@ class TestCoreFunctionality(unittest.TestCase):
self.ready_event = threading.Event()
self.mock_queue = MagicMock()
self.vllm_config = MockVllmConfig()
self.kv_caches: Dict[str, Any] = {}
self.kv_caches: Dict[str, Any] = {
"layer_0": (MagicMock(), MagicMock())
}
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
_prefill_pp_size=1,
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
@@ -337,7 +363,7 @@ class TestCoreFunctionality(unittest.TestCase):
"remote_handshake_port": 6666,
"remote_transfer_port": 7777,
"offset": 0,
"num_need_pulls": 2,
"tp_num_need_pulls": 2,
"all_task_done": False
}
self.thread.task_tracker = MagicMock()
@@ -362,12 +388,14 @@ class TestCoreFunctionality(unittest.TestCase):
@patch.object(KVCacheRecvingThread, '_get_remote_metadata')
def test_transfer_kv_cache(self, mock_get_meta):
self.thread.kv_caches_base_addr["remote_engine"] = {
6666: [0x3000, 0x4000]
}
self.thread._transfer_kv_cache(self.test_req)
with patch(
'vllm_ascend.distributed.mooncake_connector.get_ascend_config'
) as mock_config:
mock_config.return_value.enable_kv_nz = False
self.thread.kv_caches_base_addr["remote_engine"] = {
6666: [0x3000, 0x4000]
}
self.thread._transfer_kv_cache(self.test_req)
self.engine.batch_transfer_sync_read.assert_called_once()
call_args, call_kwargs = self.engine.batch_transfer_sync_read.call_args
self.assertEqual(call_args[0], "localhost:7777")
@@ -398,6 +426,7 @@ class TestMetadataHandling(unittest.TestCase):
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
_prefill_pp_size=1,
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
@@ -461,6 +490,7 @@ class TestMainThreadLoop(unittest.TestCase):
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
_prefill_pp_size=1,
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
@@ -482,7 +512,7 @@ class TestMainThreadLoop(unittest.TestCase):
"remote_handshake_port": 6666,
"remote_transfer_port": 7777,
"offset": 0,
"num_need_pulls": 2,
"tp_num_need_pulls": 2,
"all_task_done": False
}
@@ -509,6 +539,10 @@ class MockVllmConfig:
self.parallel_config.tensor_parallel_size = 2
self.parallel_config.data_parallel_rank = 0
self.parallel_config.data_parallel_size_local = 1
self.parallel_config.pipeline_parallel_size = 1
self.parallel_config.data_parallel_rank_local = 0
self.model_config.get_num_layers_by_block_type = MagicMock(
return_value=32)
self.cache_config.block_size = 16
self.kv_transfer_config.kv_port = 5000
self.kv_transfer_config.kv_role = 'kv_producer'
@@ -516,11 +550,13 @@ class MockVllmConfig:
self.kv_transfer_config.get_from_extra_config.side_effect = lambda k, d: {
"prefill": {
"tp_size": 2,
"dp_size": 1
"dp_size": 1,
"pp_size": 1
},
"decode": {
"tp_size": 2,
"dp_size": 1
"dp_size": 1,
"pp_size": 1
}
}.get(k, d)
self.additional_config = {}
@@ -1062,12 +1098,13 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
patch('torch.Tensor.element_size', return_value=4),
patch('torch.Tensor.data_ptr', return_value=0x1000),
patch('math.prod', return_value=128),
patch('random.Random'),
patch(
'vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank',
mock_get_tensor_model_parallel_rank),
patch('vllm_ascend.distributed.mooncake_connector.get_tp_group',
mock_get_tp_group),
patch('vllm_ascend.distributed.mooncake_connector.get_pp_group',
return_value=_mock_pp_group),
patch('vllm_ascend.distributed.mooncake_connector.get_ip',
mock_get_ip),
patch(
@@ -1096,8 +1133,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
patch(
'vllm_ascend.distributed.mooncake_connector.get_decode_context_model_parallel_world_size',
return_value=1),
patch('vllm_ascend.distributed.mooncake_connector.get_pcp_group',
return_value=self.mock_pcp_group),
patch(
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
return_value=MagicMock()),
@@ -1146,6 +1181,83 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
# Default tp_rank is 0, so device_id should be 10
self.assertIsNotNone(worker.engine)
def test_get_remote_tp_rank(self):
def get_tp_rank(prefill_tp_size: int, prefill_pp_size: int,
decode_tp_size: int, num_kv_heads: int,
tp_num_need_pulls: int, is_deepseek_mla: bool):
with patch('vllm_ascend.distributed.mooncake_connector.get_ascend_config',
return_value=MagicMock()), \
patch.object(self.vllm_config.kv_transfer_config, 'get_from_extra_config',
side_effect=lambda k, d=None: {
"prefill": {"tp_size": prefill_tp_size, "dp_size": 1, "pp_size": prefill_pp_size},
"decode": {"tp_size": decode_tp_size, "dp_size": 1, "pp_size": 1}
}.get(k, d)):
self.vllm_config.model_config.hf_config.num_key_value_heads = num_kv_heads
self.vllm_config.model_config.is_deepseek_mla = is_deepseek_mla
worker = MooncakeConnectorWorker(self.vllm_config,
self.engine_id)
worker.tp_num_need_pulls = tp_num_need_pulls
worker.use_sparse = 0
return worker._get_remote_ranks_for_req('test')
self.assertIn(
get_tp_rank(16, 1, 1, 4, 4, False)[0],
[[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]])
self.assertIn(
get_tp_rank(8, 1, 1, 4, 4, False)[0], [[0, 2, 4, 6], [1, 3, 5, 7]])
self.assertIn(get_tp_rank(4, 1, 1, 4, 4, False)[0], [[0, 1, 2, 3]])
self.assertIn(get_tp_rank(16, 1, 4, 4, 1, False),
[[[0], [4], [8], [12]], [[1], [5], [9], [13]],
[[2], [6], [10], [14]], [[3], [7], [11], [15]]])
self.assertIn(get_tp_rank(8, 1, 4, 4, 1, False),
[[[0], [2], [4], [6]], [[1], [3], [5], [7]]])
self.assertIn(get_tp_rank(4, 2, 2, 4, 2, False),
[[[0, 1, 4, 5], [2, 3, 6, 7]]])
self.assertIn(get_tp_rank(4, 1, 4, 4, 1, False),
[[[0], [1], [2], [3]]])
self.assertIn(
get_tp_rank(8, 2, 1, 4, 4, False)[0],
[[0, 2, 4, 6, 8, 10, 12, 14], [1, 3, 5, 7, 9, 11, 13, 15]])
self.assertIn(get_tp_rank(4, 2, 2, 4, 2, False),
[[[0, 1, 4, 5], [2, 3, 6, 7]]])
self.assertIn(get_tp_rank(2, 2, 1, 4, 2, False), [[[0, 1, 2, 3]]])
self.assertIn(
get_tp_rank(4, 4, 2, 8, 2, False),
[[[0, 1, 4, 5, 8, 9, 12, 13], [2, 3, 6, 7, 10, 11, 14, 15]]])
self.assertIn(
get_tp_rank(4, 2, 1, 4, 4, False)[0], [[0, 1, 2, 3, 4, 5, 6, 7]])
self.assertIn(
get_tp_rank(4, 4, 1, 4, 4, False)[0],
[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])
self.assertIn(get_tp_rank(8, 2, 4, 4, 1, False),
[[[0, 8], [2, 10], [4, 12], [6, 14]],
[[1, 9], [3, 11], [5, 13], [7, 15]]])
self.assertIn(get_tp_rank(4, 2, 4, 4, 4, False),
[[[0, 4], [1, 5], [2, 6], [3, 7]]])
self.assertIn(
get_tp_rank(4, 4, 4, 4, 1, False),
[[[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]])
self.assertIn(
get_tp_rank(16, 1, 1, 1, 1,
True)[0], [[0], [1], [2], [3], [4], [5], [6], [7], [8],
[9], [10], [11], [12], [13], [14], [15]])
self.assertIn(get_tp_rank(4, 1, 4, 1, 1, True), [[[0], [1], [2], [3]]])
self.assertIn(
get_tp_rank(8, 2, 1, 1, 1, True)[0],
[[0, 8], [2, 10], [4, 12], [6, 14], [1, 9], [3, 11], [5, 13],
[7, 15]])
self.assertIn(
get_tp_rank(4, 4, 1, 1, 1, True)[0],
[[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]])
self.assertIn(
get_tp_rank(8, 2, 4, 1, 1, True)[0],
[[0, 8], [2, 10], [4, 12], [6, 14], [1, 9], [3, 11], [5, 13],
[7, 15]])
self.assertIn(
get_tp_rank(4, 4, 4, 1, 1, True),
[[[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]])
if __name__ == '__main__':
unittest.main()

View File

@@ -22,29 +22,30 @@ from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
# yapf: disable
@pytest.mark.parametrize(
"soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method",
"soc_version, enable_expert_parallel, world_size, pipeline_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method",
[
# Case 1: Expert parallel is disabled, should always be 'allgather'
(AscendDeviceType._910B, False, 8, 100, 256, None, MoECommType.ALLGATHER),
(AscendDeviceType._910_93, False, 16, 500, 256, None, MoECommType.ALLGATHER),
(AscendDeviceType._910B, False, 8, 2, 100, 256, None, MoECommType.ALLGATHER),
(AscendDeviceType._910_93, False, 16, 2, 500, 256, None, MoECommType.ALLGATHER),
# Case 2: A2 SOC with w4a8_dynamic -> use alltoall when not mc2
(AscendDeviceType._910B, True, 8, 100, 256, "w4a8_dynamic", MoECommType.ALLTOALL),
(AscendDeviceType._910B, True, 16, 257, 256, "w4a8_dynamic", MoECommType.ALLTOALL),
(AscendDeviceType._910B, True, 16, 100, 256, "w4a8_dynamic", MoECommType.MC2), # meets mc2 condition
(AscendDeviceType._910B, True, 8, 1, 100, 256, "w4a8_dynamic", MoECommType.ALLTOALL),
(AscendDeviceType._910B, True, 16, 1, 257, 256, "w4a8_dynamic", MoECommType.ALLTOALL),
(AscendDeviceType._910B, True, 16, 1, 100, 256, "w4a8_dynamic", MoECommType.MC2), # meets mc2 condition
# Case 3: A2 SOC without w4a8_dynamic -> fallback to allgather
(AscendDeviceType._910B, True, 8, 100, 256, None, MoECommType.ALLGATHER),
(AscendDeviceType._910B, True, 16, 257, 256, None, MoECommType.ALLGATHER),
(AscendDeviceType._910B, True, 8, 2, 100, 256, None, MoECommType.ALLGATHER),
(AscendDeviceType._910B, True, 16, 2, 257, 256, None, MoECommType.ALLGATHER),
# Case 4: A3 SOC
(AscendDeviceType._910_93, True, 8, 100, 256, None, MoECommType.MC2),
(AscendDeviceType._910_93, True, 8, 257, 256, None, MoECommType.ALLTOALL),
(AscendDeviceType._910_93, True, 8, 2, 100, 256, None, MoECommType.MC2),
(AscendDeviceType._910_93, True, 8, 2, 257, 256, None, MoECommType.ALLTOALL),
])
# yapf: enable
def test_select_moe_comm_method(soc_version, enable_expert_parallel,
world_size, num_tokens, mc2_tokens_capacity,
quant_type, expected_method):
world_size, pipeline_size, num_tokens,
mc2_tokens_capacity, quant_type,
expected_method):
"""
Tests the _select_moe_comm_method with various configurations including quant_type.
"""
@@ -53,6 +54,7 @@ def test_select_moe_comm_method(soc_version, enable_expert_parallel,
mock_runner.parallel_config = MagicMock()
mock_runner.parallel_config.enable_expert_parallel = enable_expert_parallel
mock_runner.parallel_config.world_size_across_dp = world_size
mock_runner.parallel_config.pipeline_parallel_size = pipeline_size
mock_runner.mc2_tokens_capacity = mc2_tokens_capacity
# Add vllm_config.model_config.hf_config mock with moe_quantize

View File

@@ -27,8 +27,10 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import (
get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size, get_pcp_group,
get_tensor_model_parallel_rank, get_tp_group)
get_decode_context_model_parallel_world_size, get_pp_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group)
from vllm.distributed.utils import get_pp_indices
from vllm.logger import logger
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
from vllm.v1.core.sched.output import SchedulerOutput
@@ -38,6 +40,14 @@ from vllm.v1.request import RequestStatus
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
from vllm_ascend.distributed.mooncake_transfer_engine import global_te
from vllm_ascend.distributed.utils import get_transfer_timeout_value
from vllm_ascend.utils import prefill_context_parallel_enable
# isort: off
if prefill_context_parallel_enable():
from vllm.distributed import (get_prefill_context_model_parallel_rank,
get_prefill_context_model_parallel_world_size
)
# isort: on
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
@@ -159,14 +169,17 @@ class KVCacheTaskTracker:
class KVCacheSendingThread(threading.Thread):
def __init__(self, tp_rank: int, prefill_tp_size: int,
local_engine_id: str, side_channel_host: str,
side_channel_port: int, metadata: MooncakeAgentMetadata,
ready_event: threading.Event, kv_caches: dict[str, Any],
pcp_rank: int):
def __init__(self, vllm_config: VllmConfig, tp_rank: int,
prefill_tp_size: int, local_engine_id: str,
side_channel_host: str, side_channel_port: int,
metadata: MooncakeAgentMetadata, ready_event: threading.Event,
kv_caches: dict[str, Any], pcp_rank: int):
super().__init__(daemon=True, name="KVCacheSendingThread")
self.tp_rank = tp_rank
self.prefill_tp_size = prefill_tp_size
self.pp_rank = get_pp_group().rank_in_group
self.pp_size = vllm_config.parallel_config.pipeline_parallel_size
self.tp_size = get_tensor_model_parallel_world_size()
self.local_engine_id = local_engine_id
self.side_channel_host = side_channel_host
self.side_channel_port = side_channel_port
@@ -205,8 +218,8 @@ class KVCacheSendingThread(threading.Thread):
# NOTE(rob): we need each rank to have a unique port. This hack to keeps
# us moving. We will switch when moving to etcd or where we have a
# single ZMQ socket in the scheduler.
handshake_port = self.side_channel_port + self.pcp_rank * self.prefill_tp_size \
+ self.tp_rank
device_index = self.pp_rank * self.tp_size + self.tp_rank + self.pcp_rank * self.prefill_tp_size
handshake_port = self.side_channel_port + device_index
path = make_zmq_path("tcp", self.side_channel_host, handshake_port)
logger.info("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore
@@ -258,20 +271,22 @@ class KVCacheSendingThread(threading.Thread):
class KVCacheRecvingThread(threading.Thread):
def __init__(self, tp_rank: int, tp_size: int, engine: TransferEngine,
local_engine_id: str, local_handshake_port: int,
def __init__(self, tp_rank: int, tp_size: int, _prefill_pp_size: int,
engine: TransferEngine, local_engine_id: str,
local_handshake_port: int,
local_kv_caches_base_addr: list[int], block_len: list[int],
ready_event: threading.Event, vllm_config: VllmConfig,
kv_caches: dict[str, Any]):
super().__init__(daemon=True, name="KVCacheRecvingThread")
self.tp_rank = tp_rank
self.tp_size = tp_size
self._prefill_pp_size = _prefill_pp_size
self.local_engine_id = local_engine_id
self.local_handshake_port = local_handshake_port
self.engine = engine
self.ready_event = ready_event
self.kv_caches = kv_caches
self.kv_caches_base_addr: dict[str, dict[int, list[int]]] = \
SizedDict()
self.kv_caches_base_addr[local_engine_id][local_handshake_port] = \
@@ -299,13 +314,22 @@ class KVCacheRecvingThread(threading.Thread):
self.vllm_config = vllm_config
self.model_config = self.vllm_config.model_config
self.num_key_value_heads = self.model_config.hf_config.num_key_value_heads
self.kv_caches = kv_caches
self.block_size = self.vllm_config.cache_config.block_size
if self.use_mla:
self.k_head_dim = self.model_config.hf_config.kv_lora_rank
self.v_head_dim = self.model_config.hf_config.qk_rope_head_dim
self.num_kv_heads = 1
else:
self.k_head_dim = self.model_config.hf_config.head_dim
self.v_head_dim = self.model_config.hf_config.head_dim
self.num_kv_heads = max(
self.model_config.hf_config.num_key_value_heads //
self.tp_size, 1)
def add_request(self, request_id: str, local_block_ids: list[int],
remote_block_ids: list[int], remote_engine_id: str,
remote_host: str, remote_handshake_port: int, offset: int,
num_need_pulls: int, all_task_done: bool):
tp_num_need_pulls: int, all_task_done: bool):
"""Add a new request to the queue for processing."""
logger.debug(f"Adding request {request_id} to the queue.")
self.request_queue.put({
@@ -316,7 +340,7 @@ class KVCacheRecvingThread(threading.Thread):
"remote_host": remote_host,
"remote_handshake_port": remote_handshake_port,
"offset": offset,
"num_need_pulls": num_need_pulls,
"tp_num_need_pulls": tp_num_need_pulls,
"all_task_done": all_task_done
})
@@ -376,7 +400,7 @@ class KVCacheRecvingThread(threading.Thread):
remote_host = req_meta["remote_host"]
remote_handshake_port = req_meta["remote_handshake_port"]
offset = req_meta["offset"]
self.num_need_pulls = req_meta["num_need_pulls"]
tp_num_need_pulls = req_meta["tp_num_need_pulls"]
# Full prefix cache hit: do not need to read remote blocks, just notify
# P worker that we have the blocks we need.
@@ -394,7 +418,7 @@ class KVCacheRecvingThread(threading.Thread):
remote_handshake_port not in self.kv_caches_base_addr[remote_engine_id]:
self._get_remote_metadata(remote_host, remote_handshake_port)
if self.num_need_pulls == 1:
if tp_num_need_pulls == 1:
grouped_remote_block_ids, grouped_local_block_ids = \
group_concurrent_contiguous(remote_block_ids, local_block_ids)
else:
@@ -402,11 +426,25 @@ class KVCacheRecvingThread(threading.Thread):
local_block_ids = list(map(lambda x: [x], local_block_ids))
grouped_remote_block_ids, grouped_local_block_ids = remote_block_ids, local_block_ids
num_transfer_groups = len(grouped_remote_block_ids)
# tp_num_need_pulls: number of KV caches each Decode node needs to pull from each PP stage
# Due to GQA, different KV heads are distributed across different ranks, so there are offsets
# indicating which KV head to pull
global_offset = offset # Global offset of request across all ranks
prefill_pp_rank = offset // tp_num_need_pulls # PP rank where current request resides
inner_offset = offset % tp_num_need_pulls # Offset within each PP stage
remote_kv_caches_base_addrs = \
self.kv_caches_base_addr[remote_engine_id][remote_handshake_port]
num_layers = self.model_config.hf_config.num_hidden_layers
first_layer_index, end_layer_index = get_pp_indices(
num_layers, prefill_pp_rank, self._prefill_pp_size)
num_cache_per_layer = len(list(
self.kv_caches.values())[0]) # Number of KV caches per layer
local_kv_caches_base_addrs = \
self.kv_caches_base_addr[self.local_engine_id][self.local_handshake_port]
self.kv_caches_base_addr[self.local_engine_id][self.local_handshake_port][first_layer_index*num_cache_per_layer : end_layer_index*num_cache_per_layer]
logger.debug(
f"transfer kv cache first_layer_index:{first_layer_index} , end_layer_index:{end_layer_index}"
)
remote_transfer_port = self.remote_te_port[remote_engine_id][
remote_handshake_port]
num_blocks = len(local_block_ids)
@@ -422,11 +460,11 @@ class KVCacheRecvingThread(threading.Thread):
block_len = (self.block_len[k % 3])
else:
block_len = (self.block_len[0])
inner_block_len = block_len // self.num_need_pulls
inner_block_len = block_len // tp_num_need_pulls
for remote_block_id, local_block_id in zip(
grouped_remote_block_ids, grouped_local_block_ids):
src = src_layer_base_addr + local_block_id[
0] * block_len + offset * inner_block_len
0] * block_len + inner_offset * inner_block_len
dst = dst_layer_base_addr + remote_block_id[0] * inner_block_len
length = inner_block_len * len(local_block_id)
src_list.append(src)
@@ -447,10 +485,17 @@ class KVCacheRecvingThread(threading.Thread):
" %d blocks). local_ip %s local_device_id %s remote_session_id %s",
request_id, req_transfer_elapsed, num_transfer_groups, num_blocks,
get_ip(), self.tp_rank, session_id)
if self.num_need_pulls > 1 and offset == self.num_need_pulls - 1:
self._cat_kv_cache(grouped_local_block_ids)
def _cat_kv_cache(self, block_ids: list[list[int]]):
# Determine if the current position is the offset position at the end of the KV transmission.
is_kv_transfer_end = (
global_offset == tp_num_need_pulls * self._prefill_pp_size - 1)
need_cat_cache = tp_num_need_pulls > 1 and is_kv_transfer_end
# need_nz_cache maybe caused error in non-MLA models
if need_cat_cache:
self._cat_kv_cache(grouped_local_block_ids, tp_num_need_pulls)
def _cat_kv_cache(self, block_ids: list[list[int]],
tp_num_need_pulls: int):
# Get necessary parameters
k_cache = list(self.kv_caches.values())[0][0]
dtype = k_cache.dtype
@@ -506,9 +551,11 @@ class KVCacheRecvingThread(threading.Thread):
# Transpose KV cache
k_buffer = self._transpose_kv_cache_between_head(
k_buffer, num_blocks, block_size, block_len, num_kv_head)
k_buffer, num_blocks, block_size, block_len, num_kv_head,
tp_num_need_pulls)
v_buffer = self._transpose_kv_cache_between_head(
v_buffer, num_blocks, block_size, block_len, num_kv_head)
v_buffer, num_blocks, block_size, block_len, num_kv_head,
tp_num_need_pulls)
# Reshape and cache the processed buffers
torch_npu._npu_reshape_and_cache(
@@ -522,11 +569,11 @@ class KVCacheRecvingThread(threading.Thread):
# Clean up buffers
del k_buffer, v_buffer
def _transpose_kv_cache_between_head(self, buffer: torch.Tensor,
num_blocks: int, block_size: int,
block_len: int,
num_kv_head: int) -> torch.Tensor:
buffer = buffer.view(num_blocks, self.num_need_pulls, block_size, -1)
def _transpose_kv_cache_between_head(
self, buffer: torch.Tensor, num_blocks: int, block_size: int,
block_len: int, num_kv_head: int,
tp_num_need_pulls: int) -> torch.Tensor:
buffer = buffer.view(num_blocks, tp_num_need_pulls, block_size, -1)
buffer.transpose_(1, 2)
return buffer.contiguous().view(block_len, num_kv_head, -1)
@@ -631,8 +678,8 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
remote_pcp_size=kv_transfer_params["remote_pcp_size"],
remote_dcp_size=kv_transfer_params["remote_dcp_size"],
remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1),
remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1),
)
@@ -736,13 +783,17 @@ class MooncakeConnectorScheduler:
self.side_channel_host = get_ip()
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \
vllm_config.parallel_config.data_parallel_size * \
self.pcp_size * \
vllm_config.parallel_config.pipeline_parallel_size
# Handshake base port
self.side_channel_port = (
vllm_config.kv_transfer_config.kv_port +
vllm_config.parallel_config.data_parallel_rank *
vllm_config.parallel_config.tensor_parallel_size * self.pcp_size)
vllm_config.parallel_config.tensor_parallel_size *
vllm_config.parallel_config.pipeline_parallel_size * self.pcp_size)
# Requests that need to start recv.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
@@ -894,15 +945,24 @@ class MooncakeConnectorWorker:
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
self.tp_group = get_tp_group()
self.pp_rank = get_pp_group().rank_in_group
self.dp_rank = vllm_config.parallel_config.data_parallel_rank_local
self.dp_size = vllm_config.parallel_config.data_parallel_size_local
self.pp_size = vllm_config.parallel_config.pipeline_parallel_size
self.kv_caches: dict[str, torch.Tensor] = {}
self.side_channel_host = get_ip()
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group(
).rank_in_group if self.pcp_size > 1 else 0
self.pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
# Assert that pp_size and pcp_size cannot both be greater than 1
assert not (self.pp_size > 1 and self.pcp_size
> 1), "pp and pcp cannot open in same time"
self.pcp_rank = get_prefill_context_model_parallel_rank(
) if self.pcp_size > 1 else 0
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
self.max_device_id = self.tp_size * self.dp_size * self.pcp_size * self.pp_size
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
@@ -910,10 +970,12 @@ class MooncakeConnectorWorker:
self.side_channel_port = (
vllm_config.kv_transfer_config.kv_port +
vllm_config.parallel_config.data_parallel_rank *
vllm_config.parallel_config.tensor_parallel_size * self.pcp_size)
self.handshake_port = self.side_channel_port + self.pcp_rank * self.tp_size + self.tp_rank
vllm_config.parallel_config.tensor_parallel_size *
vllm_config.parallel_config.pipeline_parallel_size * self.pcp_size)
device_index = (self.pp_rank +
self.pcp_rank) * self.tp_size + self.tp_rank
self.handshake_port = self.side_channel_port + device_index
self.sockets: dict = {}
logger.info("Initializing Mooncake work %s", engine_id)
self.engine = global_te.get_transfer_engine(self.side_channel_host,
device_name=None)
self.te_rpc_port = self.engine.get_rpc_port()
@@ -926,13 +988,13 @@ class MooncakeConnectorWorker:
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
if self.vllm_config.model_config.is_deepseek_mla:
self.num_need_pulls = 1
self.tp_num_need_pulls = 1
else:
num_d_block_heads = max(1,
self.num_key_value_heads // self.tp_size)
num_p_block_heads = max(
1, self.num_key_value_heads // self._prefill_tp_size)
self.num_need_pulls = num_d_block_heads // num_p_block_heads
self.tp_num_need_pulls = num_d_block_heads // num_p_block_heads
def _get_prefill_decode_size(self, vllm_config: VllmConfig):
# get prefill tp and dp size from extra config
@@ -945,7 +1007,8 @@ class MooncakeConnectorWorker:
assert "dp_size" in prefill_parallel_config.keys()
self._prefill_dp_size = prefill_parallel_config["dp_size"]
# get prefill pp size from extra config
self._prefill_pp_size = prefill_parallel_config.get("pp_size", 1)
# get decode tp and dp size from extra config
decode_parallel_config: dict[
str, Any] = vllm_config.kv_transfer_config.get_from_extra_config(
@@ -954,6 +1017,9 @@ class MooncakeConnectorWorker:
self._decode_tp_size = decode_parallel_config["tp_size"]
assert "dp_size" in decode_parallel_config.keys()
self._decode_dp_size = decode_parallel_config["dp_size"]
# get prefill pp size from extra config
self._decode_pp_size = decode_parallel_config.get("pp_size", 1)
assert self._decode_pp_size == 1, "decode pp size must be 1"
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data."""
@@ -1052,15 +1118,15 @@ class MooncakeConnectorWorker:
ready_event = threading.Event()
if self.kv_role == 'kv_producer':
self.kv_send_thread = KVCacheSendingThread(
self.tp_rank, self._prefill_tp_size, self.engine_id,
self.side_channel_host, self.side_channel_port, metadata,
ready_event, self.kv_caches, self.pcp_rank)
self.vllm_config, self.tp_rank, self._prefill_tp_size,
self.engine_id, self.side_channel_host, self.side_channel_port,
metadata, ready_event, self.kv_caches, self.pcp_rank)
self.kv_send_thread.start()
else:
self.kv_recv_thread = KVCacheRecvingThread(
self.tp_rank, self.tp_size, self.engine, self.engine_id,
self.handshake_port, kv_caches_base_addr, self.block_len,
ready_event, self.vllm_config, self.kv_caches)
self.tp_rank, self.tp_size, self._prefill_pp_size, self.engine,
self.engine_id, self.handshake_port, kv_caches_base_addr,
self.block_len, ready_event, self.vllm_config, self.kv_caches)
self.kv_recv_thread.start()
ready_event.wait()
@@ -1089,7 +1155,7 @@ class MooncakeConnectorWorker:
Use this function to calculate remote port and remote block number of each remote P node that we need to pull.
"""
if meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1:
choosen_rank_list = self._get_remote_tp_rank(req_id)
choosen_rank_list = self._get_remote_rank(req_id)
remote_handshake_port_list = [[
x + meta.remote_port for x in choosen_rank_list
]]
@@ -1174,77 +1240,121 @@ class MooncakeConnectorWorker:
meta.remote_engine_id, len(meta.local_block_ids),
len(meta.remote_block_ids))
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata(
req_id, meta)
if prefill_context_parallel_enable():
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata(
req_id, meta)
for pcp_dcp_rank in range(len(remote_handshake_port_list)):
if len(local_block_ids_list[pcp_dcp_rank]) + len(
remote_block_ids_list[pcp_dcp_rank]) == 0:
continue
for i in range(self.num_need_pulls):
for pcp_dcp_rank in range(len(remote_handshake_port_list)):
if len(local_block_ids_list[pcp_dcp_rank]) + len(
remote_block_ids_list[pcp_dcp_rank]) == 0:
continue
for i in range(self.tp_num_need_pulls):
assert self.kv_recv_thread is not None
self.kv_recv_thread.add_request(
request_id=req_id,
local_block_ids=local_block_ids_list[pcp_dcp_rank],
remote_block_ids=remote_block_ids_list[
pcp_dcp_rank],
remote_engine_id=meta.remote_engine_id,
remote_host=meta.remote_host,
remote_handshake_port=remote_handshake_port_list[
pcp_dcp_rank][i],
offset=i,
tp_num_need_pulls=self.tp_num_need_pulls,
all_task_done=(
pcp_dcp_rank
== len(remote_handshake_port_list) - 1
and i == self.tp_num_need_pulls - 1))
else: #TODO: support prefill context parallel and pipeline parallel open at the same time
choosen_rank_list = self._get_remote_rank(req_id)
remote_handshake_port_list = [[x + meta.remote_port]
for x in choosen_rank_list]
for i in range(self.tp_num_need_pulls * self._prefill_pp_size):
assert self.kv_recv_thread is not None
self.kv_recv_thread.add_request(
request_id=req_id,
local_block_ids=local_block_ids_list[pcp_dcp_rank],
remote_block_ids=remote_block_ids_list[pcp_dcp_rank],
local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids,
remote_engine_id=meta.remote_engine_id,
remote_host=meta.remote_host,
remote_handshake_port=remote_handshake_port_list[
pcp_dcp_rank][i],
remote_handshake_port=remote_handshake_port_list[i][0],
offset=i,
num_need_pulls=self.num_need_pulls,
all_task_done=(pcp_dcp_rank
== len(remote_handshake_port_list) - 1
and i == self.num_need_pulls - 1))
tp_num_need_pulls=self.tp_num_need_pulls,
all_task_done=(i == self.tp_num_need_pulls *
self._prefill_pp_size - 1))
if self.kv_send_thread is not None:
for req_id, delay_start_time in metadata.requests_to_send.items():
if self.tp_rank in self._prefill_get_remote_tp_rank(req_id):
if self.tp_rank in self._prefill_get_remote_rank(req_id):
self.kv_send_thread.add_delayed_request(
req_id, delay_start_time)
else:
self.kv_send_thread.add_not_transfer_request(req_id)
def _prefill_get_remote_tp_rank(self, req_id: str) -> List[int]:
return sum(self._get_remote_tp_ranks_for_req(req_id), [])
def _prefill_get_remote_rank(self, req_id: str) -> List[int]:
return sum(self._get_remote_ranks_for_req(req_id), [])
def _get_remote_tp_rank(self, req_id: str) -> List[int]:
return self._get_remote_tp_ranks_for_req(req_id)[self.tp_rank]
def _get_remote_rank(self, req_id: str) -> List[int]:
return self._get_remote_ranks_for_req(req_id)[self.tp_rank]
def _get_remote_tp_ranks_for_req(self, req_id: str) -> List[List[int]]:
if self._prefill_tp_size == self._decode_tp_size:
result = list(map(lambda x: [x], range(self._prefill_tp_size)))
return result
seed = string_to_int64_hash(req_id)
rand = random.Random(seed)
sampled_nums = []
ori_data = np.arange(self._prefill_tp_size)
def _get_remote_tp_ranks(self, tp_ori_data: np.ndarray,
rand_group_index: list[int],
num_groups: int) -> List[List[int]]:
# random split prefill tp list
tp_sampled_nums = []
if self._prefill_tp_size > self.num_key_value_heads or self.vllm_config.model_config.is_deepseek_mla or self.use_sparse:
# use deepseek mla, num_key_value_heads == 128, but consider as 1
if self.vllm_config.model_config.is_deepseek_mla or self.use_sparse:
num_kv_head = 1
else:
num_kv_head = self.num_key_value_heads
num_groups = len(ori_data) // num_kv_head
ori_data = ori_data.reshape(-1, num_groups)
rand_group_index = rand.sample(range(num_groups), \
max(self._decode_tp_size // num_kv_head, 1)) # random choose a group
choosen_group = ori_data[:, [rand_group_index]]
tp_ori_data = tp_ori_data.reshape(-1, num_groups)
choosen_group = tp_ori_data[:, [rand_group_index]]
flattened = choosen_group.reshape(-1).tolist()
sampled_nums = [
flattened[i:i + self.num_need_pulls]
for i in range(0, len(flattened), self.num_need_pulls)
tp_sampled_nums = [
flattened[i:i + self.tp_num_need_pulls]
for i in range(0, len(flattened), self.tp_num_need_pulls)
]
# non-random split
else:
group_size = self._prefill_tp_size // self._decode_tp_size
for i in range(self._decode_tp_size):
ori_data_slice = ori_data[i * group_size:(i + 1) * group_size]
sampled_nums.append(ori_data_slice.tolist())
slice = tp_ori_data[i * group_size:(i + 1) * group_size]
tp_sampled_nums.append(slice.tolist())
return tp_sampled_nums
def _get_remote_ranks_for_req(self, req_id: str) -> List[List[int]]:
# Divide the ports according to the TP within the PP
sampled_nums = []
if self._prefill_tp_size == self._decode_tp_size:
sampled_nums = list(
map(
lambda tp: [
tp + pp * self._prefill_tp_size
for pp in range(self._prefill_pp_size)
], range(self._prefill_tp_size)))
return sampled_nums
# use deepseek mla, num_key_value_heads == 128, but consider as 1
if self.vllm_config.model_config.is_deepseek_mla or self.use_sparse:
num_kv_head = 1
else:
num_kv_head = self.num_key_value_heads
ori_data = np.arange(self._prefill_tp_size * self._prefill_pp_size)
seed = string_to_int64_hash(req_id)
rand = random.Random(seed)
# random split prefill tp list
ori_data = ori_data.reshape(self._prefill_pp_size, -1)
num_groups = max(
1,
len(ori_data[0]) // num_kv_head
) # The number of redundant copies for each KV head within the PP stage
rand_group_index = rand.sample(range(num_groups), \
(max(self._decode_tp_size // num_kv_head, 1))) # random choose a group
all_results = [
self._get_remote_tp_ranks(ori_data[pp_index], rand_group_index,
num_groups)
for pp_index in range(self._prefill_pp_size)
]
for group_index in range(len(all_results[0])):
group = []
for pp_index in range(self._prefill_pp_size):
group.extend(all_results[pp_index][group_index])
sampled_nums.append(group)
return sampled_nums

View File

@@ -46,7 +46,8 @@ from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed import tensor_model_parallel_all_gather
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
@@ -1765,11 +1766,22 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
else:
assert intermediate_tensors is not None
assert self.intermediate_tensors is not None
# If both flashcomm1 and pp are used simultaneously,
# the shape of the received data and the shape of the space to be copied to will not match,
# requiring a recalculation of the incoming data's shape.
tp_size = get_tensor_model_parallel_world_size()
num_input_tokens_with_flashcomm1 = num_input_tokens
if enable_sp():
num_input_tokens_with_flashcomm1 = (num_input_tokens +
tp_size - 1) // tp_size
for k, v in intermediate_tensors.items():
self.intermediate_tensors[k][:num_input_tokens].copy_(
v[:num_input_tokens], non_blocking=True)
self.intermediate_tensors[
k][:num_input_tokens_with_flashcomm1].copy_(
v[:num_input_tokens_with_flashcomm1],
non_blocking=True)
intermediate_tensors = IntermediateTensors({
k: v[:num_input_tokens]
k:
v[:num_input_tokens_with_flashcomm1]
for k, v in self.intermediate_tensors.items()
})
@@ -2044,7 +2056,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
update_attn_params(self.update_stream, forward_context,
maybe_padded_num_tokens)
if get_forward_context().sp_enabled:
if get_forward_context().sp_enabled and not isinstance(
hidden_states, IntermediateTensors):
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
pad_size = get_forward_context().pad_size
if pad_size > 0:
@@ -2366,7 +2379,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendDeviceType._910B}:
if (num_tokens <= self.mc2_tokens_capacity
and self.parallel_config.world_size_across_dp >= 16):
and self.parallel_config.world_size_across_dp /
self.parallel_config.pipeline_parallel_size >= 16):
moe_comm_type = MoECommType.MC2
else:
# Currently, w4a8_dynamic does not support allgatherep
@@ -3131,10 +3145,16 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
# When PP and flashcomm1 are enabled, during dummy_run the estimated space should divide num_tokens by tp_size;
# otherwise, on non-first PP ranks it would effectively perform an extra all-gather, leading to incorrect memory estimation and potentially causing OOM.
actual_tokens = num_tokens
if enable_sp():
tp_size = get_tensor_model_parallel_world_size()
actual_tokens = num_tokens // tp_size
if self.intermediate_tensors is None:
self.intermediate_tensors = (
self.model.make_empty_intermediate_tensors(
batch_size=num_tokens,
batch_size=actual_tokens,
dtype=self.dtype,
device=self.device))
intermediate_tensors = IntermediateTensors({

View File

@@ -52,9 +52,9 @@ from vllm_ascend.device_allocator.camem import CaMemAllocator
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (check_ascend_device_type, is_enable_nz,
register_ascend_customop, sleep_mode_enabled,
try_register_lib)
from vllm_ascend.utils import (check_ascend_device_type, enable_sp,
is_enable_nz, register_ascend_customop,
sleep_mode_enabled, try_register_lib)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
torch._dynamo.trace_rules.clear_lru_cache() # noqa: E402
@@ -296,9 +296,14 @@ class NPUWorker(WorkerBase):
intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
if forward_pass and not get_pp_group().is_first_rank:
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise it will conflict with the all-gather operation in flashcomm1.
if enable_sp():
all_gather_group = None
else:
all_gather_group = get_tp_group()
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
all_gather_group=all_gather_group))
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
@@ -309,9 +314,13 @@ class NPUWorker(WorkerBase):
parallel_config = self.vllm_config.parallel_config
assert parallel_config.distributed_executor_backend != (
"external_launcher") and not get_pp_group().is_last_rank
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise it will conflict with the all-gather operation in flashcomm1.
if enable_sp():
all_gather_group = None
else:
all_gather_group = get_tp_group()
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
all_gather_group=all_gather_group)
kv_connector_output = output.kv_connector_output
if not kv_connector_output: