diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index 60ce1ae..716fffc 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -358,6 +358,7 @@ jobs: pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py # Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py will raise error. # To avoid oom, we need to run the test in a single process. + pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_topk diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index d35f6cd..503157d 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -47,6 +47,32 @@ def test_models_distributed_QwQ(): vllm_model.generate_greedy(example_prompts, max_tokens) +def test_models_distributed_DeepSeek_multistream_moe(): + example_prompts = [ + "Hello, my name is", + ] + dtype = "half" + max_tokens = 5 + with VllmRunner( + "vllm-ascend/DeepSeek-V3-Pruning", + dtype=dtype, + tensor_parallel_size=2, + distributed_executor_backend="mp", + additional_config={ + "torchair_graph_config": { + "enabled": True, + "enable_multistream_moe": True, + }, + "ascend_scheduler_config": { + "enabled": True, + }, + "refresh": True, + }, + enforce_eager=False, + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) + + def test_models_distributed_DeepSeek(): example_prompts = [ "Hello, my name is", diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index cb031ae..dbcf6ef 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, MLAAttentionImpl) from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.utils import cdiv, round_down @@ -557,6 +558,7 @@ class AscendMLAImpl(MLAAttentionImpl): self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.tp_size = get_tensor_model_parallel_world_size() ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled @@ -586,7 +588,7 @@ class AscendMLAImpl(MLAAttentionImpl): x = torch.bmm(x, self.W_UV) # Convert from (N, B, V) to (B, N * V) x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) - return self.o_proj(x)[0] + return self.o_proj(x, is_prefill=False)[0] # Return `ql_nope`, `q_pe` def _q_proj_and_k_up_proj(self, x): @@ -847,12 +849,12 @@ class AscendMLAImpl(MLAAttentionImpl): current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is None: - return self.o_proj(attn_output)[0] + return self.o_proj(attn_output, is_prefill=True)[0] else: current_ms_metadata.before_comm_event.record() with torch.npu.stream(current_ms_metadata.comm_stream): current_ms_metadata.before_comm_event.wait() - return self.o_proj(attn_output)[0] + return self.o_proj(attn_output, is_prefill=True)[0] def exec_kv( self, diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index 02405fb..bace69d 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -42,8 +42,7 @@ from vllm.distributed.parallel_state import get_dp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, - ReplicatedLinear, - RowParallelLinear) + ReplicatedLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope @@ -64,7 +63,8 @@ from vllm.sequence import IntermediateTensors import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MLP +from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLP, + CustomDeepseekV2RowParallelLinear) from vllm_ascend.multistream.base import MSEventKey from vllm_ascend.multistream.context import ( advance_step_multistream_layer_context, get_multistream_comm_context, @@ -325,11 +325,12 @@ class CustomDeepseekDBOMLAAttention(DeepseekV2MLAAttention): bias=False, quant_config=quant_config, prefix=f"{prefix}.kv_b_proj") - self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + self.o_proj = CustomDeepseekV2RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") if rope_scaling: rope_scaling["rope_type"] = 'deepseek_yarn' diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index e96b2e9..908c60f 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -34,9 +34,12 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import (get_pp_group, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - get_tp_group) + get_tp_group, split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) from vllm.distributed.parallel_state import get_dp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul @@ -133,6 +136,80 @@ class CustomDeepseekV2MergedReplicatedLinear(ReplicatedLinear): shard.copy_(loaded_weight) +class CustomDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear): + + def forward( + self, + input_, + is_prefill=True + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]: + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output_parallel = self.quant_method.apply(self, + input_parallel, + bias=bias_) + if self.reduce_results and self.tp_size > 1: + if not is_prefill and output_parallel.shape[0] % self.tp_size == 0: + output = tensor_model_parallel_reduce_scatter(output_parallel, + dim=0) + else: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + + output_bias = self.bias if self.skip_bias_add else None + + if not self.return_bias: + return output + return output, output_bias + + +class CustomDeepseekV2RowParallelLinear(RowParallelLinear): + + def forward( + self, + input_, + is_prefill=True + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]: + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output_parallel = self.quant_method.apply(self, + input_parallel, + bias=bias_) + if self.reduce_results and self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + + output_bias = self.bias if self.skip_bias_add else None + + if not self.return_bias: + return output + return output, output_bias + + class CustomDeepseekV2MLP(nn.Module): def __init__( @@ -289,10 +366,11 @@ class CustomDeepseekV2MoE(nn.Module): self.params_dtype = torch.get_default_dtype() - def forward( - self, - hidden_states: torch.Tensor, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + def forward(self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None, + replace_allreduce: bool = False) -> torch.Tensor: + if attn_metadata is None: attn_metadata = get_forward_context().attn_metadata # when profile runs, force experts to load balanced tokens @@ -318,7 +396,7 @@ class CustomDeepseekV2MoE(nn.Module): top_k=CustomDeepseekV2MoE.top_k, enable_force_load_balance=enable_force_load_balance, shared_experts=self.shared_experts, - ) + replace_allreduce=replace_allreduce) hidden_states = ( experts_hidden_states[0] * self.routed_scaling_factor + @@ -365,6 +443,14 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_multistream_mla = \ + ascend_config.torchair_graph_config.enable_multistream_mla + if self.q_lora_rank is not None: self.q_a_proj = ReplicatedLinear(self.hidden_size, self.q_lora_rank, @@ -401,11 +487,23 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): bias=False, quant_config=quant_config, prefix=f"{prefix}.kv_b_proj") - self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + if (config.n_routed_experts is not None + and self.debug_layer_idx >= config.first_k_dense_replace + and self.debug_layer_idx % config.moe_layer_freq == 0 and + ascend_config.torchair_graph_config.enable_multistream_moe): + self.o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + else: + self.o_proj = CustomDeepseekV2RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") if rope_scaling: rope_scaling["rope_type"] = 'deepseek_yarn' @@ -451,14 +549,6 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): o_proj=self.o_proj, ) - self.prefix = prefix - self.debug_layer_idx = int(self.prefix.split(".")[-2]) - - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.enable_multistream_mla = \ - ascend_config.torchair_graph_config.enable_multistream_mla - def forward( self, positions: torch.Tensor, @@ -524,6 +614,10 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): # with the layer's index. layer_idx = int(prefix.split(sep='.')[-1]) self.layer_idx = layer_idx + self.layers = config.num_hidden_layers + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tp_group().rank_in_group + ascend_config = get_ascend_config() # TODO: enable mla in vllm-ascend if model_config.use_mla: attn_cls = CustomDeepseekV2MLAAttention @@ -555,6 +649,8 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): quant_config=quant_config, prefix=f"{prefix}.mlp", ) + self.mla_moe_communication = ascend_config.torchair_graph_config.enable_multistream_moe \ + and model_config.use_mla and envs.VLLM_USE_V1 and self.tp_size > 1 else: self.mlp = CustomDeepseekV2MLP( hidden_size=config.hidden_size, @@ -563,11 +659,13 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): quant_config=quant_config, prefix=f"{prefix}.mlp", ) + self.mla_moe_communication = False self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.routed_scaling_factor = config.routed_scaling_factor + self.first_k_dense_replace = config.first_k_dense_replace def forward( self, @@ -576,8 +674,13 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): residual: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor] = None, attn_metadata: Optional[AttentionMetadata] = None, + replace_allreduce: bool = False, ) -> torch.Tensor: # Self Attention + if attn_metadata is not None and attn_metadata.num_decodes > 0: + mla_moe_communication = self.mla_moe_communication and replace_allreduce + else: + mla_moe_communication = False if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -589,6 +692,9 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): # to save npu memory because they're no longer used. dispose_tensor(previous_hidden_states) dispose_tensor(previous_residual) + if mla_moe_communication and self.layer_idx > self.first_k_dense_replace: + hidden_states = tensor_model_parallel_all_gather(hidden_states, + dim=0) hidden_states = self.self_attn( positions=positions, @@ -597,6 +703,13 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): attn_metadata=attn_metadata, ) + if mla_moe_communication and residual.shape[0] != hidden_states.shape[ + 0]: + chunk_hidden_states = torch.tensor_split(residual, + self.tp_size, + dim=0) + residual = chunk_hidden_states[self.tp_rank] + if hidden_states.dtype == torch.float16: # Fix FP16 overflow # We scale both hidden_states and residual before @@ -612,7 +725,9 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): hidden_states, residual) if isinstance(self.mlp, CustomDeepseekV2MoE): - hidden_states = self.mlp(hidden_states, attn_metadata) + hidden_states = self.mlp(hidden_states, + attn_metadata, + replace_allreduce=mla_moe_communication) else: hidden_states = self.mlp(hidden_states) @@ -625,6 +740,10 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): # The scaling of DeepseekV2MOE output would be done in the forward # of DeepseekV2MOE hidden_states *= 1. / self.routed_scaling_factor + if mla_moe_communication and self.layer_idx == self.layers - 1: + hidden_states = tensor_model_parallel_all_gather(hidden_states, + dim=0) + residual = tensor_model_parallel_all_gather(residual, dim=0) return hidden_states, residual @@ -643,6 +762,7 @@ class CustomDeepseekV2Model(nn.Module): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size + self.tp_size = get_tensor_model_parallel_world_size() if get_pp_group().is_first_rank: self.embed_tokens = VocabParallelEmbedding( @@ -695,13 +815,18 @@ class CustomDeepseekV2Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] + replace_allreduce = hidden_states.shape[0] % self.tp_size == 0 + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( - positions, hidden_states, residual, + positions, + hidden_states, + residual, kv_caches[i - self.start_layer] if kv_caches is not None else None, - attn_metadata) + attn_metadata, + replace_allreduce=replace_allreduce) if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index d65f12c..f8a4f5e 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -1211,7 +1211,8 @@ class AscendFusedMoE(FusedMoE): is_prefill: bool, enable_force_load_balance: bool = False, top_k: Optional[int] = None, - shared_experts: Optional[Any] = None): + shared_experts: Optional[Any] = None, + replace_allreduce: bool = False): assert self.quant_method is not None if top_k: @@ -1230,7 +1231,8 @@ class AscendFusedMoE(FusedMoE): tp_size = get_tensor_model_parallel_world_size() if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather - and fused_moe_state != FusedMoEState.AllGatherEP): + and fused_moe_state != FusedMoEState.AllGatherEP + and not replace_allreduce): if num_tokens < tp_size: hidden_states = nn.functional.pad( hidden_states, (0, 0, 0, tp_size - num_tokens)) @@ -1289,7 +1291,8 @@ class AscendFusedMoE(FusedMoE): e_hidden_states, shared_hidden_states = e_hidden_states if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather - and fused_moe_state != FusedMoEState.AllGatherEP): + and fused_moe_state != FusedMoEState.AllGatherEP + and not replace_allreduce): dist.all_gather(list(chunk_hidden_states), e_hidden_states, self.tp_group) final_hidden_states = torch.cat(chunk_hidden_states, dim=0)