adjusting the communication method in graph mode (#1194)

### What this PR does / why we need it?
Communication performance optimization: replace allreduce with
reduce_scatter+all_gather in MLA layer's TP group,to remove
stridedsliced and all_gather in MOE layer.
when tp > 1, It is enabled during the decode phase of the graph mode
when enable_multistream_moe、MLA, use_v1, and MC2 are used.
According to the end-to-end RL inference test results, this PR can bring
3% gain in the decode stage.

**Before Improvement**
Profiling kernel_details

![image](https://github.com/user-attachments/assets/1bb5dfa1-809b-410a-90c9-c5fd23cff003)
Evaluation

![image](https://github.com/user-attachments/assets/0b8ea0c7-88e7-410f-9ef4-f0cfe910cdc7)

![image](https://github.com/user-attachments/assets/94fde910-c125-4c2e-8de4-88fc3fafc057)

**After Improvement**
Profiling kernel_details

![image](https://github.com/user-attachments/assets/55fac0e0-11f2-4654-8fd4-287949e0b29e)
Evaluation

![image](https://github.com/user-attachments/assets/e923f74b-29c4-4171-9382-40a00cf05df0)

![image](https://github.com/user-attachments/assets/5dba7967-07ea-4926-a8be-804bfd34e3e4)

### Does this PR introduce _any_ user-facing change?
Users need to configure enable_multistream_moe=True

### How was this patch tested?
Add e2e test cases to cover code logic

Signed-off-by: sharonyunyun <zhangying134@huawei.com>
This commit is contained in:
sharonyunyun
2025-06-25 19:56:49 +08:00
committed by GitHub
parent 205cb85a1e
commit 941269a6c5
6 changed files with 195 additions and 37 deletions

View File

@@ -358,6 +358,7 @@ jobs:
pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py
# Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py will raise error.
# To avoid oom, we need to run the test in a single process.
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_topk

View File

@@ -47,6 +47,32 @@ def test_models_distributed_QwQ():
vllm_model.generate_greedy(example_prompts, max_tokens)
def test_models_distributed_DeepSeek_multistream_moe():
example_prompts = [
"Hello, my name is",
]
dtype = "half"
max_tokens = 5
with VllmRunner(
"vllm-ascend/DeepSeek-V3-Pruning",
dtype=dtype,
tensor_parallel_size=2,
distributed_executor_backend="mp",
additional_config={
"torchair_graph_config": {
"enabled": True,
"enable_multistream_moe": True,
},
"ascend_scheduler_config": {
"enabled": True,
},
"refresh": True,
},
enforce_eager=False,
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
def test_models_distributed_DeepSeek():
example_prompts = [
"Hello, my name is",

View File

@@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
MLAAttentionImpl)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down
@@ -557,6 +558,7 @@ class AscendMLAImpl(MLAAttentionImpl):
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.tp_size = get_tensor_model_parallel_world_size()
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
@@ -586,7 +588,7 @@ class AscendMLAImpl(MLAAttentionImpl):
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
return self.o_proj(x)[0]
return self.o_proj(x, is_prefill=False)[0]
# Return `ql_nope`, `q_pe`
def _q_proj_and_k_up_proj(self, x):
@@ -847,12 +849,12 @@ class AscendMLAImpl(MLAAttentionImpl):
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self.o_proj(attn_output)[0]
return self.o_proj(attn_output, is_prefill=True)[0]
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self.o_proj(attn_output)[0]
return self.o_proj(attn_output, is_prefill=True)[0]
def exec_kv(
self,

View File

@@ -42,8 +42,7 @@ from vllm.distributed.parallel_state import get_dp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
ReplicatedLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
@@ -64,7 +63,8 @@ from vllm.sequence import IntermediateTensors
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MLP
from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLP,
CustomDeepseekV2RowParallelLinear)
from vllm_ascend.multistream.base import MSEventKey
from vllm_ascend.multistream.context import (
advance_step_multistream_layer_context, get_multistream_comm_context,
@@ -325,11 +325,12 @@ class CustomDeepseekDBOMLAAttention(DeepseekV2MLAAttention):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj")
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
self.o_proj = CustomDeepseekV2RowParallelLinear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'

View File

@@ -34,9 +34,12 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group,
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group)
get_tp_group, split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter)
from vllm.distributed.parallel_state import get_dp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul
@@ -133,6 +136,80 @@ class CustomDeepseekV2MergedReplicatedLinear(ReplicatedLinear):
shard.copy_(loaded_weight)
class CustomDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear):
def forward(
self,
input_,
is_prefill=True
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
if not is_prefill and output_parallel.shape[0] % self.tp_size == 0:
output = tensor_model_parallel_reduce_scatter(output_parallel,
dim=0)
else:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
def forward(
self,
input_,
is_prefill=True
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
class CustomDeepseekV2MLP(nn.Module):
def __init__(
@@ -289,10 +366,11 @@ class CustomDeepseekV2MoE(nn.Module):
self.params_dtype = torch.get_default_dtype()
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
def forward(self,
hidden_states: torch.Tensor,
attn_metadata: Optional[AttentionMetadata] = None,
replace_allreduce: bool = False) -> torch.Tensor:
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
# when profile runs, force experts to load balanced tokens
@@ -318,7 +396,7 @@ class CustomDeepseekV2MoE(nn.Module):
top_k=CustomDeepseekV2MoE.top_k,
enable_force_load_balance=enable_force_load_balance,
shared_experts=self.shared_experts,
)
replace_allreduce=replace_allreduce)
hidden_states = (
experts_hidden_states[0] * self.routed_scaling_factor +
@@ -365,6 +443,14 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_mla = \
ascend_config.torchair_graph_config.enable_multistream_mla
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
@@ -401,11 +487,23 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj")
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
if (config.n_routed_experts is not None
and self.debug_layer_idx >= config.first_k_dense_replace
and self.debug_layer_idx % config.moe_layer_freq == 0 and
ascend_config.torchair_graph_config.enable_multistream_moe):
self.o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
else:
self.o_proj = CustomDeepseekV2RowParallelLinear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
@@ -451,14 +549,6 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
o_proj=self.o_proj,
)
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_mla = \
ascend_config.torchair_graph_config.enable_multistream_mla
def forward(
self,
positions: torch.Tensor,
@@ -524,6 +614,10 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
# with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1])
self.layer_idx = layer_idx
self.layers = config.num_hidden_layers
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tp_group().rank_in_group
ascend_config = get_ascend_config()
# TODO: enable mla in vllm-ascend
if model_config.use_mla:
attn_cls = CustomDeepseekV2MLAAttention
@@ -555,6 +649,8 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.mla_moe_communication = ascend_config.torchair_graph_config.enable_multistream_moe \
and model_config.use_mla and envs.VLLM_USE_V1 and self.tp_size > 1
else:
self.mlp = CustomDeepseekV2MLP(
hidden_size=config.hidden_size,
@@ -563,11 +659,13 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.mla_moe_communication = False
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.routed_scaling_factor = config.routed_scaling_factor
self.first_k_dense_replace = config.first_k_dense_replace
def forward(
self,
@@ -576,8 +674,13 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
residual: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None,
replace_allreduce: bool = False,
) -> torch.Tensor:
# Self Attention
if attn_metadata is not None and attn_metadata.num_decodes > 0:
mla_moe_communication = self.mla_moe_communication and replace_allreduce
else:
mla_moe_communication = False
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
@@ -589,6 +692,9 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
# to save npu memory because they're no longer used.
dispose_tensor(previous_hidden_states)
dispose_tensor(previous_residual)
if mla_moe_communication and self.layer_idx > self.first_k_dense_replace:
hidden_states = tensor_model_parallel_all_gather(hidden_states,
dim=0)
hidden_states = self.self_attn(
positions=positions,
@@ -597,6 +703,13 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
attn_metadata=attn_metadata,
)
if mla_moe_communication and residual.shape[0] != hidden_states.shape[
0]:
chunk_hidden_states = torch.tensor_split(residual,
self.tp_size,
dim=0)
residual = chunk_hidden_states[self.tp_rank]
if hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# We scale both hidden_states and residual before
@@ -612,7 +725,9 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
hidden_states, residual)
if isinstance(self.mlp, CustomDeepseekV2MoE):
hidden_states = self.mlp(hidden_states, attn_metadata)
hidden_states = self.mlp(hidden_states,
attn_metadata,
replace_allreduce=mla_moe_communication)
else:
hidden_states = self.mlp(hidden_states)
@@ -625,6 +740,10 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor
if mla_moe_communication and self.layer_idx == self.layers - 1:
hidden_states = tensor_model_parallel_all_gather(hidden_states,
dim=0)
residual = tensor_model_parallel_all_gather(residual, dim=0)
return hidden_states, residual
@@ -643,6 +762,7 @@ class CustomDeepseekV2Model(nn.Module):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.tp_size = get_tensor_model_parallel_world_size()
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
@@ -695,13 +815,18 @@ class CustomDeepseekV2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
replace_allreduce = hidden_states.shape[0] % self.tp_size == 0
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, residual,
positions,
hidden_states,
residual,
kv_caches[i -
self.start_layer] if kv_caches is not None else None,
attn_metadata)
attn_metadata,
replace_allreduce=replace_allreduce)
if not get_pp_group().is_last_rank:
return IntermediateTensors({

View File

@@ -1211,7 +1211,8 @@ class AscendFusedMoE(FusedMoE):
is_prefill: bool,
enable_force_load_balance: bool = False,
top_k: Optional[int] = None,
shared_experts: Optional[Any] = None):
shared_experts: Optional[Any] = None,
replace_allreduce: bool = False):
assert self.quant_method is not None
if top_k:
@@ -1230,7 +1231,8 @@ class AscendFusedMoE(FusedMoE):
tp_size = get_tensor_model_parallel_world_size()
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
and fused_moe_state != FusedMoEState.AllGatherEP):
and fused_moe_state != FusedMoEState.AllGatherEP
and not replace_allreduce):
if num_tokens < tp_size:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, tp_size - num_tokens))
@@ -1289,7 +1291,8 @@ class AscendFusedMoE(FusedMoE):
e_hidden_states, shared_hidden_states = e_hidden_states
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
and fused_moe_state != FusedMoEState.AllGatherEP):
and fused_moe_state != FusedMoEState.AllGatherEP
and not replace_allreduce):
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
self.tp_group)
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)