[bugfix] align max_num_batched_tokens with tp*pcp when using FLASHCOMM1 (#6000)
### What this PR does / why we need it?
Align max_num_batched_tokens with tp*pcp when using FLASHCOMM1 to avoid
assert error in `NPUModelRunner._dummy_run`.
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -5,6 +5,7 @@ import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
|
||||
if 'torch_npu._inductor' not in sys.modules:
|
||||
sys.modules['torch_npu._inductor'] = MagicMock()
|
||||
@@ -81,7 +82,13 @@ class TestAscendSFAMetadata(TestBase):
|
||||
|
||||
class TestAscendSFAMetadataBuilder(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
@patch('vllm.distributed.parallel_state._TP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
def setUp(self, mock_tp):
|
||||
mock_tp.world_size = 2
|
||||
mock_tp.rank_in_group = MagicMock()
|
||||
mock_tp.device_group = MagicMock()
|
||||
|
||||
self.mock_cfg = MagicMock()
|
||||
|
||||
self.mock_cfg.parallel_config = MagicMock()
|
||||
|
||||
Reference in New Issue
Block a user