[main][BugFix] Avoided a bug of torch_npu.npu_mm_reduce_scatter_base when sp size >= 16 (#6168)
### What this PR does / why we need it?
If `sp` is enabled and `tp_size` >= 16,
`torch_npu.npu_mm_reduce_scatter_base` will raises a exception.
After consulting with the operator developer, we learned that the
operator does not work when `tp` = 16.
So, we disable the operator when `tp` = 16.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested
We started a server with `sp` enabled and `tp` = 16.
It started successfully.
```text
[0;36m(APIServer pid=1855938)[0;0m INFO: Started server process [1855938]
[0;36m(APIServer pid=1855938)[0;0m INFO: Waiting for application startup.
[0;36m(APIServer pid=1855938)[0;0m INFO: Application startup complete.
```
- vLLM version: v0.13.0
- vLLM main:
d68209402d
Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
@@ -72,13 +72,16 @@ def set_ascend_forward_context(
|
|||||||
# due to multiple warmups before actual capturing
|
# due to multiple warmups before actual capturing
|
||||||
forward_context.capturing = False
|
forward_context.capturing = False
|
||||||
|
|
||||||
|
# TODO: remove it when torch_npu.npu_mm_reduce_scatter_base supports tp_size >= 16.
|
||||||
|
mmrs_fusion = tp_world_size <= 8
|
||||||
|
|
||||||
# set for sequence parallelism, 1000 is the batch size concurrency threshold
|
# set for sequence parallelism, 1000 is the batch size concurrency threshold
|
||||||
# for enabling the flashcomm_v1 or sequence_parallelism feature.
|
# for enabling the flashcomm_v1 or sequence_parallelism feature.
|
||||||
# Currently, it is an empirical value. In normal scenarios, if the concurrency
|
# Currently, it is an empirical value. In normal scenarios, if the concurrency
|
||||||
# exceeds this threshold, the performance benefits can be maximized.
|
# exceeds this threshold, the performance benefits can be maximized.
|
||||||
# Conversely, if the concurrency is below the threshold,
|
# Conversely, if the concurrency is below the threshold,
|
||||||
# the performance may degrade due to the switching of communication methods.
|
# the performance may degrade due to the switching of communication methods.
|
||||||
mmrs_fusion = True
|
|
||||||
# main model and drafter model may have different architecture
|
# main model and drafter model may have different architecture
|
||||||
is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config)
|
is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config)
|
||||||
if is_context_moe_model:
|
if is_context_moe_model:
|
||||||
@@ -86,6 +89,7 @@ def set_ascend_forward_context(
|
|||||||
mmrs_fusion = False
|
mmrs_fusion = False
|
||||||
else:
|
else:
|
||||||
sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
|
sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
|
||||||
|
|
||||||
forward_context.mmrs_fusion = mmrs_fusion
|
forward_context.mmrs_fusion = mmrs_fusion
|
||||||
forward_context.num_tokens = num_tokens
|
forward_context.num_tokens = num_tokens
|
||||||
forward_context.sp_enabled = sp_enabled
|
forward_context.sp_enabled = sp_enabled
|
||||||
|
|||||||
Reference in New Issue
Block a user