diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 49693474..ea137b04 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -72,13 +72,16 @@ def set_ascend_forward_context( # due to multiple warmups before actual capturing forward_context.capturing = False + # TODO: remove it when torch_npu.npu_mm_reduce_scatter_base supports tp_size >= 16. + mmrs_fusion = tp_world_size <= 8 + # set for sequence parallelism, 1000 is the batch size concurrency threshold # for enabling the flashcomm_v1 or sequence_parallelism feature. # Currently, it is an empirical value. In normal scenarios, if the concurrency # exceeds this threshold, the performance benefits can be maximized. # Conversely, if the concurrency is below the threshold, # the performance may degrade due to the switching of communication methods. - mmrs_fusion = True + # main model and drafter model may have different architecture is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config) if is_context_moe_model: @@ -86,6 +89,7 @@ def set_ascend_forward_context( mmrs_fusion = False else: sp_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000 + forward_context.mmrs_fusion = mmrs_fusion forward_context.num_tokens = num_tokens forward_context.sp_enabled = sp_enabled