adjusting the communication method in graph mode (#1194)

### What this PR does / why we need it?
Communication performance optimization: replace allreduce with
reduce_scatter+all_gather in MLA layer's TP group,to remove
stridedsliced and all_gather in MOE layer.
when tp > 1, It is enabled during the decode phase of the graph mode
when enable_multistream_moe、MLA, use_v1, and MC2 are used.
According to the end-to-end RL inference test results, this PR can bring
3% gain in the decode stage.

**Before Improvement**
Profiling kernel_details

![image](https://github.com/user-attachments/assets/1bb5dfa1-809b-410a-90c9-c5fd23cff003)
Evaluation

![image](https://github.com/user-attachments/assets/0b8ea0c7-88e7-410f-9ef4-f0cfe910cdc7)

![image](https://github.com/user-attachments/assets/94fde910-c125-4c2e-8de4-88fc3fafc057)

**After Improvement**
Profiling kernel_details

![image](https://github.com/user-attachments/assets/55fac0e0-11f2-4654-8fd4-287949e0b29e)
Evaluation

![image](https://github.com/user-attachments/assets/e923f74b-29c4-4171-9382-40a00cf05df0)

![image](https://github.com/user-attachments/assets/5dba7967-07ea-4926-a8be-804bfd34e3e4)

### Does this PR introduce _any_ user-facing change?
Users need to configure enable_multistream_moe=True

### How was this patch tested?
Add e2e test cases to cover code logic

Signed-off-by: sharonyunyun <zhangying134@huawei.com>
This commit is contained in:
sharonyunyun
2025-06-25 19:56:49 +08:00
committed by GitHub
parent 205cb85a1e
commit 941269a6c5
6 changed files with 195 additions and 37 deletions

View File

@@ -358,6 +358,7 @@ jobs:
pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py
# Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py will raise error. # Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py will raise error.
# To avoid oom, we need to run the test in a single process. # To avoid oom, we need to run the test in a single process.
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_topk pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_topk

View File

@@ -47,6 +47,32 @@ def test_models_distributed_QwQ():
vllm_model.generate_greedy(example_prompts, max_tokens) vllm_model.generate_greedy(example_prompts, max_tokens)
def test_models_distributed_DeepSeek_multistream_moe():
example_prompts = [
"Hello, my name is",
]
dtype = "half"
max_tokens = 5
with VllmRunner(
"vllm-ascend/DeepSeek-V3-Pruning",
dtype=dtype,
tensor_parallel_size=2,
distributed_executor_backend="mp",
additional_config={
"torchair_graph_config": {
"enabled": True,
"enable_multistream_moe": True,
},
"ascend_scheduler_config": {
"enabled": True,
},
"refresh": True,
},
enforce_eager=False,
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
def test_models_distributed_DeepSeek(): def test_models_distributed_DeepSeek():
example_prompts = [ example_prompts = [
"Hello, my name is", "Hello, my name is",

View File

@@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
MLAAttentionImpl) MLAAttentionImpl)
from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down from vllm.utils import cdiv, round_down
@@ -557,6 +558,7 @@ class AscendMLAImpl(MLAAttentionImpl):
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.tp_size = get_tensor_model_parallel_world_size()
ascend_config = get_ascend_config() ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
@@ -586,7 +588,7 @@ class AscendMLAImpl(MLAAttentionImpl):
x = torch.bmm(x, self.W_UV) x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V) # Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
return self.o_proj(x)[0] return self.o_proj(x, is_prefill=False)[0]
# Return `ql_nope`, `q_pe` # Return `ql_nope`, `q_pe`
def _q_proj_and_k_up_proj(self, x): def _q_proj_and_k_up_proj(self, x):
@@ -847,12 +849,12 @@ class AscendMLAImpl(MLAAttentionImpl):
current_ms_metadata = get_multistream_comm_context() current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None: if current_ms_metadata is None:
return self.o_proj(attn_output)[0] return self.o_proj(attn_output, is_prefill=True)[0]
else: else:
current_ms_metadata.before_comm_event.record() current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream): with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait() current_ms_metadata.before_comm_event.wait()
return self.o_proj(attn_output)[0] return self.o_proj(attn_output, is_prefill=True)[0]
def exec_kv( def exec_kv(
self, self,

View File

@@ -42,8 +42,7 @@ from vllm.distributed.parallel_state import get_dp_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear, ReplicatedLinear)
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
@@ -64,7 +63,8 @@ from vllm.sequence import IntermediateTensors
import vllm_ascend.envs as envs_ascend import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MLP from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLP,
CustomDeepseekV2RowParallelLinear)
from vllm_ascend.multistream.base import MSEventKey from vllm_ascend.multistream.base import MSEventKey
from vllm_ascend.multistream.context import ( from vllm_ascend.multistream.context import (
advance_step_multistream_layer_context, get_multistream_comm_context, advance_step_multistream_layer_context, get_multistream_comm_context,
@@ -325,11 +325,12 @@ class CustomDeepseekDBOMLAAttention(DeepseekV2MLAAttention):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj") prefix=f"{prefix}.kv_b_proj")
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, self.o_proj = CustomDeepseekV2RowParallelLinear(
self.hidden_size, self.num_heads * self.v_head_dim,
bias=False, self.hidden_size,
quant_config=quant_config, bias=False,
prefix=f"{prefix}.o_proj") quant_config=quant_config,
prefix=f"{prefix}.o_proj")
if rope_scaling: if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn' rope_scaling["rope_type"] = 'deepseek_yarn'

View File

@@ -34,9 +34,12 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
get_tp_group) get_tp_group, split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter)
from vllm.distributed.parallel_state import get_dp_group from vllm.distributed.parallel_state import get_dp_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
@@ -133,6 +136,80 @@ class CustomDeepseekV2MergedReplicatedLinear(ReplicatedLinear):
shard.copy_(loaded_weight) shard.copy_(loaded_weight)
class CustomDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear):
def forward(
self,
input_,
is_prefill=True
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
if not is_prefill and output_parallel.shape[0] % self.tp_size == 0:
output = tensor_model_parallel_reduce_scatter(output_parallel,
dim=0)
else:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
def forward(
self,
input_,
is_prefill=True
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
class CustomDeepseekV2MLP(nn.Module): class CustomDeepseekV2MLP(nn.Module):
def __init__( def __init__(
@@ -289,10 +366,11 @@ class CustomDeepseekV2MoE(nn.Module):
self.params_dtype = torch.get_default_dtype() self.params_dtype = torch.get_default_dtype()
def forward( def forward(self,
self, hidden_states: torch.Tensor,
hidden_states: torch.Tensor, attn_metadata: Optional[AttentionMetadata] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: replace_allreduce: bool = False) -> torch.Tensor:
if attn_metadata is None: if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata attn_metadata = get_forward_context().attn_metadata
# when profile runs, force experts to load balanced tokens # when profile runs, force experts to load balanced tokens
@@ -318,7 +396,7 @@ class CustomDeepseekV2MoE(nn.Module):
top_k=CustomDeepseekV2MoE.top_k, top_k=CustomDeepseekV2MoE.top_k,
enable_force_load_balance=enable_force_load_balance, enable_force_load_balance=enable_force_load_balance,
shared_experts=self.shared_experts, shared_experts=self.shared_experts,
) replace_allreduce=replace_allreduce)
hidden_states = ( hidden_states = (
experts_hidden_states[0] * self.routed_scaling_factor + experts_hidden_states[0] * self.routed_scaling_factor +
@@ -365,6 +443,14 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_mla = \
ascend_config.torchair_graph_config.enable_multistream_mla
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(self.hidden_size, self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank, self.q_lora_rank,
@@ -401,11 +487,23 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj") prefix=f"{prefix}.kv_b_proj")
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, if (config.n_routed_experts is not None
self.hidden_size, and self.debug_layer_idx >= config.first_k_dense_replace
bias=False, and self.debug_layer_idx % config.moe_layer_freq == 0 and
quant_config=quant_config, ascend_config.torchair_graph_config.enable_multistream_moe):
prefix=f"{prefix}.o_proj") self.o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
else:
self.o_proj = CustomDeepseekV2RowParallelLinear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
if rope_scaling: if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn' rope_scaling["rope_type"] = 'deepseek_yarn'
@@ -451,14 +549,6 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
o_proj=self.o_proj, o_proj=self.o_proj,
) )
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_mla = \
ascend_config.torchair_graph_config.enable_multistream_mla
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
@@ -524,6 +614,10 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
# with the layer's index. # with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1]) layer_idx = int(prefix.split(sep='.')[-1])
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.layers = config.num_hidden_layers
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tp_group().rank_in_group
ascend_config = get_ascend_config()
# TODO: enable mla in vllm-ascend # TODO: enable mla in vllm-ascend
if model_config.use_mla: if model_config.use_mla:
attn_cls = CustomDeepseekV2MLAAttention attn_cls = CustomDeepseekV2MLAAttention
@@ -555,6 +649,8 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
self.mla_moe_communication = ascend_config.torchair_graph_config.enable_multistream_moe \
and model_config.use_mla and envs.VLLM_USE_V1 and self.tp_size > 1
else: else:
self.mlp = CustomDeepseekV2MLP( self.mlp = CustomDeepseekV2MLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
@@ -563,11 +659,13 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
self.mla_moe_communication = False
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.first_k_dense_replace = config.first_k_dense_replace
def forward( def forward(
self, self,
@@ -576,8 +674,13 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor] = None, kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None, attn_metadata: Optional[AttentionMetadata] = None,
replace_allreduce: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
if attn_metadata is not None and attn_metadata.num_decodes > 0:
mla_moe_communication = self.mla_moe_communication and replace_allreduce
else:
mla_moe_communication = False
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
@@ -589,6 +692,9 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
# to save npu memory because they're no longer used. # to save npu memory because they're no longer used.
dispose_tensor(previous_hidden_states) dispose_tensor(previous_hidden_states)
dispose_tensor(previous_residual) dispose_tensor(previous_residual)
if mla_moe_communication and self.layer_idx > self.first_k_dense_replace:
hidden_states = tensor_model_parallel_all_gather(hidden_states,
dim=0)
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
@@ -597,6 +703,13 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
) )
if mla_moe_communication and residual.shape[0] != hidden_states.shape[
0]:
chunk_hidden_states = torch.tensor_split(residual,
self.tp_size,
dim=0)
residual = chunk_hidden_states[self.tp_rank]
if hidden_states.dtype == torch.float16: if hidden_states.dtype == torch.float16:
# Fix FP16 overflow # Fix FP16 overflow
# We scale both hidden_states and residual before # We scale both hidden_states and residual before
@@ -612,7 +725,9 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
hidden_states, residual) hidden_states, residual)
if isinstance(self.mlp, CustomDeepseekV2MoE): if isinstance(self.mlp, CustomDeepseekV2MoE):
hidden_states = self.mlp(hidden_states, attn_metadata) hidden_states = self.mlp(hidden_states,
attn_metadata,
replace_allreduce=mla_moe_communication)
else: else:
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
@@ -625,6 +740,10 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
# The scaling of DeepseekV2MOE output would be done in the forward # The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE # of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor hidden_states *= 1. / self.routed_scaling_factor
if mla_moe_communication and self.layer_idx == self.layers - 1:
hidden_states = tensor_model_parallel_all_gather(hidden_states,
dim=0)
residual = tensor_model_parallel_all_gather(residual, dim=0)
return hidden_states, residual return hidden_states, residual
@@ -643,6 +762,7 @@ class CustomDeepseekV2Model(nn.Module):
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.tp_size = get_tensor_model_parallel_world_size()
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
@@ -695,13 +815,18 @@ class CustomDeepseekV2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
replace_allreduce = hidden_states.shape[0] % self.tp_size == 0
for i in range(self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, residual, positions,
hidden_states,
residual,
kv_caches[i - kv_caches[i -
self.start_layer] if kv_caches is not None else None, self.start_layer] if kv_caches is not None else None,
attn_metadata) attn_metadata,
replace_allreduce=replace_allreduce)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({

View File

@@ -1211,7 +1211,8 @@ class AscendFusedMoE(FusedMoE):
is_prefill: bool, is_prefill: bool,
enable_force_load_balance: bool = False, enable_force_load_balance: bool = False,
top_k: Optional[int] = None, top_k: Optional[int] = None,
shared_experts: Optional[Any] = None): shared_experts: Optional[Any] = None,
replace_allreduce: bool = False):
assert self.quant_method is not None assert self.quant_method is not None
if top_k: if top_k:
@@ -1230,7 +1231,8 @@ class AscendFusedMoE(FusedMoE):
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
and fused_moe_state != FusedMoEState.AllGatherEP): and fused_moe_state != FusedMoEState.AllGatherEP
and not replace_allreduce):
if num_tokens < tp_size: if num_tokens < tp_size:
hidden_states = nn.functional.pad( hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, tp_size - num_tokens)) hidden_states, (0, 0, 0, tp_size - num_tokens))
@@ -1289,7 +1291,8 @@ class AscendFusedMoE(FusedMoE):
e_hidden_states, shared_hidden_states = e_hidden_states e_hidden_states, shared_hidden_states = e_hidden_states
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
and fused_moe_state != FusedMoEState.AllGatherEP): and fused_moe_state != FusedMoEState.AllGatherEP
and not replace_allreduce):
dist.all_gather(list(chunk_hidden_states), e_hidden_states, dist.all_gather(list(chunk_hidden_states), e_hidden_states,
self.tp_group) self.tp_group)
final_hidden_states = torch.cat(chunk_hidden_states, dim=0) final_hidden_states = torch.cat(chunk_hidden_states, dim=0)