adjusting the communication method in graph mode (#1194)
### What this PR does / why we need it? Communication performance optimization: replace allreduce with reduce_scatter+all_gather in MLA layer's TP group,to remove stridedsliced and all_gather in MOE layer. when tp > 1, It is enabled during the decode phase of the graph mode when enable_multistream_moe、MLA, use_v1, and MC2 are used. According to the end-to-end RL inference test results, this PR can bring 3% gain in the decode stage. **Before Improvement** Profiling kernel_details  Evaluation   **After Improvement** Profiling kernel_details  Evaluation   ### Does this PR introduce _any_ user-facing change? Users need to configure enable_multistream_moe=True ### How was this patch tested? Add e2e test cases to cover code logic Signed-off-by: sharonyunyun <zhangying134@huawei.com>
This commit is contained in:
@@ -42,8 +42,7 @@ from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
ReplicatedLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
@@ -64,7 +63,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MLP
|
||||
from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLP,
|
||||
CustomDeepseekV2RowParallelLinear)
|
||||
from vllm_ascend.multistream.base import MSEventKey
|
||||
from vllm_ascend.multistream.context import (
|
||||
advance_step_multistream_layer_context, get_multistream_comm_context,
|
||||
@@ -325,11 +325,12 @@ class CustomDeepseekDBOMLAAttention(DeepseekV2MLAAttention):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_b_proj")
|
||||
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
self.o_proj = CustomDeepseekV2RowParallelLinear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
|
||||
if rope_scaling:
|
||||
rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||
|
||||
Reference in New Issue
Block a user