[bugfix] align max_num_batched_tokens with tp*pcp when using FLASHCOMM1 (#6000)
### What this PR does / why we need it?
Align max_num_batched_tokens with tp*pcp when using FLASHCOMM1 to avoid
assert error in `NPUModelRunner._dummy_run`.
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -176,12 +176,16 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
vllm_config = MagicMock()
|
vllm_config = MagicMock()
|
||||||
speculative_config = MagicMock()
|
speculative_config = MagicMock()
|
||||||
model_config = MagicMock()
|
model_config = MagicMock()
|
||||||
|
parallel_config = MagicMock()
|
||||||
|
parallel_config.prefill_context_parallel_size = 1
|
||||||
|
parallel_config.tensor_parallel_size = 2
|
||||||
speculative_config.num_speculative_tokens = 4
|
speculative_config.num_speculative_tokens = 4
|
||||||
vllm_config.speculative_config = speculative_config
|
vllm_config.speculative_config = speculative_config
|
||||||
model_config.dtype = torch.float16
|
model_config.dtype = torch.float16
|
||||||
vllm_config.model_config = model_config
|
vllm_config.model_config = model_config
|
||||||
get_current_vllm_config.return_value = vllm_config
|
get_current_vllm_config.return_value = vllm_config
|
||||||
vllm_config.additional_config = {"refresh": True}
|
vllm_config.additional_config = {"refresh": True}
|
||||||
|
vllm_config.parallel_config = parallel_config
|
||||||
init_ascend_config(vllm_config)
|
init_ascend_config(vllm_config)
|
||||||
|
|
||||||
num_heads = 256
|
num_heads = 256
|
||||||
|
|||||||
@@ -757,12 +757,15 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
vllm_config = MagicMock()
|
vllm_config = MagicMock()
|
||||||
speculative_config = MagicMock()
|
speculative_config = MagicMock()
|
||||||
model_config = MagicMock()
|
model_config = MagicMock()
|
||||||
|
parallel_config = MagicMock()
|
||||||
|
parallel_config.prefill_context_parallel_size = 1
|
||||||
speculative_config.num_speculative_tokens = 4
|
speculative_config.num_speculative_tokens = 4
|
||||||
vllm_config.speculative_config = speculative_config
|
vllm_config.speculative_config = speculative_config
|
||||||
model_config.dtype = torch.float16
|
model_config.dtype = torch.float16
|
||||||
vllm_config.model_config = model_config
|
vllm_config.model_config = model_config
|
||||||
get_current_vllm_config.return_value = vllm_config
|
get_current_vllm_config.return_value = vllm_config
|
||||||
vllm_config.additional_config = {"refresh": True}
|
vllm_config.additional_config = {"refresh": True}
|
||||||
|
vllm_config.parallel_config = parallel_config
|
||||||
init_ascend_config(vllm_config)
|
init_ascend_config(vllm_config)
|
||||||
|
|
||||||
num_heads = 256
|
num_heads = 256
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import torch
|
|||||||
|
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
|
from vllm.distributed.parallel_state import GroupCoordinator
|
||||||
|
|
||||||
if 'torch_npu._inductor' not in sys.modules:
|
if 'torch_npu._inductor' not in sys.modules:
|
||||||
sys.modules['torch_npu._inductor'] = MagicMock()
|
sys.modules['torch_npu._inductor'] = MagicMock()
|
||||||
@@ -81,7 +82,13 @@ class TestAscendSFAMetadata(TestBase):
|
|||||||
|
|
||||||
class TestAscendSFAMetadataBuilder(TestBase):
|
class TestAscendSFAMetadataBuilder(TestBase):
|
||||||
|
|
||||||
def setUp(self):
|
@patch('vllm.distributed.parallel_state._TP',
|
||||||
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||||
|
def setUp(self, mock_tp):
|
||||||
|
mock_tp.world_size = 2
|
||||||
|
mock_tp.rank_in_group = MagicMock()
|
||||||
|
mock_tp.device_group = MagicMock()
|
||||||
|
|
||||||
self.mock_cfg = MagicMock()
|
self.mock_cfg = MagicMock()
|
||||||
|
|
||||||
self.mock_cfg.parallel_config = MagicMock()
|
self.mock_cfg.parallel_config = MagicMock()
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.triton_utils import HAS_TRITON
|
from vllm.triton_utils import HAS_TRITON
|
||||||
|
from vllm.utils.math_utils import cdiv
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@@ -62,10 +63,25 @@ class AscendConfig:
|
|||||||
additional_config.get("enable_shared_expert_dp", False)
|
additional_config.get("enable_shared_expert_dp", False)
|
||||||
and vllm_config.parallel_config.enable_expert_parallel
|
and vllm_config.parallel_config.enable_expert_parallel
|
||||||
)
|
)
|
||||||
if self.enable_shared_expert_dp:
|
from vllm_ascend.utils import enable_sp
|
||||||
from vllm_ascend.utils import enable_sp
|
|
||||||
|
|
||||||
|
if self.enable_shared_expert_dp:
|
||||||
assert enable_sp(vllm_config=vllm_config, enable_shared_expert_dp=True)
|
assert enable_sp(vllm_config=vllm_config, enable_shared_expert_dp=True)
|
||||||
|
|
||||||
|
if vllm_config.parallel_config.prefill_context_parallel_size > 1 and enable_sp(vllm_config=vllm_config):
|
||||||
|
tp_pcp_size = (
|
||||||
|
vllm_config.parallel_config.tensor_parallel_size
|
||||||
|
* vllm_config.parallel_config.prefill_context_parallel_size
|
||||||
|
)
|
||||||
|
if vllm_config.scheduler_config.max_num_batched_tokens % tp_pcp_size != 0:
|
||||||
|
vllm_config.scheduler_config.max_num_batched_tokens = (
|
||||||
|
cdiv(vllm_config.scheduler_config.max_num_batched_tokens, tp_pcp_size) * tp_pcp_size
|
||||||
|
)
|
||||||
|
logger.warning_once(
|
||||||
|
f"When using FLASHCOMM1, the max_num_batched_tokens should be divisible"
|
||||||
|
f"by tp_size * pcp_size ({tp_pcp_size}). It has been adjusted to"
|
||||||
|
f"{vllm_config.scheduler_config.max_num_batched_tokens}."
|
||||||
|
)
|
||||||
self.multistream_overlap_shared_expert = additional_config.get("multistream_overlap_shared_expert", False)
|
self.multistream_overlap_shared_expert = additional_config.get("multistream_overlap_shared_expert", False)
|
||||||
self.multistream_overlap_gate = additional_config.get("multistream_overlap_gate", False)
|
self.multistream_overlap_gate = additional_config.get("multistream_overlap_gate", False)
|
||||||
self.recompute_scheduler_enable = additional_config.get("recompute_scheduler_enable", False)
|
self.recompute_scheduler_enable = additional_config.get("recompute_scheduler_enable", False)
|
||||||
|
|||||||
Reference in New Issue
Block a user