adjusting the communication method in graph mode (#1194)
### What this PR does / why we need it? Communication performance optimization: replace allreduce with reduce_scatter+all_gather in MLA layer's TP group,to remove stridedsliced and all_gather in MOE layer. when tp > 1, It is enabled during the decode phase of the graph mode when enable_multistream_moe、MLA, use_v1, and MC2 are used. According to the end-to-end RL inference test results, this PR can bring 3% gain in the decode stage. **Before Improvement** Profiling kernel_details  Evaluation   **After Improvement** Profiling kernel_details  Evaluation   ### Does this PR introduce _any_ user-facing change? Users need to configure enable_multistream_moe=True ### How was this patch tested? Add e2e test cases to cover code logic Signed-off-by: sharonyunyun <zhangying134@huawei.com>
This commit is contained in:
1
.github/workflows/vllm_ascend_test.yaml
vendored
1
.github/workflows/vllm_ascend_test.yaml
vendored
@@ -358,6 +358,7 @@ jobs:
|
|||||||
pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py
|
pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py
|
||||||
# Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py will raise error.
|
# Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py will raise error.
|
||||||
# To avoid oom, we need to run the test in a single process.
|
# To avoid oom, we need to run the test in a single process.
|
||||||
|
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe
|
||||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
|
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
|
||||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek
|
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek
|
||||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
|
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
|
||||||
|
|||||||
@@ -47,6 +47,32 @@ def test_models_distributed_QwQ():
|
|||||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def test_models_distributed_DeepSeek_multistream_moe():
|
||||||
|
example_prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
]
|
||||||
|
dtype = "half"
|
||||||
|
max_tokens = 5
|
||||||
|
with VllmRunner(
|
||||||
|
"vllm-ascend/DeepSeek-V3-Pruning",
|
||||||
|
dtype=dtype,
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
distributed_executor_backend="mp",
|
||||||
|
additional_config={
|
||||||
|
"torchair_graph_config": {
|
||||||
|
"enabled": True,
|
||||||
|
"enable_multistream_moe": True,
|
||||||
|
},
|
||||||
|
"ascend_scheduler_config": {
|
||||||
|
"enabled": True,
|
||||||
|
},
|
||||||
|
"refresh": True,
|
||||||
|
},
|
||||||
|
enforce_eager=False,
|
||||||
|
) as vllm_model:
|
||||||
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
|
|
||||||
def test_models_distributed_DeepSeek():
|
def test_models_distributed_DeepSeek():
|
||||||
example_prompts = [
|
example_prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
|||||||
MLAAttentionImpl)
|
MLAAttentionImpl)
|
||||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.linear import (LinearBase,
|
from vllm.model_executor.layers.linear import (LinearBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
from vllm.utils import cdiv, round_down
|
from vllm.utils import cdiv, round_down
|
||||||
@@ -557,6 +558,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
|
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
|
||||||
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
|
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||||
@@ -586,7 +588,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
x = torch.bmm(x, self.W_UV)
|
x = torch.bmm(x, self.W_UV)
|
||||||
# Convert from (N, B, V) to (B, N * V)
|
# Convert from (N, B, V) to (B, N * V)
|
||||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||||
return self.o_proj(x)[0]
|
return self.o_proj(x, is_prefill=False)[0]
|
||||||
|
|
||||||
# Return `ql_nope`, `q_pe`
|
# Return `ql_nope`, `q_pe`
|
||||||
def _q_proj_and_k_up_proj(self, x):
|
def _q_proj_and_k_up_proj(self, x):
|
||||||
@@ -847,12 +849,12 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
|
|
||||||
current_ms_metadata = get_multistream_comm_context()
|
current_ms_metadata = get_multistream_comm_context()
|
||||||
if current_ms_metadata is None:
|
if current_ms_metadata is None:
|
||||||
return self.o_proj(attn_output)[0]
|
return self.o_proj(attn_output, is_prefill=True)[0]
|
||||||
else:
|
else:
|
||||||
current_ms_metadata.before_comm_event.record()
|
current_ms_metadata.before_comm_event.record()
|
||||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||||
current_ms_metadata.before_comm_event.wait()
|
current_ms_metadata.before_comm_event.wait()
|
||||||
return self.o_proj(attn_output)[0]
|
return self.o_proj(attn_output, is_prefill=True)[0]
|
||||||
|
|
||||||
def exec_kv(
|
def exec_kv(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -42,8 +42,7 @@ from vllm.distributed.parallel_state import get_dp_group
|
|||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear)
|
||||||
RowParallelLinear)
|
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
@@ -64,7 +63,8 @@ from vllm.sequence import IntermediateTensors
|
|||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MLP
|
from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLP,
|
||||||
|
CustomDeepseekV2RowParallelLinear)
|
||||||
from vllm_ascend.multistream.base import MSEventKey
|
from vllm_ascend.multistream.base import MSEventKey
|
||||||
from vllm_ascend.multistream.context import (
|
from vllm_ascend.multistream.context import (
|
||||||
advance_step_multistream_layer_context, get_multistream_comm_context,
|
advance_step_multistream_layer_context, get_multistream_comm_context,
|
||||||
@@ -325,11 +325,12 @@ class CustomDeepseekDBOMLAAttention(DeepseekV2MLAAttention):
|
|||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.kv_b_proj")
|
prefix=f"{prefix}.kv_b_proj")
|
||||||
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
|
self.o_proj = CustomDeepseekV2RowParallelLinear(
|
||||||
self.hidden_size,
|
self.num_heads * self.v_head_dim,
|
||||||
bias=False,
|
self.hidden_size,
|
||||||
quant_config=quant_config,
|
bias=False,
|
||||||
prefix=f"{prefix}.o_proj")
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj")
|
||||||
|
|
||||||
if rope_scaling:
|
if rope_scaling:
|
||||||
rope_scaling["rope_type"] = 'deepseek_yarn'
|
rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||||
|
|||||||
@@ -34,9 +34,12 @@ from torch import nn
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.distributed import (get_pp_group,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
get_tp_group)
|
get_tp_group, split_tensor_along_last_dim,
|
||||||
|
tensor_model_parallel_all_gather,
|
||||||
|
tensor_model_parallel_all_reduce,
|
||||||
|
tensor_model_parallel_reduce_scatter)
|
||||||
from vllm.distributed.parallel_state import get_dp_group
|
from vllm.distributed.parallel_state import get_dp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
@@ -133,6 +136,80 @@ class CustomDeepseekV2MergedReplicatedLinear(ReplicatedLinear):
|
|||||||
shard.copy_(loaded_weight)
|
shard.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear):
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_,
|
||||||
|
is_prefill=True
|
||||||
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
|
||||||
|
if self.input_is_parallel:
|
||||||
|
input_parallel = input_
|
||||||
|
else:
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
splitted_input = split_tensor_along_last_dim(
|
||||||
|
input_, num_partitions=self.tp_size)
|
||||||
|
input_parallel = splitted_input[tp_rank].contiguous()
|
||||||
|
|
||||||
|
# Matrix multiply.
|
||||||
|
assert self.quant_method is not None
|
||||||
|
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||||
|
# bias will not get added more than once in TP>1 case)
|
||||||
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||||
|
output_parallel = self.quant_method.apply(self,
|
||||||
|
input_parallel,
|
||||||
|
bias=bias_)
|
||||||
|
if self.reduce_results and self.tp_size > 1:
|
||||||
|
if not is_prefill and output_parallel.shape[0] % self.tp_size == 0:
|
||||||
|
output = tensor_model_parallel_reduce_scatter(output_parallel,
|
||||||
|
dim=0)
|
||||||
|
else:
|
||||||
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||||
|
else:
|
||||||
|
output = output_parallel
|
||||||
|
|
||||||
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
|
|
||||||
|
if not self.return_bias:
|
||||||
|
return output
|
||||||
|
return output, output_bias
|
||||||
|
|
||||||
|
|
||||||
|
class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_,
|
||||||
|
is_prefill=True
|
||||||
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
|
||||||
|
if self.input_is_parallel:
|
||||||
|
input_parallel = input_
|
||||||
|
else:
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
splitted_input = split_tensor_along_last_dim(
|
||||||
|
input_, num_partitions=self.tp_size)
|
||||||
|
input_parallel = splitted_input[tp_rank].contiguous()
|
||||||
|
|
||||||
|
# Matrix multiply.
|
||||||
|
assert self.quant_method is not None
|
||||||
|
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||||
|
# bias will not get added more than once in TP>1 case)
|
||||||
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||||
|
output_parallel = self.quant_method.apply(self,
|
||||||
|
input_parallel,
|
||||||
|
bias=bias_)
|
||||||
|
if self.reduce_results and self.tp_size > 1:
|
||||||
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||||
|
else:
|
||||||
|
output = output_parallel
|
||||||
|
|
||||||
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
|
|
||||||
|
if not self.return_bias:
|
||||||
|
return output
|
||||||
|
return output, output_bias
|
||||||
|
|
||||||
|
|
||||||
class CustomDeepseekV2MLP(nn.Module):
|
class CustomDeepseekV2MLP(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -289,10 +366,11 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
|
|
||||||
self.params_dtype = torch.get_default_dtype()
|
self.params_dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
def forward(
|
def forward(self,
|
||||||
self,
|
hidden_states: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
attn_metadata: Optional[AttentionMetadata] = None,
|
||||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
replace_allreduce: bool = False) -> torch.Tensor:
|
||||||
|
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
# when profile runs, force experts to load balanced tokens
|
# when profile runs, force experts to load balanced tokens
|
||||||
@@ -318,7 +396,7 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
top_k=CustomDeepseekV2MoE.top_k,
|
top_k=CustomDeepseekV2MoE.top_k,
|
||||||
enable_force_load_balance=enable_force_load_balance,
|
enable_force_load_balance=enable_force_load_balance,
|
||||||
shared_experts=self.shared_experts,
|
shared_experts=self.shared_experts,
|
||||||
)
|
replace_allreduce=replace_allreduce)
|
||||||
|
|
||||||
hidden_states = (
|
hidden_states = (
|
||||||
experts_hidden_states[0] * self.routed_scaling_factor +
|
experts_hidden_states[0] * self.routed_scaling_factor +
|
||||||
@@ -365,6 +443,14 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
|||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
|
self.prefix = prefix
|
||||||
|
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||||
|
|
||||||
|
ascend_config = get_ascend_config()
|
||||||
|
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||||
|
self.enable_multistream_mla = \
|
||||||
|
ascend_config.torchair_graph_config.enable_multistream_mla
|
||||||
|
|
||||||
if self.q_lora_rank is not None:
|
if self.q_lora_rank is not None:
|
||||||
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
||||||
self.q_lora_rank,
|
self.q_lora_rank,
|
||||||
@@ -401,11 +487,23 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
|||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.kv_b_proj")
|
prefix=f"{prefix}.kv_b_proj")
|
||||||
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
|
if (config.n_routed_experts is not None
|
||||||
self.hidden_size,
|
and self.debug_layer_idx >= config.first_k_dense_replace
|
||||||
bias=False,
|
and self.debug_layer_idx % config.moe_layer_freq == 0 and
|
||||||
quant_config=quant_config,
|
ascend_config.torchair_graph_config.enable_multistream_moe):
|
||||||
prefix=f"{prefix}.o_proj")
|
self.o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce(
|
||||||
|
self.num_heads * self.v_head_dim,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj")
|
||||||
|
else:
|
||||||
|
self.o_proj = CustomDeepseekV2RowParallelLinear(
|
||||||
|
self.num_heads * self.v_head_dim,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj")
|
||||||
|
|
||||||
if rope_scaling:
|
if rope_scaling:
|
||||||
rope_scaling["rope_type"] = 'deepseek_yarn'
|
rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||||
@@ -451,14 +549,6 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
|||||||
o_proj=self.o_proj,
|
o_proj=self.o_proj,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.prefix = prefix
|
|
||||||
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
|
||||||
|
|
||||||
ascend_config = get_ascend_config()
|
|
||||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
|
||||||
self.enable_multistream_mla = \
|
|
||||||
ascend_config.torchair_graph_config.enable_multistream_mla
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
@@ -524,6 +614,10 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
# with the layer's index.
|
# with the layer's index.
|
||||||
layer_idx = int(prefix.split(sep='.')[-1])
|
layer_idx = int(prefix.split(sep='.')[-1])
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
|
self.layers = config.num_hidden_layers
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.tp_rank = get_tp_group().rank_in_group
|
||||||
|
ascend_config = get_ascend_config()
|
||||||
# TODO: enable mla in vllm-ascend
|
# TODO: enable mla in vllm-ascend
|
||||||
if model_config.use_mla:
|
if model_config.use_mla:
|
||||||
attn_cls = CustomDeepseekV2MLAAttention
|
attn_cls = CustomDeepseekV2MLAAttention
|
||||||
@@ -555,6 +649,8 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mlp",
|
prefix=f"{prefix}.mlp",
|
||||||
)
|
)
|
||||||
|
self.mla_moe_communication = ascend_config.torchair_graph_config.enable_multistream_moe \
|
||||||
|
and model_config.use_mla and envs.VLLM_USE_V1 and self.tp_size > 1
|
||||||
else:
|
else:
|
||||||
self.mlp = CustomDeepseekV2MLP(
|
self.mlp = CustomDeepseekV2MLP(
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
@@ -563,11 +659,13 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mlp",
|
prefix=f"{prefix}.mlp",
|
||||||
)
|
)
|
||||||
|
self.mla_moe_communication = False
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
self.routed_scaling_factor = config.routed_scaling_factor
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
|
self.first_k_dense_replace = config.first_k_dense_replace
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -576,8 +674,13 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
kv_cache: Optional[torch.Tensor] = None,
|
kv_cache: Optional[torch.Tensor] = None,
|
||||||
attn_metadata: Optional[AttentionMetadata] = None,
|
attn_metadata: Optional[AttentionMetadata] = None,
|
||||||
|
replace_allreduce: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
|
if attn_metadata is not None and attn_metadata.num_decodes > 0:
|
||||||
|
mla_moe_communication = self.mla_moe_communication and replace_allreduce
|
||||||
|
else:
|
||||||
|
mla_moe_communication = False
|
||||||
if residual is None:
|
if residual is None:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
@@ -589,6 +692,9 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
# to save npu memory because they're no longer used.
|
# to save npu memory because they're no longer used.
|
||||||
dispose_tensor(previous_hidden_states)
|
dispose_tensor(previous_hidden_states)
|
||||||
dispose_tensor(previous_residual)
|
dispose_tensor(previous_residual)
|
||||||
|
if mla_moe_communication and self.layer_idx > self.first_k_dense_replace:
|
||||||
|
hidden_states = tensor_model_parallel_all_gather(hidden_states,
|
||||||
|
dim=0)
|
||||||
|
|
||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
@@ -597,6 +703,13 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if mla_moe_communication and residual.shape[0] != hidden_states.shape[
|
||||||
|
0]:
|
||||||
|
chunk_hidden_states = torch.tensor_split(residual,
|
||||||
|
self.tp_size,
|
||||||
|
dim=0)
|
||||||
|
residual = chunk_hidden_states[self.tp_rank]
|
||||||
|
|
||||||
if hidden_states.dtype == torch.float16:
|
if hidden_states.dtype == torch.float16:
|
||||||
# Fix FP16 overflow
|
# Fix FP16 overflow
|
||||||
# We scale both hidden_states and residual before
|
# We scale both hidden_states and residual before
|
||||||
@@ -612,7 +725,9 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
hidden_states, residual)
|
hidden_states, residual)
|
||||||
|
|
||||||
if isinstance(self.mlp, CustomDeepseekV2MoE):
|
if isinstance(self.mlp, CustomDeepseekV2MoE):
|
||||||
hidden_states = self.mlp(hidden_states, attn_metadata)
|
hidden_states = self.mlp(hidden_states,
|
||||||
|
attn_metadata,
|
||||||
|
replace_allreduce=mla_moe_communication)
|
||||||
else:
|
else:
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
|
||||||
@@ -625,6 +740,10 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
# The scaling of DeepseekV2MOE output would be done in the forward
|
# The scaling of DeepseekV2MOE output would be done in the forward
|
||||||
# of DeepseekV2MOE
|
# of DeepseekV2MOE
|
||||||
hidden_states *= 1. / self.routed_scaling_factor
|
hidden_states *= 1. / self.routed_scaling_factor
|
||||||
|
if mla_moe_communication and self.layer_idx == self.layers - 1:
|
||||||
|
hidden_states = tensor_model_parallel_all_gather(hidden_states,
|
||||||
|
dim=0)
|
||||||
|
residual = tensor_model_parallel_all_gather(residual, dim=0)
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
@@ -643,6 +762,7 @@ class CustomDeepseekV2Model(nn.Module):
|
|||||||
|
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
if get_pp_group().is_first_rank:
|
if get_pp_group().is_first_rank:
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
@@ -695,13 +815,18 @@ class CustomDeepseekV2Model(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
|
replace_allreduce = hidden_states.shape[0] % self.tp_size == 0
|
||||||
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for i in range(self.start_layer, self.end_layer):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions, hidden_states, residual,
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
kv_caches[i -
|
kv_caches[i -
|
||||||
self.start_layer] if kv_caches is not None else None,
|
self.start_layer] if kv_caches is not None else None,
|
||||||
attn_metadata)
|
attn_metadata,
|
||||||
|
replace_allreduce=replace_allreduce)
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
|
|||||||
@@ -1211,7 +1211,8 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
is_prefill: bool,
|
is_prefill: bool,
|
||||||
enable_force_load_balance: bool = False,
|
enable_force_load_balance: bool = False,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
shared_experts: Optional[Any] = None):
|
shared_experts: Optional[Any] = None,
|
||||||
|
replace_allreduce: bool = False):
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
if top_k:
|
if top_k:
|
||||||
@@ -1230,7 +1231,8 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
|
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
|
||||||
and fused_moe_state != FusedMoEState.AllGatherEP):
|
and fused_moe_state != FusedMoEState.AllGatherEP
|
||||||
|
and not replace_allreduce):
|
||||||
if num_tokens < tp_size:
|
if num_tokens < tp_size:
|
||||||
hidden_states = nn.functional.pad(
|
hidden_states = nn.functional.pad(
|
||||||
hidden_states, (0, 0, 0, tp_size - num_tokens))
|
hidden_states, (0, 0, 0, tp_size - num_tokens))
|
||||||
@@ -1289,7 +1291,8 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
e_hidden_states, shared_hidden_states = e_hidden_states
|
e_hidden_states, shared_hidden_states = e_hidden_states
|
||||||
|
|
||||||
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
|
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
|
||||||
and fused_moe_state != FusedMoEState.AllGatherEP):
|
and fused_moe_state != FusedMoEState.AllGatherEP
|
||||||
|
and not replace_allreduce):
|
||||||
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
|
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
|
||||||
self.tp_group)
|
self.tp_group)
|
||||||
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||||
|
|||||||
Reference in New Issue
Block a user